Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
32806d5f
Commit
32806d5f
authored
Dec 27, 2023
by
Jun Liu
Browse files
Merge branch 'amd-develop' into amd-master
parents
e70a4d19
d0f355a3
Changes
138
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2251 additions
and
206 deletions
+2251
-206
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
...impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
+1
-2
include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_data_impl.hpp
...on/gpu/device/impl/device_normalization_bwd_data_impl.hpp
+465
-0
include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp
.../device/impl/device_normalization_bwd_gamma_beta_impl.hpp
+23
-9
include/ck/tensor_operation/gpu/device/impl/device_normalization_fwd_impl.hpp
...eration/gpu/device/impl/device_normalization_fwd_impl.hpp
+2
-4
include/ck/tensor_operation/gpu/device/impl/device_normalization_fwd_splitk_impl.hpp
.../gpu/device/impl/device_normalization_fwd_splitk_impl.hpp
+2
-2
include/ck/tensor_operation/gpu/device/tensor_layout.hpp
include/ck/tensor_operation/gpu/device/tensor_layout.hpp
+0
-6
include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp
...k/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp
+2
-0
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_data.hpp
...pu/grid/normalization/gridwise_normalization_bwd_data.hpp
+554
-0
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_gamma_beta.hpp
...d/normalization/gridwise_normalization_bwd_gamma_beta.hpp
+10
-1
include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp
...eration/operator_transform/transform_conv_fwd_to_gemm.hpp
+7
-8
include/ck/utility/tuple_helper.hpp
include/ck/utility/tuple_helper.hpp
+12
-0
include/ck/wrapper/layout.hpp
include/ck/wrapper/layout.hpp
+159
-150
include/ck/wrapper/tensor.hpp
include/ck/wrapper/tensor.hpp
+314
-0
include/ck/wrapper/utils/layout_utils.hpp
include/ck/wrapper/utils/layout_utils.hpp
+335
-0
include/ck/wrapper/utils/tensor_utils.hpp
include/ck/wrapper/utils/tensor_utils.hpp
+290
-0
library/include/ck/library/reference_tensor_operation/cpu/reference_groupnorm_bwd.hpp
...eference_tensor_operation/cpu/reference_groupnorm_bwd.hpp
+25
-0
library/include/ck/library/reference_tensor_operation/cpu/reference_layernorm_bwd.hpp
...eference_tensor_operation/cpu/reference_layernorm_bwd.hpp
+24
-0
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
..._operation_instance/device_operation_instance_factory.hpp
+3
-3
library/include/ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp
..._instance/gpu/contraction/device_contraction_instance.hpp
+20
-4
library/include/ck/library/tensor_operation_instance/gpu/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_instance.hpp
...vice_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_instance.hpp
+3
-17
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
View file @
32806d5f
...
@@ -631,8 +631,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
...
@@ -631,8 +631,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
is_same_v
<
DLayout
,
ctc
::
G_NDHW_K
>
||
is_same_v
<
DLayout
,
ctc
::
GNWK
>
||
is_same_v
<
DLayout
,
ctc
::
G_NDHW_K
>
||
is_same_v
<
DLayout
,
ctc
::
GNWK
>
||
is_same_v
<
DLayout
,
ctc
::
GNHWK
>
||
is_same_v
<
DLayout
,
ctc
::
GNDHWK
>
||
is_same_v
<
DLayout
,
ctc
::
GNHWK
>
||
is_same_v
<
DLayout
,
ctc
::
GNDHWK
>
||
is_same_v
<
DLayout
,
ctc
::
NWGK
>
||
is_same_v
<
DLayout
,
ctc
::
NHWGK
>
||
is_same_v
<
DLayout
,
ctc
::
NWGK
>
||
is_same_v
<
DLayout
,
ctc
::
NHWGK
>
||
is_same_v
<
DLayout
,
ctc
::
NDHWGK
>
||
is_same_v
<
DLayout
,
ctc
::
GK
>
||
is_same_v
<
DLayout
,
ctc
::
NDHWGK
>
||
is_same_v
<
DLayout
,
ctc
::
G_K
>
)
is_same_v
<
DLayout
,
ctc
::
G_K
>
)
{
{
const
index_t
K
=
arg
.
ds_g_n_k_wos_lengths_
[
i
][
2
];
const
index_t
K
=
arg
.
ds_g_n_k_wos_lengths_
[
i
][
2
];
...
...
include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_data_impl.hpp
0 → 100644
View file @
32806d5f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "ck/tensor_operation/gpu/device/device_normalization_bwd_data.hpp"
#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_data.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
// M is Invariant dimension, K is reduced dimension
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
GridwiseNormalizationBwd
,
typename
DYDataType
,
typename
XDataType
,
typename
GammaDataType
,
typename
MeanInvStdDataType
,
typename
DXDataType
,
typename
GridDesc_M_K
>
__global__
void
kernel_normalization_bwd_data
(
const
GridDesc_M_K
dy_grid_desc_m_k
,
const
GridDesc_M_K
x_grid_desc_m_k
,
const
GridDesc_M_K
gamma_grid_desc_m_k
,
const
GridDesc_M_K
mean_grid_desc_m_k
,
const
GridDesc_M_K
inv_std_grid_desc_m_k
,
const
GridDesc_M_K
dx_grid_desc_m_k
,
index_t
num_k_block_tile_iteration
,
const
DYDataType
*
const
__restrict__
p_dy_global
,
const
XDataType
*
const
__restrict__
p_x_global
,
const
GammaDataType
*
const
__restrict__
p_gamma_global
,
const
MeanInvStdDataType
*
const
__restrict__
p_mean_global
,
const
MeanInvStdDataType
*
const
__restrict__
p_inv_std_global
,
DXDataType
*
const
__restrict__
p_dx_global
)
{
GridwiseNormalizationBwd
::
Run
(
dy_grid_desc_m_k
,
x_grid_desc_m_k
,
gamma_grid_desc_m_k
,
mean_grid_desc_m_k
,
inv_std_grid_desc_m_k
,
dx_grid_desc_m_k
,
num_k_block_tile_iteration
,
p_dy_global
,
p_x_global
,
p_gamma_global
,
p_mean_global
,
p_inv_std_global
,
p_dx_global
);
};
template
<
typename
DYDataType
,
typename
XDataType
,
typename
GammaDataType
,
typename
MeanInvStdDataType
,
typename
ComputeDataType
,
typename
DXDataType
,
index_t
Rank
,
index_t
NumReduceDim
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
bool
IsDYFastestDimReduced
,
index_t
DYSrcVectorSize
,
bool
IsXFastestDimReduced
,
index_t
XSrcVectorSize
,
bool
IsGammaFastestDimReduced
,
index_t
GammaSrcVectorSize
,
bool
IsMeanInvStdFastestDimReduced
,
index_t
MeanInvStdSrcVectorSize
,
bool
IsDxFastestDimReduced
,
index_t
DXDstVectorSize
>
struct
DeviceNormalizationBwdDataImpl
:
public
DeviceNormalizationBwdData
<
DYDataType
,
XDataType
,
GammaDataType
,
MeanInvStdDataType
,
DXDataType
,
Rank
,
NumReduceDim
>
{
static
constexpr
index_t
DYSrcVectorDim
=
IsDYFastestDimReduced
?
1
:
0
;
static
constexpr
index_t
XSrcVectorDim
=
IsXFastestDimReduced
?
1
:
0
;
static
constexpr
index_t
GammaSrcVectorDim
=
IsGammaFastestDimReduced
?
1
:
0
;
static
constexpr
index_t
MeanInvStdSrcVectorDim
=
IsMeanInvStdFastestDimReduced
?
1
:
0
;
static
constexpr
index_t
DXDstVectorDim
=
IsDxFastestDimReduced
?
1
:
0
;
static_assert
(
BlockSize
==
MThreadClusterSize
*
KThreadClusterSize
);
static_assert
(((
DYSrcVectorDim
==
0
&&
MThreadSliceSize
%
DYSrcVectorSize
==
0
)
||
(
DYSrcVectorDim
==
1
&&
KThreadSliceSize
%
DYSrcVectorSize
==
0
)),
"Invalid thread slice sizes and/or dy vector sizes configuration, please check!"
);
static_assert
(((
XSrcVectorDim
==
0
&&
MThreadSliceSize
%
XSrcVectorSize
==
0
)
||
(
XSrcVectorDim
==
1
&&
KThreadSliceSize
%
XSrcVectorSize
==
0
)),
"Invalid thread slice sizes and/or x vector sizes configuration, please check!"
);
static_assert
(
((
GammaSrcVectorDim
==
0
&&
MThreadSliceSize
%
GammaSrcVectorSize
==
0
)
||
(
GammaSrcVectorDim
==
1
&&
KThreadSliceSize
%
GammaSrcVectorSize
==
0
)),
"Invalid thread slice sizes and/or gamma vector sizes configuration, please check!"
);
static_assert
(
(
MeanInvStdSrcVectorDim
==
0
&&
MThreadSliceSize
%
MeanInvStdSrcVectorSize
==
0
)
||
(
MeanInvStdSrcVectorDim
==
1
&&
KThreadSliceSize
%
MeanInvStdSrcVectorSize
==
0
),
"Invalid thread slice sizes and/or mean and inverse std vector sizes configuration, please "
"check!"
);
static_assert
(((
DXDstVectorDim
==
0
&&
MThreadSliceSize
%
DXDstVectorSize
==
0
)
||
(
DXDstVectorDim
==
1
&&
KThreadSliceSize
%
DXDstVectorSize
==
0
)),
"Invalid thread slice sizes and/or dx vector sizes configuration, please check!"
);
static
constexpr
index_t
NumInvariantDim
=
Rank
-
NumReduceDim
;
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
constexpr
bool
reduceAllDim
=
(
NumInvariantDim
==
0
);
static_assert
(
!
reduceAllDim
);
static
auto
Make2dDescriptor
(
const
std
::
vector
<
index_t
>&
lengths
,
const
std
::
vector
<
index_t
>&
strides
,
int
numBlockTileIteration
)
{
const
auto
tupleLengths
=
make_tuple_from_array
(
lengths
,
Number
<
Rank
>
{});
const
auto
tupleStrides
=
make_tuple_from_array
(
strides
,
Number
<
Rank
>
{});
const
auto
desc
=
make_naive_tensor_descriptor
(
tupleLengths
,
tupleStrides
);
const
auto
grid_desc_m_k
=
[
&
]()
{
using
InvariantDims
=
typename
arithmetic_sequence_gen
<
0
,
NumInvariantDim
,
1
>::
type
;
using
ReduceDims
=
typename
arithmetic_sequence_gen
<
NumInvariantDim
,
Rank
,
1
>::
type
;
const
auto
reduceDimLengths
=
make_tuple_from_array_and_index_seq
(
lengths
,
ReduceDims
{});
const
auto
invariantDimLengths
=
make_tuple_from_array_and_index_seq
(
lengths
,
InvariantDims
{});
return
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
invariantDimLengths
),
make_merge_transform
(
reduceDimLengths
)),
make_tuple
(
InvariantDims
{},
ReduceDims
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}();
const
auto
invariantLength
=
grid_desc_m_k
.
GetLength
(
Number
<
0
>
{});
const
auto
reduceLength
=
grid_desc_m_k
.
GetLength
(
Number
<
1
>
{});
const
auto
pad_M
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
const
auto
pad_K
=
K_BlockTileSize
*
numBlockTileIteration
-
reduceLength
;
auto
grid_desc_m_k_padded
=
transform_tensor_descriptor
(
grid_desc_m_k
,
make_tuple
(
make_right_pad_transform
(
invariantLength
,
pad_M
),
make_right_pad_transform
(
reduceLength
,
pad_K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
grid_desc_m_k_padded
;
}
using
GridDesc_M_K
=
decltype
(
Make2dDescriptor
({
1
},
{
1
},
1
));
using
GridwiseNormalizationBwdDataGeneric
=
GridwiseNormalizationBwdData_mk_to_mk
<
DYDataType
,
XDataType
,
GammaDataType
,
MeanInvStdDataType
,
ComputeDataType
,
DXDataType
,
GridDesc_M_K
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
DYSrcVectorDim
,
DYSrcVectorSize
,
XSrcVectorDim
,
XSrcVectorSize
,
GammaSrcVectorDim
,
GammaSrcVectorSize
,
MeanInvStdSrcVectorDim
,
MeanInvStdSrcVectorSize
,
DXDstVectorDim
,
DXDstVectorSize
,
false
>
;
using
GridwiseNormalizationBwdDataSweepOnce
=
GridwiseNormalizationBwdData_mk_to_mk
<
DYDataType
,
XDataType
,
GammaDataType
,
MeanInvStdDataType
,
ComputeDataType
,
DXDataType
,
GridDesc_M_K
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
DYSrcVectorDim
,
DYSrcVectorSize
,
XSrcVectorDim
,
XSrcVectorSize
,
GammaSrcVectorDim
,
GammaSrcVectorSize
,
MeanInvStdSrcVectorDim
,
MeanInvStdSrcVectorSize
,
DXDstVectorDim
,
DXDstVectorSize
,
true
>
;
struct
Argument
:
public
BaseArgument
{
Argument
(
const
std
::
vector
<
index_t
>
lengths
,
const
std
::
vector
<
index_t
>
dyStrides
,
const
std
::
vector
<
index_t
>
xStrides
,
const
std
::
vector
<
index_t
>
gammaStrides
,
const
std
::
vector
<
index_t
>
meanStrides
,
const
std
::
vector
<
index_t
>
invStdStrides
,
const
std
::
vector
<
index_t
>
dxStrides
,
const
std
::
vector
<
index_t
>
reduceDims
,
const
DYDataType
*
p_dy
,
const
XDataType
*
p_x
,
const
GammaDataType
*
p_gamma
,
const
MeanInvStdDataType
*
p_mean
,
const
MeanInvStdDataType
*
p_invStd
,
DXDataType
*
p_dx
)
:
p_dy_
(
p_dy
),
p_x_
(
p_x
),
p_gamma_
(
p_gamma
),
p_mean_
(
p_mean
),
p_invStd_
(
p_invStd
),
p_dx_
(
p_dx
)
{
lengths_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
lengths
,
reduceDims
);
dyStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
dyStrides
,
reduceDims
);
xStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
xStrides
,
reduceDims
);
gammaStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
gammaStrides
,
reduceDims
);
meanStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
meanStrides
,
reduceDims
);
invStdStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
invStdStrides
,
reduceDims
);
dxStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
dxStrides
,
reduceDims
);
std
::
tie
(
MRaw_
,
KRaw_
)
=
get_2d_lengths
<
Rank
,
NumReduceDim
>
(
lengths_
);
numBlockTileIteration_
=
math
::
integer_divide_ceil
(
KRaw_
,
K_BlockTileSize
);
gridSize_
=
math
::
integer_divide_ceil
(
MRaw_
,
M_BlockTileSize
);
dy_grid_desc_m_k_
=
Make2dDescriptor
(
lengths_
,
dyStrides_
,
numBlockTileIteration_
);
x_grid_desc_m_k_
=
Make2dDescriptor
(
lengths_
,
xStrides_
,
numBlockTileIteration_
);
gamma_grid_desc_m_k_
=
Make2dDescriptor
(
lengths_
,
gammaStrides_
,
numBlockTileIteration_
);
mean_grid_desc_m_k_
=
Make2dDescriptor
(
lengths_
,
meanStrides_
,
numBlockTileIteration_
);
inv_std_grid_desc_m_k_
=
Make2dDescriptor
(
lengths_
,
invStdStrides_
,
numBlockTileIteration_
);
dx_grid_desc_m_k_
=
Make2dDescriptor
(
lengths_
,
dxStrides_
,
numBlockTileIteration_
);
isSweeponce_
=
dy_grid_desc_m_k_
.
GetLength
(
Number
<
1
>
{})
<=
K_BlockTileSize
;
}
const
DYDataType
*
p_dy_
;
const
XDataType
*
p_x_
;
const
GammaDataType
*
p_gamma_
;
const
MeanInvStdDataType
*
p_mean_
;
const
MeanInvStdDataType
*
p_invStd_
;
DXDataType
*
p_dx_
;
std
::
vector
<
index_t
>
lengths_
;
std
::
vector
<
index_t
>
dyStrides_
;
std
::
vector
<
index_t
>
xStrides_
;
std
::
vector
<
index_t
>
gammaStrides_
;
std
::
vector
<
index_t
>
meanStrides_
;
std
::
vector
<
index_t
>
invStdStrides_
;
std
::
vector
<
index_t
>
dxStrides_
;
int
numBlockTileIteration_
;
size_t
gridSize_
;
// tensor descriptor
GridDesc_M_K
dy_grid_desc_m_k_
;
GridDesc_M_K
x_grid_desc_m_k_
;
GridDesc_M_K
gamma_grid_desc_m_k_
;
GridDesc_M_K
mean_grid_desc_m_k_
;
GridDesc_M_K
inv_std_grid_desc_m_k_
;
GridDesc_M_K
dx_grid_desc_m_k_
;
bool
isSweeponce_
;
index_t
MRaw_
;
// Invariant length
index_t
KRaw_
;
// reduce length
};
struct
Invoker
:
public
BaseInvoker
{
auto
KernelSelector
(
bool
isSweepOnce
)
{
return
isSweepOnce
?
kernel_normalization_bwd_data
<
GridwiseNormalizationBwdDataSweepOnce
,
DYDataType
,
XDataType
,
GammaDataType
,
MeanInvStdDataType
,
DXDataType
,
GridDesc_M_K
>
:
kernel_normalization_bwd_data
<
GridwiseNormalizationBwdDataGeneric
,
DYDataType
,
XDataType
,
GammaDataType
,
MeanInvStdDataType
,
DXDataType
,
GridDesc_M_K
>
;
}
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
const
auto
kernel_main
=
KernelSelector
(
arg
.
isSweeponce_
);
return
launch_and_time_kernel
(
stream_config
,
kernel_main
,
dim3
(
arg
.
gridSize_
),
dim3
(
BlockSize
),
0
,
arg
.
dy_grid_desc_m_k_
,
arg
.
x_grid_desc_m_k_
,
arg
.
gamma_grid_desc_m_k_
,
arg
.
mean_grid_desc_m_k_
,
arg
.
inv_std_grid_desc_m_k_
,
arg
.
dx_grid_desc_m_k_
,
arg
.
numBlockTileIteration_
,
arg
.
p_dy_
,
arg
.
p_x_
,
arg
.
p_gamma_
,
arg
.
p_mean_
,
arg
.
p_invStd_
,
arg
.
p_dx_
);
}
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
template
<
index_t
SrcVectorDim
,
index_t
SrcVectorSize
>
bool
IsVectorDimSizeValid
(
const
std
::
vector
<
index_t
>&
lengths
,
const
std
::
vector
<
index_t
>&
strides
)
{
if
constexpr
(
SrcVectorSize
==
1
)
return
true
;
// Fastest dimension is not reduced
if
constexpr
(
SrcVectorDim
==
0
)
{
if
constexpr
(
NumInvariantDim
==
0
)
return
false
;
if
(
strides
[
NumInvariantDim
-
1
]
!=
1
)
return
false
;
if
(
lengths
[
NumInvariantDim
-
1
]
%
SrcVectorSize
!=
0
)
return
false
;
}
else
// Fastest dimension is reduced
{
if
(
strides
[
Rank
-
1
]
!=
1
)
return
false
;
if
(
lengths
[
Rank
-
1
]
%
SrcVectorSize
!=
0
)
return
false
;
};
return
true
;
}
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
const
Argument
*
p_arg_
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
bool
pass
=
true
;
pass
&=
IsVectorDimSizeValid
<
DYSrcVectorDim
,
DYSrcVectorSize
>
(
p_arg_
->
lengths_
,
p_arg_
->
dyStrides_
);
pass
&=
IsVectorDimSizeValid
<
XSrcVectorDim
,
XSrcVectorSize
>
(
p_arg_
->
lengths_
,
p_arg_
->
xStrides_
);
pass
&=
IsVectorDimSizeValid
<
GammaSrcVectorDim
,
GammaSrcVectorSize
>
(
p_arg_
->
lengths_
,
p_arg_
->
gammaStrides_
);
pass
&=
IsVectorDimSizeValid
<
MeanInvStdSrcVectorDim
,
MeanInvStdSrcVectorSize
>
(
p_arg_
->
lengths_
,
p_arg_
->
meanStrides_
);
pass
&=
IsVectorDimSizeValid
<
MeanInvStdSrcVectorDim
,
MeanInvStdSrcVectorSize
>
(
p_arg_
->
lengths_
,
p_arg_
->
invStdStrides_
);
pass
&=
IsVectorDimSizeValid
<
DXDstVectorDim
,
DXDstVectorSize
>
(
p_arg_
->
lengths_
,
p_arg_
->
dxStrides_
);
return
pass
;
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
vector
<
index_t
>
lengths
,
const
std
::
vector
<
index_t
>
dyStrides
,
const
std
::
vector
<
index_t
>
xStrides
,
const
std
::
vector
<
index_t
>
gammaStrides
,
const
std
::
vector
<
index_t
>
meanStrides
,
const
std
::
vector
<
index_t
>
invStdStrides
,
const
std
::
vector
<
index_t
>
dxStrides
,
const
std
::
vector
<
index_t
>
reduceDims
,
const
void
*
p_dy
,
const
void
*
p_x
,
const
void
*
p_gamma
,
const
void
*
p_mean
,
const
void
*
p_invStd
,
void
*
p_dx
)
override
{
if
(
lengths
.
size
()
!=
Rank
||
dyStrides
.
size
()
!=
Rank
||
xStrides
.
size
()
!=
Rank
||
gammaStrides
.
size
()
!=
Rank
||
meanStrides
.
size
()
!=
Rank
||
invStdStrides
.
size
()
!=
Rank
||
dxStrides
.
size
()
!=
Rank
)
throw
std
::
runtime_error
(
"dimension is incorrect"
);
return
std
::
make_unique
<
Argument
>
(
lengths
,
dyStrides
,
xStrides
,
gammaStrides
,
meanStrides
,
invStdStrides
,
dxStrides
,
reduceDims
,
static_cast
<
const
DYDataType
*>
(
p_dy
),
static_cast
<
const
XDataType
*>
(
p_x
),
static_cast
<
const
GammaDataType
*>
(
p_gamma
),
static_cast
<
const
MeanInvStdDataType
*>
(
p_mean
),
static_cast
<
const
MeanInvStdDataType
*>
(
p_invStd
),
static_cast
<
DXDataType
*>
(
p_dx
));
}
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
();
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceNormalizationBwdDataImpl<"
<<
BlockSize
<<
","
;
str
<<
"Cluster_MK_"
<<
MThreadClusterSize
<<
"_"
<<
KThreadClusterSize
<<
","
;
str
<<
"Slice_MK_"
<<
MThreadSliceSize
<<
"_"
<<
KThreadSliceSize
<<
","
;
str
<<
"DYSrcVectorSize"
<<
DYSrcVectorSize
<<
"_X"
<<
XSrcVectorSize
<<
"_Gamma"
<<
GammaSrcVectorSize
<<
"_MeanRstd"
<<
MeanInvStdSrcVectorSize
<<
"_Dx"
<<
DXDstVectorSize
;
str
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp
View file @
32806d5f
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/kernel_launch.hpp"
// M is
i
nvari
e
nt dimension, K is reduced dimension
// M is
I
nvari
a
nt dimension, K is reduced dimension
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
...
@@ -87,7 +87,6 @@ struct DeviceNormalizationBwdGammaBetaImpl
...
@@ -87,7 +87,6 @@ struct DeviceNormalizationBwdGammaBetaImpl
Rank
,
Rank
,
NumReduceDim
>
NumReduceDim
>
{
{
static
constexpr
index_t
DYSrcVectorDim
=
IsDYFastestDimReduced
?
1
:
0
;
static
constexpr
index_t
DYSrcVectorDim
=
IsDYFastestDimReduced
?
1
:
0
;
static
constexpr
index_t
XSrcVectorDim
=
IsXFastestDimReduced
?
1
:
0
;
static
constexpr
index_t
XSrcVectorDim
=
IsXFastestDimReduced
?
1
:
0
;
static
constexpr
index_t
MeanInvStdSrcVectorDim
=
IsMeanInvStdFastestDimReduced
?
1
:
0
;
static
constexpr
index_t
MeanInvStdSrcVectorDim
=
IsMeanInvStdFastestDimReduced
?
1
:
0
;
...
@@ -102,18 +101,18 @@ struct DeviceNormalizationBwdGammaBetaImpl
...
@@ -102,18 +101,18 @@ struct DeviceNormalizationBwdGammaBetaImpl
(
XSrcVectorDim
==
1
&&
KThreadSliceSize
%
XSrcVectorSize
==
0
)),
(
XSrcVectorDim
==
1
&&
KThreadSliceSize
%
XSrcVectorSize
==
0
)),
"Invalid thread slice sizes and/or x vector sizes configuration, please check!"
);
"Invalid thread slice sizes and/or x vector sizes configuration, please check!"
);
static_assert
(
((
MThreadSliceSize
%
DGammaDstVectorSize
==
0
)
||
(
MThreadSliceSize
%
DBetaDstVectorSize
==
0
)),
"Invalid thread slice sizes and/or Gamma and beta vector sizes configuration, please "
"check!"
);
static_assert
(
static_assert
(
(
MeanInvStdSrcVectorDim
==
0
&&
MThreadSliceSize
%
MeanInvStdSrcVectorSize
==
0
)
||
(
MeanInvStdSrcVectorDim
==
0
&&
MThreadSliceSize
%
MeanInvStdSrcVectorSize
==
0
)
||
(
MeanInvStdSrcVectorDim
==
1
&&
KThreadSliceSize
%
MeanInvStdSrcVectorSize
==
0
),
(
MeanInvStdSrcVectorDim
==
1
&&
KThreadSliceSize
%
MeanInvStdSrcVectorSize
==
0
),
"Invalid thread slice sizes and/or mean and inverse std vector sizes configuration, please "
"Invalid thread slice sizes and/or mean and inverse std vector sizes configuration, please "
"check!"
);
"check!"
);
static_assert
(
((
MThreadSliceSize
%
DGammaDstVectorSize
==
0
)
||
(
MThreadSliceSize
%
DBetaDstVectorSize
==
0
)),
"Invalid thread slice sizes and/or Gamma and beta vector sizes configuration, please "
"check!"
);
static
constexpr
index_t
NumInvariantDim
=
Rank
-
NumReduceDim
;
static
constexpr
index_t
NumInvariantDim
=
Rank
-
NumReduceDim
;
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
...
@@ -298,7 +297,7 @@ struct DeviceNormalizationBwdGammaBetaImpl
...
@@ -298,7 +297,7 @@ struct DeviceNormalizationBwdGammaBetaImpl
GridDesc_M
dgamma_grid_desc_m_
;
GridDesc_M
dgamma_grid_desc_m_
;
GridDesc_M
dbeta_grid_desc_m_
;
GridDesc_M
dbeta_grid_desc_m_
;
index_t
MRaw_
;
//
i
nvari
e
nt length
index_t
MRaw_
;
//
I
nvari
a
nt length
index_t
KRaw_
;
// reduce length
index_t
KRaw_
;
// reduce length
};
};
...
@@ -457,6 +456,21 @@ struct DeviceNormalizationBwdGammaBetaImpl
...
@@ -457,6 +456,21 @@ struct DeviceNormalizationBwdGammaBetaImpl
{
{
return
std
::
make_unique
<
Invoker
>
();
return
std
::
make_unique
<
Invoker
>
();
}
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceNormalizationBwdGammaBetaImpl<"
<<
BlockSize
<<
","
;
str
<<
"Cluster_MK_"
<<
MThreadClusterSize
<<
"_"
<<
KThreadClusterSize
<<
","
;
str
<<
"Slice_MK_"
<<
MThreadSliceSize
<<
"_"
<<
KThreadSliceSize
<<
","
;
str
<<
"VectorSize_DY"
<<
DYSrcVectorSize
<<
"_X"
<<
XSrcVectorSize
;
str
<<
"_DGamma"
<<
DGammaDstVectorSize
<<
"_DBeta"
<<
DBetaDstVectorSize
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
};
}
// namespace device
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/impl/device_normalization_fwd_impl.hpp
View file @
32806d5f
...
@@ -19,7 +19,7 @@ namespace tensor_operation {
...
@@ -19,7 +19,7 @@ namespace tensor_operation {
namespace
device
{
namespace
device
{
// Y = Normalization(X, Beta, Gamma)
// Y = Normalization(X, Beta, Gamma)
// M: Invari
e
nt length
// M: Invari
a
nt length
// K: Reduce length (Calculate mean and variance along K dimension)
// K: Reduce length (Calculate mean and variance along K dimension)
// eg. Length = [N, C, H, W], reduce dim = [C, H, W]
// eg. Length = [N, C, H, W], reduce dim = [C, H, W]
// Then, M = N, K = C * H * W
// Then, M = N, K = C * H * W
...
@@ -263,7 +263,7 @@ struct DeviceNormalizationFwdImpl : public DeviceNormalizationFwd<XDataType,
...
@@ -263,7 +263,7 @@ struct DeviceNormalizationFwdImpl : public DeviceNormalizationFwd<XDataType,
GridDesc_M
save_inv_std_grid_desc_m_
;
GridDesc_M
save_inv_std_grid_desc_m_
;
bool
isSweeponce_
;
bool
isSweeponce_
;
index_t
MRaw_
;
//
i
nvari
e
nt length
index_t
MRaw_
;
//
I
nvari
a
nt length
index_t
KRaw_
;
// reduce length
index_t
KRaw_
;
// reduce length
index_t
invariant_lowest_length_
;
index_t
invariant_lowest_length_
;
...
@@ -342,8 +342,6 @@ struct DeviceNormalizationFwdImpl : public DeviceNormalizationFwd<XDataType,
...
@@ -342,8 +342,6 @@ struct DeviceNormalizationFwdImpl : public DeviceNormalizationFwd<XDataType,
}
}
else
else
{
{
printf
(
"!!!! %d
\n
"
,
p_arg_
->
invariant_lowest_length_
);
if
(
p_arg_
->
xStrides_
[
NumInvariantDim
-
1
]
!=
1
)
if
(
p_arg_
->
xStrides_
[
NumInvariantDim
-
1
]
!=
1
)
return
false
;
return
false
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_normalization_fwd_splitk_impl.hpp
View file @
32806d5f
...
@@ -108,7 +108,7 @@ namespace tensor_operation {
...
@@ -108,7 +108,7 @@ namespace tensor_operation {
namespace
device
{
namespace
device
{
// Y = Normalization(X, Beta, Gamma)
// Y = Normalization(X, Beta, Gamma)
// M: Invari
e
nt length
// M: Invari
a
nt length
// K: Reduce length (Calculate mean and variance along K dimension)
// K: Reduce length (Calculate mean and variance along K dimension)
// eg. Length = [N, C, H, W], reduce dim = [C, H, W]
// eg. Length = [N, C, H, W], reduce dim = [C, H, W]
// Then, M = N, K = C * H * W
// Then, M = N, K = C * H * W
...
@@ -468,7 +468,7 @@ struct DeviceNormalizationFwdSplitKImpl : public DeviceNormalizationFwd<XDataTyp
...
@@ -468,7 +468,7 @@ struct DeviceNormalizationFwdSplitKImpl : public DeviceNormalizationFwd<XDataTyp
Kernel2MeanVarGridDesc_M_KBlock
kernel2_mean_var_grid_desc_m_kblock_
;
Kernel2MeanVarGridDesc_M_KBlock
kernel2_mean_var_grid_desc_m_kblock_
;
Kernel2CountGridDesc_M_KBlock
kernel2_count_grid_desc_m_kblock_
;
Kernel2CountGridDesc_M_KBlock
kernel2_count_grid_desc_m_kblock_
;
index_t
MRaw_
;
//
i
nvari
e
nt length
index_t
MRaw_
;
//
I
nvari
a
nt length
index_t
KRaw_
;
// reduce length
index_t
KRaw_
;
// reduce length
index_t
invariant_lowest_length_
;
index_t
invariant_lowest_length_
;
...
...
include/ck/tensor_operation/gpu/device/tensor_layout.hpp
View file @
32806d5f
...
@@ -308,12 +308,6 @@ struct GNDHWK : public BaseTensorLayout
...
@@ -308,12 +308,6 @@ struct GNDHWK : public BaseTensorLayout
static
constexpr
const
char
*
name
=
"GNDHWK"
;
static
constexpr
const
char
*
name
=
"GNDHWK"
;
};
};
// for output bias
struct
GK
:
public
BaseTensorLayout
{
static
constexpr
const
char
*
name
=
"GK"
;
};
// output tensor
// output tensor
// packed NWGK/NHWGK/NDHWGK
// packed NWGK/NHWGK/NDHWGK
struct
NWGK
:
public
BaseTensorLayout
struct
NWGK
:
public
BaseTensorLayout
...
...
include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp
View file @
32806d5f
...
@@ -50,7 +50,9 @@ __global__ void
...
@@ -50,7 +50,9 @@ __global__ void
ignore
=
p_in_global
;
ignore
=
p_in_global
;
ignore
=
out_grid_desc
;
ignore
=
out_grid_desc
;
ignore
=
p_out_global
;
ignore
=
p_out_global
;
ignore
=
batch_count
;
ignore
=
block_2_tile_map
;
ignore
=
block_2_tile_map
;
ignore
=
compute_ptr_offset_of_batch
;
#endif
#endif
}
}
...
...
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_data.hpp
0 → 100644
View file @
32806d5f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp"
namespace
ck
{
// Tensor Shape
// dy, x = [M, K], gamma = [1, K], x_mean, inv_std = [M, 1]
// Flow:
// def normalization_backward_x(dy, x, gamma, x_mean, inv_std, reduce_axis, reduce_size):
// ds = np.sum(dy * gamma * x, axis=reduce_axis, keepdims=True)
// db = np.sum(dy * gamma, axis=reduce_axis, keepdims=True)
// b = (db * x_mean - ds) * inv_std ** (3) / reduce_size
// c = -b * x_mean - db * inv_std / reduce_size
// dx = inv_std * dy * gamma + b * x + c
// return dx
template
<
typename
DYDataType
,
typename
XDataType
,
typename
GammaDataType
,
typename
MeanInvStdDataType
,
typename
ComputeDataType
,
typename
DXDataType
,
typename
GridDesc_M_K
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
DYSrcVectorDim
,
index_t
DYSrcVectorSize
,
index_t
XSrcVectorDim
,
index_t
XSrcVectorSize
,
index_t
GammaSrcVectorDim
,
index_t
GammaSrcVectorSize
,
index_t
MeanInvStdSrcVectorDim
,
index_t
MeanInvStdSrcVectorSize
,
index_t
DXDstVectorDim
,
index_t
DXDstVectorSize
,
bool
SweepOnce
>
struct
GridwiseNormalizationBwdData_mk_to_mk
{
// if we just check ThreadSliceSize % VectorSize == 0, the performance may be poor (coalesce)
static_assert
(((
DYSrcVectorDim
==
0
&&
MThreadSliceSize
==
DYSrcVectorSize
)
||
(
DYSrcVectorDim
==
1
&&
KThreadSliceSize
==
DYSrcVectorSize
)),
"Invalid thread slice sizes and/or dy vector sizes configuration, please check!"
);
static_assert
(((
XSrcVectorDim
==
0
&&
MThreadSliceSize
==
XSrcVectorSize
)
||
(
XSrcVectorDim
==
1
&&
KThreadSliceSize
==
XSrcVectorSize
)),
"Invalid thread slice sizes and/or x vector sizes configuration, please check!"
);
static_assert
(
((
GammaSrcVectorDim
==
0
&&
MThreadSliceSize
==
GammaSrcVectorSize
)
||
(
GammaSrcVectorDim
==
1
&&
KThreadSliceSize
==
GammaSrcVectorSize
)),
"Invalid thread slice sizes and/or gamma vector sizes configuration, please check!"
);
static_assert
(
((
MeanInvStdSrcVectorDim
==
0
&&
MThreadSliceSize
==
MeanInvStdSrcVectorSize
)
||
(
MeanInvStdSrcVectorDim
==
1
&&
KThreadSliceSize
==
MeanInvStdSrcVectorSize
)),
"Invalid thread slice sizes and/or mean/inv_std vector sizes configuration, please check!"
);
static_assert
(((
DXDstVectorDim
==
0
&&
MThreadSliceSize
==
DXDstVectorSize
)
||
(
DXDstVectorDim
==
1
&&
KThreadSliceSize
==
DXDstVectorSize
)),
"Invalid thread slice sizes and/or dx vector sizes configuration, please check!"
);
using
ThreadClusterLengths_M_K
=
Sequence
<
MThreadClusterSize
,
KThreadClusterSize
>
;
using
DYThreadBufferDimAccessOrder
=
typename
conditional
<
DYSrcVectorDim
==
0
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
using
XThreadBufferDimAccessOrder
=
typename
conditional
<
XSrcVectorDim
==
0
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
using
GammaThreadBufferDimAccessOrder
=
typename
conditional
<
GammaSrcVectorDim
==
0
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
using
MeanInvStdThreadBufferDimAccessOrder
=
typename
conditional
<
MeanInvStdSrcVectorDim
==
0
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
using
DXThreadBufferDimAccessOrder
=
typename
conditional
<
DXDstVectorDim
==
0
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
using
ThreadClusterArrangeOrder
=
DYThreadBufferDimAccessOrder
;
static
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
ThreadClusterLengths_M_K
{},
ThreadClusterArrangeOrder
{});
using
ThreadBufferLengths_M_K
=
Sequence
<
MThreadSliceSize
,
KThreadSliceSize
>
;
static
constexpr
auto
thread_buffer_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{}));
static
constexpr
auto
thread_buffer_desc_m
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{}));
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
using
BlockwiseSumReduce
=
PartitionedBlockwiseReduction
<
ComputeDataType
,
BlockSize
,
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
,
reduce
::
Add
,
true
>
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
__device__
static
void
Run
(
const
GridDesc_M_K
&
dy_grid_desc_m_k
,
const
GridDesc_M_K
&
x_grid_desc_m_k
,
const
GridDesc_M_K
&
gamma_grid_desc_m_k
,
const
GridDesc_M_K
&
mean_grid_desc_m_k
,
const
GridDesc_M_K
&
inv_std_grid_desc_m_k
,
const
GridDesc_M_K
&
dx_grid_desc_m_k
,
index_t
num_k_block_tile_iteration
,
const
DYDataType
*
const
__restrict__
p_dy_global
,
const
XDataType
*
const
__restrict__
p_x_global
,
const
GammaDataType
*
const
__restrict__
p_gamma_global
,
const
MeanInvStdDataType
*
const
__restrict__
p_mean_global
,
const
MeanInvStdDataType
*
const
__restrict__
p_inv_std_global
,
DXDataType
*
const
__restrict__
p_dx_global
)
{
// LDS
__shared__
ComputeDataType
p_reduce_work_buffer
[
BlockSize
];
auto
reduce_work_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_reduce_work_buffer
,
BlockSize
);
// Global
const
auto
dy_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_dy_global
,
dy_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
x_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_x_global
,
x_grid_desc_m_k
.
GetElementSpaceSize
());
auto
gamma_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_gamma_global
,
gamma_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
mean_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_mean_global
,
mean_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
inv_std_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_inv_std_global
,
inv_std_grid_desc_m_k
.
GetElementSpaceSize
());
auto
dx_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_dx_global
,
dx_grid_desc_m_k
.
GetElementSpaceSize
());
// VGPR
auto
dy_thread_buf
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
{};
auto
x_thread_buf
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
{};
auto
gamma_thread_buf
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
{};
auto
mean_thread_buf
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
{};
auto
inv_std_thread_buf
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
{};
auto
dx_thread_buf
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
{};
auto
ds_thread_buf
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
,
true
>
{};
auto
db_thread_buf
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
,
true
>
{};
// thread id
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
const
auto
thread_cluster_idx
=
thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
thread_local_id
));
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
I0
];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
I1
];
// IO
auto
threadwise_dy_load
=
ThreadwiseTensorSliceTransfer_v2
<
DYDataType
,
ComputeDataType
,
GridDesc_M_K
,
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
DYThreadBufferDimAccessOrder
,
DYSrcVectorDim
,
DYSrcVectorSize
,
1
,
false
>
(
dy_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_x_load
=
ThreadwiseTensorSliceTransfer_v2
<
XDataType
,
ComputeDataType
,
GridDesc_M_K
,
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
XThreadBufferDimAccessOrder
,
XSrcVectorDim
,
XSrcVectorSize
,
1
,
false
>
(
x_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_gamma_load
=
ThreadwiseTensorSliceTransfer_v2
<
GammaDataType
,
ComputeDataType
,
GridDesc_M_K
,
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
XThreadBufferDimAccessOrder
,
GammaSrcVectorDim
,
GammaSrcVectorSize
,
1
,
false
>
(
gamma_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_mean_load
=
ThreadwiseTensorSliceTransfer_v2
<
MeanInvStdDataType
,
ComputeDataType
,
GridDesc_M_K
,
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
MeanInvStdThreadBufferDimAccessOrder
,
MeanInvStdSrcVectorDim
,
MeanInvStdSrcVectorSize
,
1
,
false
>
(
mean_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_inv_std_load
=
ThreadwiseTensorSliceTransfer_v2
<
MeanInvStdDataType
,
ComputeDataType
,
GridDesc_M_K
,
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
MeanInvStdThreadBufferDimAccessOrder
,
MeanInvStdSrcVectorDim
,
MeanInvStdSrcVectorSize
,
1
,
false
>
(
inv_std_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_dx_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
ComputeDataType
,
DXDataType
,
decltype
(
thread_buffer_desc_m_k
),
GridDesc_M_K
,
PassThroughOp
,
ThreadBufferLengths_M_K
,
DXThreadBufferDimAccessOrder
,
DXDstVectorDim
,
DXDstVectorSize
,
InMemoryDataOperationEnum
::
Set
,
1
,
false
>
(
dx_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
),
PassThroughOp
{});
ComputeDataType
reduce_size
=
type_convert
<
ComputeDataType
>
(
dy_grid_desc_m_k
.
GetTransforms
()[
I2
].
GetUpperLengths
()[
I0
]);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
ds_thread_buf
(
I
)
=
type_convert
<
ComputeDataType
>
(
0.0
f
);
db_thread_buf
(
I
)
=
type_convert
<
ComputeDataType
>
(
0.0
f
);
});
// Separate sweep once and sweep twice pipeline
// Sweep once: for small k, if KThreadClusterSize * KThreadSliceSize > K
// we don't need to use loop to read x, dy, gamma twice
if
constexpr
(
SweepOnce
)
{
threadwise_dy_load
.
Run
(
dy_grid_desc_m_k
,
dy_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
dy_thread_buf
);
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
x_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
x_thread_buf
);
threadwise_gamma_load
.
Run
(
gamma_grid_desc_m_k
,
gamma_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
gamma_thread_buf
);
threadwise_mean_load
.
Run
(
mean_grid_desc_m_k
,
mean_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
mean_thread_buf
);
threadwise_inv_std_load
.
Run
(
inv_std_grid_desc_m_k
,
inv_std_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
inv_std_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
constexpr
auto
offset_m
=
Number
<
thread_buffer_desc_m
.
CalculateOffset
(
make_tuple
(
iM
))
>
{};
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset_m_k
=
Number
<
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
ds_thread_buf
(
offset_m
)
+=
dy_thread_buf
[
offset_m_k
]
*
gamma_thread_buf
[
offset_m_k
]
*
x_thread_buf
[
offset_m_k
];
db_thread_buf
(
offset_m
)
+=
dy_thread_buf
[
offset_m_k
]
*
gamma_thread_buf
[
offset_m_k
];
});
});
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
if
constexpr
(
I
>
0
)
block_sync_lds
();
BlockwiseSumReduce
::
Reduce
(
reduce_work_buf
,
ds_thread_buf
(
I
));
block_sync_lds
();
BlockwiseSumReduce
::
Reduce
(
reduce_work_buf
,
db_thread_buf
(
I
));
});
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
constexpr
auto
offset_m
=
Number
<
thread_buffer_desc_m
.
CalculateOffset
(
make_tuple
(
iM
))
>
{};
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset_m_k
=
Number
<
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
// b = (db * x_mean - ds) * rstd ** (3) / reduce_size
// c = -b * x_mean - db * rstd / reduce_size
// dx = rstd * dy * gamma + b * x + c
ComputeDataType
b
=
db_thread_buf
[
offset_m
]
*
mean_thread_buf
[
offset_m_k
]
-
ds_thread_buf
[
offset_m
];
b
*=
inv_std_thread_buf
[
offset_m_k
]
*
inv_std_thread_buf
[
offset_m_k
]
*
inv_std_thread_buf
[
offset_m_k
]
/
reduce_size
;
ComputeDataType
c
=
-
b
*
mean_thread_buf
(
offset_m_k
);
c
-=
db_thread_buf
[
offset_m
]
*
inv_std_thread_buf
[
offset_m_k
]
/
reduce_size
;
dx_thread_buf
(
offset_m_k
)
=
dy_thread_buf
[
offset_m_k
]
*
gamma_thread_buf
[
offset_m_k
]
*
inv_std_thread_buf
[
offset_m_k
]
+
b
*
x_thread_buf
[
offset_m_k
]
+
c
;
});
});
threadwise_dx_store
.
Run
(
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
dx_thread_buf
,
dx_grid_desc_m_k
,
dx_global_val_buf
);
}
// end of sweep once
else
// Sweep Twice pipeline
{
constexpr
auto
thread_copy_fwd_step_m_k
=
make_multi_index
(
0
,
K_BlockTileSize
);
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_k_block_tile_iteration
;
++
reducedTiles
)
{
threadwise_dy_load
.
Run
(
dy_grid_desc_m_k
,
dy_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
dy_thread_buf
);
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
x_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
x_thread_buf
);
threadwise_gamma_load
.
Run
(
gamma_grid_desc_m_k
,
gamma_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
gamma_thread_buf
);
threadwise_dy_load
.
MoveSrcSliceWindow
(
dy_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
constexpr
auto
offset_m
=
Number
<
thread_buffer_desc_m
.
CalculateOffset
(
make_tuple
(
iM
))
>
{};
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset_m_k
=
Number
<
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
ds_thread_buf
(
offset_m
)
+=
dy_thread_buf
[
offset_m_k
]
*
gamma_thread_buf
[
offset_m_k
]
*
x_thread_buf
[
offset_m_k
];
db_thread_buf
(
offset_m
)
+=
dy_thread_buf
[
offset_m_k
]
*
gamma_thread_buf
[
offset_m_k
];
});
});
}
// end of first sweep
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
if
constexpr
(
I
>
0
)
block_sync_lds
();
BlockwiseSumReduce
::
Reduce
(
reduce_work_buf
,
ds_thread_buf
(
I
));
block_sync_lds
();
BlockwiseSumReduce
::
Reduce
(
reduce_work_buf
,
db_thread_buf
(
I
));
});
// reverse read for using dy, gamma and x in the cache
constexpr
auto
thread_copy_bwd_step_m_k
=
make_multi_index
(
0
,
-
K_BlockTileSize
);
auto
thread_copy_tail_m_k
=
(
num_k_block_tile_iteration
-
1
)
*
thread_copy_fwd_step_m_k
;
// move to tail
threadwise_dy_load
.
MoveSrcSliceWindow
(
dy_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
// move from start to tail
threadwise_mean_load
.
MoveSrcSliceWindow
(
mean_grid_desc_m_k
,
thread_copy_tail_m_k
);
threadwise_inv_std_load
.
MoveSrcSliceWindow
(
inv_std_grid_desc_m_k
,
thread_copy_tail_m_k
);
threadwise_dx_store
.
MoveDstSliceWindow
(
dx_grid_desc_m_k
,
thread_copy_tail_m_k
);
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_k_block_tile_iteration
;
++
reducedTiles
)
{
threadwise_dy_load
.
Run
(
dy_grid_desc_m_k
,
dy_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
dy_thread_buf
);
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
x_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
x_thread_buf
);
threadwise_gamma_load
.
Run
(
gamma_grid_desc_m_k
,
gamma_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
gamma_thread_buf
);
threadwise_mean_load
.
Run
(
mean_grid_desc_m_k
,
mean_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
mean_thread_buf
);
threadwise_inv_std_load
.
Run
(
inv_std_grid_desc_m_k
,
inv_std_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
inv_std_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
constexpr
auto
offset_m
=
Number
<
thread_buffer_desc_m
.
CalculateOffset
(
make_tuple
(
iM
))
>
{};
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset_m_k
=
Number
<
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
// b = (db * x_mean - ds) * rstd ** (3) / reduce_size
// c = -b * x_mean - db * rstd / reduce_size
// dx = rstd * dy * gamma + b * x + c
ComputeDataType
b
=
db_thread_buf
[
offset_m
]
*
mean_thread_buf
[
offset_m_k
]
-
ds_thread_buf
[
offset_m
];
b
*=
inv_std_thread_buf
[
offset_m_k
]
*
inv_std_thread_buf
[
offset_m_k
]
*
inv_std_thread_buf
[
offset_m_k
]
/
reduce_size
;
ComputeDataType
c
=
-
b
*
mean_thread_buf
(
offset_m_k
);
c
-=
db_thread_buf
[
offset_m
]
*
inv_std_thread_buf
[
offset_m_k
]
/
reduce_size
;
dx_thread_buf
(
offset_m_k
)
=
dy_thread_buf
[
offset_m_k
]
*
gamma_thread_buf
[
offset_m_k
]
*
inv_std_thread_buf
[
offset_m_k
]
+
b
*
x_thread_buf
[
offset_m_k
]
+
c
;
});
});
threadwise_dx_store
.
Run
(
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
dx_thread_buf
,
dx_grid_desc_m_k
,
dx_global_val_buf
);
threadwise_dy_load
.
MoveSrcSliceWindow
(
dy_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_mean_load
.
MoveSrcSliceWindow
(
mean_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_inv_std_load
.
MoveSrcSliceWindow
(
inv_std_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_dx_store
.
MoveDstSliceWindow
(
dx_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
}
}
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_gamma_beta.hpp
View file @
32806d5f
...
@@ -35,7 +35,7 @@ template <typename DYDataType,
...
@@ -35,7 +35,7 @@ template <typename DYDataType,
index_t
DBetaDstVectorSize
>
index_t
DBetaDstVectorSize
>
struct
GridwiseNormalizationBwdGammaBeta_mk_to_k
struct
GridwiseNormalizationBwdGammaBeta_mk_to_k
{
{
// if we just check ThreadSliceSize
&
VectorSize == 0, the performance may be poor
// if we just check ThreadSliceSize
%
VectorSize == 0, the performance may be poor
(coalesce)
static_assert
(((
DYSrcVectorDim
==
0
&&
MThreadSliceSize
==
DYSrcVectorSize
)
||
static_assert
(((
DYSrcVectorDim
==
0
&&
MThreadSliceSize
==
DYSrcVectorSize
)
||
(
DYSrcVectorDim
==
1
&&
KThreadSliceSize
==
DYSrcVectorSize
)),
(
DYSrcVectorDim
==
1
&&
KThreadSliceSize
==
DYSrcVectorSize
)),
"Invalid thread slice sizes and/or dy vector sizes configuration, please check!"
);
"Invalid thread slice sizes and/or dy vector sizes configuration, please check!"
);
...
@@ -44,6 +44,15 @@ struct GridwiseNormalizationBwdGammaBeta_mk_to_k
...
@@ -44,6 +44,15 @@ struct GridwiseNormalizationBwdGammaBeta_mk_to_k
(
XSrcVectorDim
==
1
&&
KThreadSliceSize
==
XSrcVectorSize
)),
(
XSrcVectorDim
==
1
&&
KThreadSliceSize
==
XSrcVectorSize
)),
"Invalid thread slice sizes and/or x vector sizes configuration, please check!"
);
"Invalid thread slice sizes and/or x vector sizes configuration, please check!"
);
// do not force SliceSize == MeanInvStdSrcVectorSize for groupnorm
static_assert
(
((
MeanInvStdSrcVectorDim
==
0
&&
MThreadSliceSize
%
MeanInvStdSrcVectorSize
==
0
)
||
(
MeanInvStdSrcVectorDim
==
1
&&
KThreadSliceSize
%
MeanInvStdSrcVectorSize
==
0
)),
"Invalid thread slice sizes and/or mean/inv_std vector sizes configuration, please check!"
);
static_assert
(
MThreadSliceSize
==
DGammaDstVectorSize
&&
MThreadSliceSize
==
DBetaDstVectorSize
,
"Invalid thread slice sizes and/or dx vector sizes configuration, please check!"
);
using
ThreadClusterLengths_M_K
=
Sequence
<
MThreadClusterSize
,
KThreadClusterSize
>
;
using
ThreadClusterLengths_M_K
=
Sequence
<
MThreadClusterSize
,
KThreadClusterSize
>
;
using
DYThreadBufferDimAccessOrder
=
using
DYThreadBufferDimAccessOrder
=
...
...
include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp
View file @
32806d5f
...
@@ -522,22 +522,21 @@ struct TransformConvFwdToGemm
...
@@ -522,22 +522,21 @@ struct TransformConvFwdToGemm
// for output bias
// for output bias
template
<
typename
CLayout
,
template
<
typename
CLayout
,
typename
std
::
enable_if
<
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
GK
>
||
typename
std
::
enable_if
<
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_K
>,
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_K
>
,
bool
>::
type
=
false
>
bool
>::
type
=
false
>
static
auto
static
auto
MakeCDescriptor_M_N
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_lengths
,
MakeCDescriptor_M_N
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_strides
)
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* c_g_n_k_wos_strides */
)
{
{
const
index_t
N
=
c_g_n_k_wos_lengths
[
1
];
const
index_t
N
=
c_g_n_k_wos_lengths
[
1
];
const
index_t
K
=
c_g_n_k_wos_lengths
[
2
];
const
index_t
K
=
c_g_n_k_wos_lengths
[
2
];
const
index_t
KStride
=
c_g_n_k_wos_strides
[
2
];
const
index_t
NHoWo
=
const
index_t
NHoWo
=
N
*
ck
::
accumulate_n
<
index_t
>
(
N
*
ck
::
accumulate_n
<
index_t
>
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
c_g_n_k_wos_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
const
auto
out_gemmm_gemmn_desc
=
const
auto
out_gemmm_gemmn_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
NHoWo
,
K
),
make_tuple
(
I0
,
I1
));
make_naive_tensor_descriptor
(
make_tuple
(
NHoWo
,
K
),
make_tuple
(
I0
,
KStride
));
return
out_gemmm_gemmn_desc
;
return
out_gemmm_gemmn_desc
;
}
}
...
...
include/ck/utility/tuple_helper.hpp
View file @
32806d5f
...
@@ -166,4 +166,16 @@ __host__ __device__ constexpr auto IsNestedTuple(const Tuple<Ts...>&)
...
@@ -166,4 +166,16 @@ __host__ __device__ constexpr auto IsNestedTuple(const Tuple<Ts...>&)
return
(
is_detected
<
is_tuple
,
Ts
>::
value
||
...);
return
(
is_detected
<
is_tuple
,
Ts
>::
value
||
...);
}
}
template
<
index_t
depth
=
0
,
typename
T
>
__host__
__device__
constexpr
auto
TupleDepth
(
const
T
&
)
{
return
depth
;
}
template
<
index_t
depth
=
0
,
typename
...
Ts
>
__host__
__device__
constexpr
auto
TupleDepth
(
const
Tuple
<
Ts
...
>&
)
{
return
math
::
max
(
TupleDepth
<
depth
+
1
>
(
Ts
{})...);
}
}
// namespace ck
}
// namespace ck
example/64_tensor_transforms/tensor_transform_wrapper
.hpp
→
include/ck/wrapper/layout
.hpp
View file @
32806d5f
...
@@ -3,27 +3,13 @@
...
@@ -3,27 +3,13 @@
#pragma once
#pragma once
#include "ck/ck.hpp"
#include "ck/wrapper/utils/layout_utils.hpp"
#include "ck/utility/number.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/utility/tuple_helper.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/utility/sequence_helper.hpp"
#include "ck/utility/is_detected.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_transform_
wrapper
{
namespace
wrapper
{
/**
/**
* \brief Layout wrapper
* \brief Layout wrapper that performs the tensor descriptor logic.
*
* \details
* Layout wrapper that performs the tensor descriptor logic.
*
*
* \tparam Shape Tuple of Number<> (for compile-time layout) or index_t
* \tparam Shape Tuple of Number<> (for compile-time layout) or index_t
* (dynamic layout). It is possible to pass nested shapes
* (dynamic layout). It is possible to pass nested shapes
...
@@ -32,21 +18,39 @@ namespace tensor_transform_wrapper {
...
@@ -32,21 +18,39 @@ namespace tensor_transform_wrapper {
* (dynamic layout). Stride tuple should be nested if shape tuple is
* (dynamic layout). Stride tuple should be nested if shape tuple is
* nested.
* nested.
*/
*/
template
<
typename
Shape
,
typename
Strides
=
Tuple
<
>
>
template
<
typename
Shape
,
typename
Strides
>
struct
Layout
struct
Layout
{
{
private:
private:
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
template
<
typename
T
>
// Generate default idxs tuple (idx with all merged nested shapes)
using
is_tuple
=
decltype
(
std
::
declval
<
T
&>
().
IsTuple
());
template
<
typename
...
Ts
>
__host__
__device__
constexpr
static
auto
GenerateDefaultIdxsTuple
(
const
Tuple
<
Ts
...
>&
)
{
return
generate_tuple
(
[
&
](
auto
)
{
if
constexpr
(
!
FlattenDescriptorType
::
IsKnownAtCompileTime
())
{
// runtime layout
return
index_t
(
0
);
}
else
{
// compiletime layout
return
I0
;
}
},
Number
<
Tuple
<
Ts
...
>::
Size
()
>
{});
}
// Generate packed (column-major) strides if not passed
// Generate packed (column-major) strides if not passed
template
<
typename
...
Ts
>
template
<
typename
...
Ts
>
__host__
__device__
constexpr
static
auto
__host__
__device__
constexpr
static
auto
GenerateColumnMajorPackedStrides
(
const
Tuple
<
Ts
...
>&
tupl
e
)
GenerateColumnMajorPackedStrides
(
const
Tuple
<
Ts
...
>&
shap
e
)
{
{
const
auto
unrolled_shape
=
UnrollNestedTuple
(
shape
);
return
generate_tuple
(
return
generate_tuple
(
[
&
](
auto
i
)
{
[
&
](
auto
i
)
{
if
constexpr
(
i
.
value
==
0
)
if
constexpr
(
i
.
value
==
0
)
...
@@ -56,10 +60,10 @@ struct Layout
...
@@ -56,10 +60,10 @@ struct Layout
else
else
{
{
return
TupleReduce
<
I0
.
value
,
i
.
value
>
([](
auto
x
,
auto
y
)
{
return
x
*
y
;
},
return
TupleReduce
<
I0
.
value
,
i
.
value
>
([](
auto
x
,
auto
y
)
{
return
x
*
y
;
},
tupl
e
);
unrolled_shap
e
);
}
}
},
},
Number
<
Tuple
<
Ts
...
>
::
Size
()
>
{});
Number
<
decltype
(
unrolled_shape
)
::
Size
()
>
{});
}
}
// Generate LowerDims in Compile-time for MergeTrasform using passed Type
// Generate LowerDims in Compile-time for MergeTrasform using passed Type
...
@@ -112,8 +116,8 @@ struct Layout
...
@@ -112,8 +116,8 @@ struct Layout
// Example shape: (2, (2, 2)), 2, (2, 2)
// Example shape: (2, (2, 2)), 2, (2, 2)
// Unrolled shape: 2, (2, 2), 2, (2, 2)
// Unrolled shape: 2, (2, 2), 2, (2, 2)
template
<
typename
...
ShapeDims
,
typename
...
IdxDims
>
template
<
typename
...
ShapeDims
,
typename
...
IdxDims
>
__host__
__device__
constexpr
static
auto
Unroll
Shape
Via
Idx
(
const
Tuple
<
ShapeDims
...
>&
shape
,
__host__
__device__
constexpr
static
auto
Align
Shape
To
Idx
(
const
Tuple
<
ShapeDims
...
>&
shape
,
const
Tuple
<
IdxDims
...
>&
idx
)
const
Tuple
<
IdxDims
...
>&
idx
)
{
{
if
constexpr
(
!
IsNestedTuple
(
Tuple
<
IdxDims
...
>
{}))
if
constexpr
(
!
IsNestedTuple
(
Tuple
<
IdxDims
...
>
{}))
{
{
...
@@ -125,7 +129,7 @@ struct Layout
...
@@ -125,7 +129,7 @@ struct Layout
// Iterate over shape tuple elements:
// Iterate over shape tuple elements:
// 1. If corresponding idx element is tuple then return (will be unrolled)
// 1. If corresponding idx element is tuple then return (will be unrolled)
// 2. If no, pack in tuple. It will be restored during unroll.
// 2. If no, pack in tuple. It will be restored during unroll.
auto
unroll
ed_shape
_via_idx
=
generate_tuple
(
auto
align
ed_shape
=
generate_tuple
(
[
&
](
auto
i
)
{
[
&
](
auto
i
)
{
if
constexpr
(
is_detected
<
is_tuple
,
if
constexpr
(
is_detected
<
is_tuple
,
tuple_element_t
<
i
,
Tuple
<
IdxDims
...
>>>::
value
)
tuple_element_t
<
i
,
Tuple
<
IdxDims
...
>>>::
value
)
...
@@ -140,37 +144,34 @@ struct Layout
...
@@ -140,37 +144,34 @@ struct Layout
Number
<
Tuple
<
IdxDims
...
>::
Size
()
>
{});
Number
<
Tuple
<
IdxDims
...
>::
Size
()
>
{});
// Unroll and process next step
// Unroll and process next step
return
Unroll
Shape
Via
Idx
(
UnrollNestedTuple
<
0
,
1
>
(
unroll
ed_shape
_via_idx
),
return
Align
Shape
To
Idx
(
UnrollNestedTuple
<
0
,
1
>
(
align
ed_shape
),
UnrollNestedTuple
<
0
,
1
>
(
idx
));
UnrollNestedTuple
<
0
,
1
>
(
idx
));
}
}
}
}
template
<
typename
...
ShapeDims
,
typename
DescriptorToMerge
>
template
<
typename
...
ShapeDims
,
typename
DescriptorToMerge
>
__host__
__device__
constexpr
static
auto
MakeMerge1d
(
const
Tuple
<
ShapeDims
...
>&
shape
,
__host__
__device__
constexpr
static
auto
MakeMerge1d
(
const
Tuple
<
ShapeDims
...
>&
shape
,
DescriptorToMerge
&
desc
)
const
DescriptorToMerge
&
desc
)
{
{
// Reverse each element in tuple
// Reverse each element in tuple
using
ReversedUnrolledShape
=
decltype
(
TupleReverse
(
UnrollNestedTuple
(
shape
)));
const
auto
merge_elems
=
TupleReverse
(
UnrollNestedTuple
(
shape
));
const
auto
merge_elems
=
ReversedUnrolledShape
{};
// Generate reverted indexes (column major traverse)
// Generate reverted indexes (column major traverse)
using
MergeElemsSequence
=
using
MergeElemsSequence
=
typename
arithmetic_sequence_gen
<
0
,
merge_elems
.
Size
(),
1
>::
type
;
typename
arithmetic_sequence_gen
<
0
,
ReversedUnrolledShape
::
Size
(),
1
>::
type
;
const
auto
lower_dims
=
make_tuple
(
MergeElemsSequence
::
Reverse
());
const
auto
lower_dims
=
make_tuple
(
MergeElemsSequence
::
Reverse
());
const
auto
upper_dims
=
make_tuple
(
Sequence
<
0
>
{});
const
auto
upper_dims
=
make_tuple
(
Sequence
<
0
>
{});
// Merge to 1d
// Merge to 1d
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
merge_elems
)),
lower_dims
,
upper_dims
);
desc
,
make_tuple
(
make_merge_transform
(
merge_elems
)),
lower_dims
,
upper_dims
);
}
}
// Merge nested shape dims
// Merge nested shape dims
when corresponding index is also nested.
// Input desc shape: 2, 2, 2, 2, 2, 2
// Input desc shape: 2, 2, 2, 2, 2, 2
// Example idx: 1, 1, 1, 1
// Example idx: 1, 1, 1, 1
// Example shape: 2, (2, 2), 2, (2, 2)
// Example shape: 2, (2, 2), 2, (2, 2)
// Merged shape: 2, 4, 2, 4
// Merged shape: 2, 4, 2, 4
template
<
typename
...
ShapeDims
,
typename
...
IdxDims
,
typename
DescriptorToMerge
>
template
<
typename
...
ShapeDims
,
typename
...
IdxDims
,
typename
DescriptorToMerge
>
__host__
__device__
constexpr
static
auto
__host__
__device__
constexpr
static
auto
CreateMergedDescriptor
(
MakeMerges
(
const
Tuple
<
ShapeDims
...
>&
shape
,
const
Tuple
<
IdxDims
...
>&
,
DescriptorToMerge
&
desc
)
const
Tuple
<
ShapeDims
...
>&
shape
,
const
Tuple
<
IdxDims
...
>&
,
DescriptorToMerge
&
desc
)
{
{
const
auto
transforms
=
generate_tuple
(
const
auto
transforms
=
generate_tuple
(
[
&
](
auto
i
)
{
[
&
](
auto
i
)
{
...
@@ -206,14 +207,38 @@ struct Layout
...
@@ -206,14 +207,38 @@ struct Layout
return
transform_tensor_descriptor
(
desc
,
transforms
,
lower_dims
,
upper_dims
);
return
transform_tensor_descriptor
(
desc
,
transforms
,
lower_dims
,
upper_dims
);
}
}
template
<
typename
LayoutShape
,
typename
LayoutStrides
>
__host__
__device__
static
auto
MakeFlattenDescriptor
(
const
LayoutShape
&
shape
,
const
LayoutStrides
&
strides
)
{
const
auto
unrolled_shape
=
UnrollNestedTuple
(
shape
);
const
auto
unrolled_strides
=
UnrollNestedTuple
(
strides
);
static_assert
(
unrolled_shape
.
Size
()
==
unrolled_strides
.
Size
(),
"Size of strides and shape are not consistent."
);
return
make_naive_tensor_descriptor
(
unrolled_shape
,
unrolled_strides
);
}
// If the stride is not passed, you can infer it from `GenerateColumnMajorPackedStrides`.
using
DeducedStrides
=
std
::
conditional_t
<
is_same_v
<
Strides
,
Tuple
<>>
,
remove_cvref_t
<
decltype
(
GenerateColumnMajorPackedStrides
(
Shape
{}))
>
,
Strides
>
;
using
FlattenDescriptorType
=
remove_cvref_t
<
decltype
(
MakeFlattenDescriptor
(
Shape
{},
DeducedStrides
{}))
>
;
using
Descriptor1dType
=
remove_cvref_t
<
decltype
(
MakeMerge1d
(
Shape
{},
FlattenDescriptorType
{}))
>
;
using
DefaultIdxsTupleType
=
remove_cvref_t
<
decltype
(
GenerateDefaultIdxsTuple
(
Shape
{}))
>
;
template
<
typename
...
ShapeDims
,
typename
...
IdxDims
>
template
<
typename
...
ShapeDims
,
typename
...
IdxDims
>
__host__
__device__
constexpr
auto
TransformDesc
(
const
Tuple
<
ShapeDims
...
>&
shape
,
__host__
__device__
constexpr
static
auto
const
Tuple
<
IdxDims
...
>&
idx
)
const
TransformDesc
(
const
Tuple
<
ShapeDims
...
>&
shape
,
const
Tuple
<
IdxDims
...
>&
idx
,
const
FlattenDescriptorType
&
naive_descriptor
)
{
{
if
constexpr
(
Tuple
<
IdxDims
...
>::
Size
()
==
I1
)
if
constexpr
(
Tuple
<
IdxDims
...
>::
Size
()
==
I1
)
{
{
// 1d idx path
// 1d idx path
return
MakeMerge1d
(
shape
,
descriptor
_
);
return
MakeMerge1d
(
shape
,
naive_
descriptor
);
}
}
else
else
{
{
...
@@ -224,62 +249,55 @@ struct Layout
...
@@ -224,62 +249,55 @@ struct Layout
static_assert
(
Tuple
<
ShapeDims
...
>::
Size
()
==
Tuple
<
IdxDims
...
>::
Size
(),
static_assert
(
Tuple
<
ShapeDims
...
>::
Size
()
==
Tuple
<
IdxDims
...
>::
Size
(),
"Idx rank and Shape rank must be the same (except 1d)."
);
"Idx rank and Shape rank must be the same (except 1d)."
);
// Unroll while IdxDims is nested
// Unroll while IdxDims is nested
const
auto
unroll
ed_shape
_via_idx
=
Unroll
Shape
Via
Idx
(
shape
,
idx
);
const
auto
align
ed_shape
=
Align
Shape
To
Idx
(
shape
,
idx
);
// Transform correct form of shape
// Transform correct form of shape
return
Mak
eMerge
s
(
unrolled_shape_via_idx
,
UnrollNestedTuple
(
idx
),
descriptor
_
);
return
Creat
eMerge
dDescriptor
(
aligned_shape
,
UnrollNestedTuple
(
idx
),
naive_
descriptor
);
}
}
}
}
template
<
typename
LayoutShape
,
typename
LayoutStrides
>
using
MergedNestsDescriptorType
=
remove_cvref_t
<
decltype
(
TransformDesc
(
__host__
__device__
static
auto
MakeNaiveDescriptor
(
const
LayoutShape
&
shape
,
Shape
{},
DefaultIdxsTupleType
{},
FlattenDescriptorType
{}))
>
;
const
LayoutStrides
&
strides
)
{
const
auto
unrolled_shape
=
UnrollNestedTuple
(
shape
);
if
constexpr
(
ck
::
is_same_v
<
LayoutStrides
,
Tuple
<>>
)
{
// If shape is packed
const
auto
column_major_packed_strides
=
GenerateColumnMajorPackedStrides
(
unrolled_shape
);
return
make_naive_tensor_descriptor
(
unrolled_shape
,
column_major_packed_strides
);
}
else
{
const
auto
unrolled_strides
=
UnrollNestedTuple
(
strides
);
static_assert
(
unrolled_shape
.
Size
()
==
unrolled_strides
.
Size
(),
"Size of strides and shape are not consistent."
);
return
make_naive_tensor_descriptor
(
unrolled_shape
,
unrolled_strides
);
}
}
public:
public:
using
NaiveDescriptorType
=
remove_cvref_t
<
decltype
(
MakeNaiveDescriptor
(
Shape
{},
Strides
{}))
>
;
__host__
__device__
constexpr
auto
GetElementSpaceSize
()
const
{
return
flatten_descriptor_
.
GetElementSpaceSize
();
}
__host__
__device__
Layout
()
=
delete
;
/**
/**
* \brief Layout constructor.
* \brief Layout constructor.
*
*
* \param shape Shape for layout.
* \param shape Shape for layout.
* \param strides Strides for layout (optional if tensor is packed).
* \param strides Strides for layout (optional if tensor is packed).
* \return Layout object.
*/
*/
__host__
__device__
Layout
()
=
delete
;
__host__
__device__
constexpr
Layout
(
const
Shape
&
shape
,
const
Strides
&
strides
)
__host__
__device__
Layout
(
const
S
hape
&
shape
,
const
S
trides
&
strides
)
:
descriptor_
{}
:
flatten_descriptor_
{},
s
hape
_
(
shape
)
,
s
trides
_
(
strides
)
{
{
// Construct if runtime mode
// Construct if runtime mode
if
constexpr
(
!
Naive
DescriptorType
::
IsKnownAtCompileTime
())
if
constexpr
(
!
Flatten
DescriptorType
::
IsKnownAtCompileTime
())
{
{
// Keep only shape, strides are not need for transforms
flatten_descriptor_
=
MakeFlattenDescriptor
(
shape_
,
strides_
);
shape_
=
shape
;
descriptor_1d_
=
MakeMerge1d
(
shape_
,
flatten_descriptor_
);
descriptor_
=
MakeNaiveDescriptor
(
shape
,
strides
);
merged_nests_descriptor_
=
TransformDesc
(
shape_
,
DefaultIdxsTupleType
{},
flatten_descriptor_
);
}
}
}
}
__host__
__device__
Layout
(
const
Shape
&
shape
)
:
descriptor_
{}
/**
* \brief Layout constructor (with default packed column-major strides).
*
* \param shape Shape for layout.
*/
__host__
__device__
constexpr
Layout
(
const
Shape
&
shape
)
:
flatten_descriptor_
{},
shape_
(
shape
),
strides_
(
GenerateColumnMajorPackedStrides
(
shape_
))
{
{
if
constexpr
(
!
Naive
DescriptorType
::
IsKnownAtCompileTime
())
if
constexpr
(
!
Flatten
DescriptorType
::
IsKnownAtCompileTime
())
{
{
shape_
=
shape
;
flatten_descriptor_
=
MakeFlattenDescriptor
(
shape_
,
strides_
);
descriptor_
=
MakeNaiveDescriptor
(
shape
,
Strides
{});
descriptor_1d_
=
MakeMerge1d
(
shape_
,
flatten_descriptor_
);
merged_nests_descriptor_
=
TransformDesc
(
shape_
,
DefaultIdxsTupleType
{},
flatten_descriptor_
);
}
}
}
}
...
@@ -292,7 +310,9 @@ struct Layout
...
@@ -292,7 +310,9 @@ struct Layout
template
<
typename
Idxs
>
template
<
typename
Idxs
>
__host__
__device__
constexpr
index_t
operator
()()
const
__host__
__device__
constexpr
index_t
operator
()()
const
{
{
using
TransformedDesc
=
decltype
(
TransformDesc
(
Shape
{},
Idxs
{}));
static_assert
(
FlattenDescriptorType
::
IsKnownAtCompileTime
(),
"Compiletime operator used on runtime layout."
);
using
TransformedDesc
=
decltype
(
TransformDesc
(
Shape
{},
Idxs
{},
FlattenDescriptorType
{}));
using
UnrolledIdx
=
decltype
(
UnrollNestedTuple
(
Idxs
{}));
using
UnrolledIdx
=
decltype
(
UnrollNestedTuple
(
Idxs
{}));
return
TransformedDesc
{}.
CalculateOffset
(
UnrolledIdx
{});
return
TransformedDesc
{}.
CalculateOffset
(
UnrolledIdx
{});
}
}
...
@@ -306,9 +326,22 @@ struct Layout
...
@@ -306,9 +326,22 @@ struct Layout
template
<
typename
...
Ts
>
template
<
typename
...
Ts
>
__host__
__device__
index_t
operator
()(
const
Tuple
<
Ts
...
>&
Idx
)
const
__host__
__device__
index_t
operator
()(
const
Tuple
<
Ts
...
>&
Idx
)
const
{
{
// Static to construct transformed_desc only once
if
constexpr
(
!
IsNestedTuple
(
Tuple
<
Ts
...
>
{})
&&
Tuple
<
Ts
...
>::
Size
()
==
1
)
static
const
auto
transformed_desc
=
TransformDesc
(
shape_
,
Idx
);
{
return
transformed_desc
.
CalculateOffset
(
UnrollNestedTuple
(
Idx
));
// if 1d access
return
descriptor_1d_
.
CalculateOffset
(
Idx
);
}
else
if
constexpr
(
!
IsNestedTuple
(
Tuple
<
Ts
...
>
{})
&&
Tuple
<
Ts
...
>::
Size
()
==
Shape
::
Size
())
{
// if Shape::Size() access (merged nested shapes)
return
merged_nests_descriptor_
.
CalculateOffset
(
UnrollNestedTuple
(
Idx
));
}
else
{
// Custom index, need to transform descriptor
const
auto
transformed_desc
=
TransformDesc
(
shape_
,
Idx
,
flatten_descriptor_
);
return
transformed_desc
.
CalculateOffset
(
UnrollNestedTuple
(
Idx
));
}
}
}
/**
/**
...
@@ -338,7 +371,7 @@ struct Layout
...
@@ -338,7 +371,7 @@ struct Layout
*
*
* \return Calculated size.
* \return Calculated size.
*/
*/
__host__
__device__
constexpr
index_t
GetLength
()
const
__host__
__device__
constexpr
index_t
GetLength
s
()
const
{
{
const
auto
unrolled_shape
=
UnrollNestedTuple
(
shape_
);
const
auto
unrolled_shape
=
UnrollNestedTuple
(
shape_
);
return
TupleReduce
<
I0
.
value
,
unrolled_shape
.
Size
()
>
([](
auto
x
,
auto
y
)
{
return
x
*
y
;
},
return
TupleReduce
<
I0
.
value
,
unrolled_shape
.
Size
()
>
([](
auto
x
,
auto
y
)
{
return
x
*
y
;
},
...
@@ -346,80 +379,56 @@ struct Layout
...
@@ -346,80 +379,56 @@ struct Layout
}
}
/**
/**
* \brief
Dimension
getter.
* \brief
Shape
getter.
*
*
* \tparam IDim Dimension idx.
* \return Shape.
* \return Calculated size.
*/
*/
template
<
index_t
IDim
>
__host__
__device__
constexpr
const
Shape
&
GetShape
()
const
{
return
shape_
;
}
__host__
__device__
constexpr
auto
Get
()
const
{
const
auto
elem
=
shape_
.
At
(
Number
<
IDim
>
{});
return
elem
;
}
private:
/**
NaiveDescriptorType
descriptor_
;
* \brief Strides getter.
Shape
shape_
;
*
};
* \return Strides.
*/
// Layout helpers
__host__
__device__
constexpr
const
DeducedStrides
&
GetStrides
()
const
{
return
strides_
;
}
// Length getter (product if tuple)
template
<
index_t
idx
,
typename
Shape
,
typename
Strides
>
__host__
__device__
constexpr
index_t
size
(
const
Layout
<
Shape
,
Strides
>&
layout
)
{
return
layout
.
template
GetLength
<
idx
>();
}
// Get shape size (product of dims if tuple)
template
<
typename
...
ShapeDims
>
__host__
__device__
constexpr
index_t
size
(
const
Tuple
<
ShapeDims
...
>&
shape
)
{
using
UnrolledShape
=
decltype
(
UnrollNestedTuple
(
shape
));
return
TupleReduce
<
0
,
UnrolledShape
::
Size
()
>
([](
auto
x
,
auto
y
)
{
return
x
*
y
;
},
UnrolledShape
{});
}
// Get dim size (could be returned from get function)
template
<
typename
T
>
__host__
__device__
T
constexpr
size
(
const
T
&
dim
)
{
return
dim
;
}
// Get layout size (product of shapes)
template
<
typename
Shape
,
typename
Strides
>
__host__
__device__
constexpr
index_t
size
(
const
Layout
<
Shape
,
Strides
>&
layout
)
{
return
layout
.
GetLength
();
}
// Get shape element size
/**
template
<
index_t
idx
,
typename
...
ShapeDims
>
* \brief Get default lengths (tuple filled with Shape length elements).
__host__
__device__
constexpr
index_t
size
(
const
Tuple
<
ShapeDims
...
>&
shape
)
*
{
* \return Default lengths.
return
size
(
shape
.
At
(
Number
<
idx
>
{}));
*/
}
__host__
__device__
constexpr
auto
GetDefaultLengthsTuple
()
const
{
return
generate_tuple
([
&
](
auto
i
)
{
return
GetLength
<
i
>
();
},
Number
<
Shape
::
Size
()
>
{});
}
// Dim getter (tuple if tuple)
/**
template
<
index_t
idx
,
typename
Shape
,
typename
Strides
>
* \brief Get default start idx (tuple filled with 0s of the same size as Shape).
__host__
__device__
constexpr
auto
get
(
const
Layout
<
Shape
,
Strides
>&
layout
)
*
{
* \return Default start idx.
return
layout
.
template
Get
<
idx
>();
*/
}
__host__
__device__
constexpr
auto
GetDefaultStartIdxs
()
const
{
return
GenerateDefaultIdxsTuple
(
shape_
);
}
template
<
typename
Shape
,
typename
Strides
>
/**
__host__
__device__
constexpr
Layout
<
Shape
,
Strides
>
make_layout
(
const
Shape
&
shape
,
* \brief Get default descriptor (with the same size as Shape)
const
Strides
&
strides
)
*
{
* \return Default descriptor.
return
Layout
<
Shape
,
Strides
>
(
shape
,
strides
);
*/
}
__host__
__device__
constexpr
MergedNestsDescriptorType
GetDefaultDescriptor
()
{
return
merged_nests_descriptor_
;
}
template
<
typename
Shape
>
private:
__host__
__device__
constexpr
Layout
<
Shape
>
make_layout
(
const
Shape
&
shape
)
FlattenDescriptorType
flatten_descriptor_
;
{
Descriptor1dType
descriptor_1d_
;
return
Layout
<
Shape
>
(
shape
);
MergedNestsDescriptorType
merged_nests_descriptor_
;
}
const
Shape
shape_
;
const
DeducedStrides
strides_
;
};
}
// namespace
tensor_transform_
wrapper
}
// namespace wrapper
}
// namespace ck
}
// namespace ck
include/ck/wrapper/tensor.hpp
0 → 100644
View file @
32806d5f
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "utils/tensor_utils.hpp"
#include "utils/layout_utils.hpp"
namespace
ck
{
namespace
wrapper
{
/**
* \brief Tensor wrapper that performs static and dynamic buffer logic.
*
* \tparam BufferAddressSpace Memory type (Generic, Global, LDS, VGPR, SGPR).
* \tparam ElementType Element data type.
* \tparam Shape Tensor shape (layout component).
* \tparam Strides Tensor strides (layout component).
* \tparam NumVectors Number of vectors (only for VGPR, SGPR).
* \tparam ScalarPerVector Scalars per vector (only for VGPR, SGPR).
*/
template
<
MemoryTypeEnum
BufferAddressSpace
,
typename
ElementType
,
typename
Shape
,
typename
Strides
,
index_t
NumVectors
,
// param for Register memory
index_t
ScalarPerVector
// param for Register memory
>
struct
Tensor
{
private:
// Check if Tuple contains Slice object
template
<
typename
T
>
constexpr
static
bool
IsSlicing
(
T
&&
)
{
return
is_detected
<
is_slice
,
T
>::
value
;
}
template
<
typename
...
Ts
>
constexpr
static
bool
IsSlicing
(
Tuple
<
Ts
...
>&&
)
{
return
(
IsSlicing
(
Ts
{})
||
...);
}
// Calculate first index of new tensor after slice
// It is needed to calculate offset for new tensor
template
<
typename
...
Ts
>
constexpr
auto
GetStartIdxForSlicedTensor
(
const
Tuple
<
Ts
...
>&
idx
)
const
{
const
auto
start_idx_for_sliced_tensor
=
generate_tuple
(
[
&
](
auto
i
)
{
constexpr
auto
num_i
=
Number
<
i
>
{};
if
constexpr
(
is_detected
<
is_tuple
,
tuple_element_t
<
i
.
value
,
Tuple
<
Ts
...
>>>::
value
)
{
// if tuple then recurrence
return
GetStartIdxForSlicedTensor
(
idx
.
At
(
num_i
));
}
else
if
constexpr
(
is_detected
<
is_slice
,
tuple_element_t
<
i
.
value
,
Tuple
<
Ts
...
>>>::
value
)
{
// if slice, return the beginning of the interval
return
idx
.
At
(
num_i
).
from_
;
}
else
{
// if one dim selected
return
idx
.
At
(
num_i
);
}
},
Number
<
Tuple
<
Ts
...
>::
Size
()
>
{});
return
start_idx_for_sliced_tensor
;
}
// Calculate new tensor shape after slice
template
<
typename
...
Ts
,
typename
ShapeTmpType
>
constexpr
auto
GetShapeFromSlicedTensor
(
const
Tuple
<
Ts
...
>&
idx
,
const
ShapeTmpType
&
shape
)
const
{
// Pack each value in tuple to remove empty tuples after generation
auto
new_shape
=
generate_tuple
(
[
&
](
auto
i
)
{
constexpr
auto
num_i
=
Number
<
i
>
{};
if
constexpr
(
is_detected
<
is_tuple
,
tuple_element_t
<
i
.
value
,
Tuple
<
Ts
...
>>>::
value
)
{
if
constexpr
(
!
IsSlicing
(
tuple_element_t
<
i
.
value
,
Tuple
<
Ts
...
>>
{}))
{
// if tuple does not have any slice then we can remove dimension
return
Tuple
<>
{};
}
else
{
// if tuple then recurrence
return
make_tuple
(
GetShapeFromSlicedTensor
(
idx
.
At
(
num_i
),
shape
.
At
(
num_i
)));
}
}
else
if
constexpr
(
is_detected
<
is_slice
,
tuple_element_t
<
i
.
value
,
Tuple
<
Ts
...
>>>::
value
)
{
// calculate new dimension
const
auto
&
dim
=
size
(
shape
.
At
(
num_i
));
const
auto
val
=
idx
.
At
(
num_i
).
range
(
dim
);
return
make_tuple
(
val
);
}
else
{
// remove dimension for just value
return
Tuple
<>
{};
}
},
Number
<
Tuple
<
Ts
...
>::
Size
()
>
{});
// Remove empty tuples (deleted elements) and return
return
UnrollNestedTuple
<
0
,
1
>
(
new_shape
);
}
template
<
typename
...
Ts
,
typename
StridesTmpType
>
constexpr
auto
GetStridesFromSlicedTensor
(
const
Tuple
<
Ts
...
>&
idx
,
const
StridesTmpType
&
strides
)
const
{
// Pack each value in tuple to remove empty tuples after generation
auto
new_strides
=
generate_tuple
(
[
&
](
auto
i
)
{
constexpr
auto
num_i
=
Number
<
i
>
{};
if
constexpr
(
is_detected
<
is_tuple
,
tuple_element_t
<
i
.
value
,
Tuple
<
Ts
...
>>>::
value
)
{
if
constexpr
(
!
IsSlicing
(
tuple_element_t
<
i
.
value
,
Tuple
<
Ts
...
>>
{}))
{
// if tuple does not have any slice then we can remove dimension
return
Tuple
<>
{};
}
else
{
// if tuple then recurrence
return
make_tuple
(
GetStridesFromSlicedTensor
(
idx
.
At
(
num_i
),
strides
.
At
(
num_i
)));
}
}
else
if
constexpr
(
is_detected
<
is_slice
,
tuple_element_t
<
i
.
value
,
Tuple
<
Ts
...
>>>::
value
)
{
// Stride will be the same
return
make_tuple
(
strides
.
At
(
num_i
));
}
else
{
// remove dimension for just value
return
Tuple
<>
{};
}
},
Number
<
Tuple
<
Ts
...
>::
Size
()
>
{});
// Remove empty tuples (deleted elements) and return
return
UnrollNestedTuple
<
0
,
1
>
(
new_strides
);
}
public:
using
ElementSpaceSize
=
decltype
(
Layout
<
Shape
,
Strides
>
{
Shape
{},
Strides
{}}.
GetElementSpaceSize
());
// SpaceSize type for buffer
using
TensorElementType
=
ElementType
;
// DataType
static
constexpr
MemoryTypeEnum
TensorBufferAddressSpace
=
BufferAddressSpace
;
static
constexpr
bool
IsDynamicBuffer
=
!
(
BufferAddressSpace
==
MemoryTypeEnum
::
Sgpr
||
BufferAddressSpace
==
MemoryTypeEnum
::
Vgpr
);
__host__
__device__
Tensor
()
=
delete
;
__host__
__device__
Tensor
(
ElementType
*
pointer
,
const
Layout
<
Shape
,
Strides
>&
layout
)
:
layout_
(
layout
),
buffer_
(
make_dynamic_buffer
<
BufferAddressSpace
>
(
pointer
,
layout
.
GetElementSpaceSize
()))
{
}
__host__
__device__
Tensor
(
const
Layout
<
Shape
,
Strides
>&
layout
)
:
layout_
(
layout
)
{
static_assert
(
!
IsDynamicBuffer
,
"Wrong BufferAddressSpace for register."
);
}
__host__
__device__
constexpr
const
Layout
<
Shape
,
Strides
>&
GetLayout
()
const
{
return
layout_
;
}
// Getter for new sliced tensor
template
<
typename
...
Ts
,
enable_if_t
<
IsSlicing
(
Tuple
<
Ts
...>{}),
bool
>
=
false
>
__host__
__device__
auto
operator
[](
const
Tuple
<
Ts
...
>&
idx
)
const
{
static_assert
(
IsDynamicBuffer
,
"Register slice is not supported"
);
// Calculate offset based on first idx for new tensor
const
index_t
offset
=
layout_
(
GetStartIdxForSlicedTensor
(
idx
));
auto
new_shape
=
GetShapeFromSlicedTensor
(
idx
,
layout_
.
GetShape
());
if
constexpr
(
is_same_v
<
Strides
,
Tuple
<>>
)
{
auto
new_layout
=
make_layout
(
new_shape
);
return
make_tensor
<
BufferAddressSpace
>
(
buffer_
.
p_data_
+
offset
,
new_layout
);
}
else
{
auto
new_strides
=
GetStridesFromSlicedTensor
(
idx
,
layout_
.
GetStrides
());
auto
new_layout
=
make_layout
(
new_shape
,
new_strides
);
return
make_tensor
<
BufferAddressSpace
>
(
buffer_
.
p_data_
+
offset
,
new_layout
);
}
}
template
<
typename
...
Ts
,
enable_if_t
<
IsSlicing
(
Tuple
<
Ts
...>{}),
bool
>
=
false
>
__host__
__device__
auto
operator
()(
const
Tuple
<
Ts
...
>&
idx
)
const
{
return
this
->
operator
[](
idx
);
}
template
<
typename
...
Idxs
,
enable_if_t
<
IsSlicing
(
Tuple
<
Idxs
...>{}),
bool
>
=
false
>
__host__
__device__
auto
operator
()(
Idxs
...
idxs
)
const
{
return
this
->
operator
[](
make_tuple
(
idxs
...));
}
// Getter for the const value
template
<
typename
...
Ts
,
enable_if_t
<!
IsSlicing
(
Tuple
<
Ts
...>{}),
bool
>
=
false
>
__host__
__device__
const
ElementType
&
operator
[](
const
Tuple
<
Ts
...
>&
idx
)
const
{
if
constexpr
(
IsDynamicBuffer
)
{
const
index_t
offset
=
layout_
(
idx
);
return
buffer_
[
offset
];
}
else
{
if
constexpr
(
is_same_v
<
Strides
,
Tuple
<>>
)
{
constexpr
index_t
offset
=
Layout
<
Shape
,
Strides
>
{
Shape
{}}.
template
operator
()
<
Tuple
<
Ts
...>
>
();
return
buffer_
[
Number
<
offset
>
{}];
}
else
{
constexpr
index_t
offset
=
Layout
<
Shape
,
Strides
>
{
Shape
{},
Strides
{}}.
template
operator
()
<
Tuple
<
Ts
...>
>
();
return
buffer_
[
Number
<
offset
>
{}];
}
}
}
template
<
typename
...
Ts
,
enable_if_t
<!
IsSlicing
(
Tuple
<
Ts
...>{}),
bool
>
=
false
>
__host__
__device__
const
ElementType
&
operator
()(
const
Tuple
<
Ts
...
>&
idx
)
const
{
return
this
->
operator
[](
idx
);
}
template
<
typename
...
Idxs
,
enable_if_t
<!
IsSlicing
(
Tuple
<
Idxs
...>{}),
bool
>
=
false
>
__host__
__device__
const
ElementType
&
operator
()(
Idxs
...
idxs
)
const
{
return
this
->
operator
[](
make_tuple
(
idxs
...));
}
// Getter for the value reference
template
<
typename
...
Ts
,
enable_if_t
<!
IsSlicing
(
Tuple
<
Ts
...>{}),
bool
>
=
false
>
__host__
__device__
ElementType
&
operator
[](
const
Tuple
<
Ts
...
>&
idx
)
{
if
constexpr
(
IsDynamicBuffer
)
{
const
index_t
offset
=
layout_
(
idx
);
return
buffer_
(
offset
);
}
else
{
if
constexpr
(
is_same_v
<
Strides
,
Tuple
<>>
)
{
constexpr
index_t
offset
=
Layout
<
Shape
,
Strides
>
{
Shape
{}}.
template
operator
()
<
Tuple
<
Ts
...>
>
();
return
buffer_
(
Number
<
offset
>
{});
}
else
{
constexpr
index_t
offset
=
Layout
<
Shape
,
Strides
>
{
Shape
{},
Strides
{}}.
template
operator
()
<
Tuple
<
Ts
...>
>
();
return
buffer_
(
Number
<
offset
>
{});
}
}
}
template
<
typename
...
Ts
,
enable_if_t
<!
IsSlicing
(
Tuple
<
Ts
...>{}),
bool
>
=
false
>
__host__
__device__
ElementType
&
operator
()(
const
Tuple
<
Ts
...
>&
idx
)
{
return
this
->
operator
[](
idx
);
}
template
<
typename
...
Idxs
,
enable_if_t
<!
IsSlicing
(
Tuple
<
Idxs
...>{}),
bool
>
=
false
>
__host__
__device__
ElementType
&
operator
()(
Idxs
...
idxs
)
{
return
this
->
operator
[](
make_tuple
(
idxs
...));
}
__host__
__device__
constexpr
auto
GetDefaultDescriptor
()
{
return
layout_
.
GetDefaultDescriptor
();
}
private:
using
DynamicBufferType
=
DynamicBuffer
<
BufferAddressSpace
,
ElementType
,
ElementSpaceSize
,
true
/*InvalidElementUseNumericalZeroValue*/
>
;
using
StaticBufferType
=
StaticBufferTupleOfVector
<
BufferAddressSpace
,
ElementType
,
NumVectors
,
ScalarPerVector
,
true
/*InvalidElementUseNumericalZeroValue*/
>
;
// If register use static buffer, else use dynamic buffer
using
Buffer
=
std
::
conditional_t
<
IsDynamicBuffer
,
DynamicBufferType
,
StaticBufferType
>
;
const
Layout
<
Shape
,
Strides
>
layout_
;
Buffer
buffer_
;
};
}
// namespace wrapper
}
// namespace ck
include/ck/wrapper/utils/layout_utils.hpp
0 → 100644
View file @
32806d5f
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/utility/number.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/utility/tuple_helper.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/utility/sequence_helper.hpp"
#include "ck/utility/is_detected.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
namespace
ck
{
namespace
wrapper
{
// Disable from doxygen docs generation
/// @cond
// forward declaration
template
<
typename
Shape
,
typename
Strides
>
struct
Layout
;
template
<
typename
T
>
using
is_tuple
=
decltype
(
std
::
declval
<
T
&>
().
IsTuple
());
/// @endcond
// make_*
/**
* \brief Make layout function.
*
* \tparam Shape Shape for layout.
* \tparam Strides Strides for layout.
* \return Constructed layout.
*/
template
<
typename
Shape
,
typename
Strides
>
__host__
__device__
constexpr
Layout
<
Shape
,
Strides
>
make_layout
(
const
Shape
&
shape
,
const
Strides
&
strides
)
{
return
Layout
<
Shape
,
Strides
>
(
shape
,
strides
);
}
/**
* \brief Make layout function with packed strides
* (column-major).
*
* \tparam Shape Shape for layout.
* \return Constructed layout.
*/
template
<
typename
Shape
>
__host__
__device__
constexpr
Layout
<
Shape
,
Tuple
<>>
make_layout
(
const
Shape
&
shape
)
{
return
Layout
<
Shape
,
Tuple
<>>
(
shape
);
}
// Layout helpers
// get
// Get dim (could be returned from get with empty Idxs)
/**
* \private
*/
template
<
typename
T
>
__host__
__device__
T
constexpr
get
(
const
T
&
dim
)
{
return
dim
;
}
/**
* \brief Get element from tuple (Shape/Strides/Idxs).
*
* \tparam idx Index to lookup.
* \param tuple Tuple to lookup.
* \return Requsted element.
*/
template
<
index_t
idx
,
typename
...
Dims
>
__host__
__device__
constexpr
auto
get
(
const
Tuple
<
Dims
...
>&
tuple
)
{
return
tuple
.
At
(
Number
<
idx
>
{});
}
/**
* \brief Get sub layout.
*
* \tparam idx Index to lookup.
* \param layout Layout to create sub layout.
* \return Requsted sub layout.
*/
template
<
index_t
idx
,
typename
Shape
,
typename
Strides
>
__host__
__device__
constexpr
auto
get
(
const
Layout
<
Shape
,
Strides
>&
layout
)
{
const
auto
&
shape
=
layout
.
GetShape
();
const
auto
&
new_shape
=
get
<
idx
>
(
shape
);
static_assert
(
is_detected
<
is_tuple
,
decltype
(
new_shape
)
>::
value
,
"Shape of sub layout must be tuple"
);
if
constexpr
(
is_same_v
<
Strides
,
Tuple
<>>
)
{
// If stride not passed, create without strides
return
make_layout
(
new_shape
);
}
else
{
const
auto
&
strides
=
layout
.
GetStrides
();
const
auto
&
new_strides
=
get
<
idx
>
(
strides
);
static_assert
(
is_detected
<
is_tuple
,
decltype
(
new_strides
)
>::
value
,
"Strides of sub layout must be tuple"
);
return
make_layout
(
new_shape
,
new_strides
);
}
}
/**
* \brief Hierarchical get.
*
* \tparam Idxs Indexes to lookup.
* \param elem Element to lookup.
* \return Requsted element.
*/
template
<
index_t
Idx
,
index_t
...
Idxs
,
typename
T
>
__host__
__device__
constexpr
auto
get
(
const
T
&
elem
)
{
return
get
<
Idxs
...
>
(
get
<
Idx
>
(
elem
));
}
// size
// Get dim size (could be returned from get function)
/**
* \private
*/
template
<
typename
T
>
__host__
__device__
T
constexpr
size
(
const
T
&
dim
)
{
return
dim
;
}
/**
* \brief Length get (product if tuple).
*
* \tparam idx Index to lookup.
* \param layout Layout to get Shape of.
* \return Requsted length.
*/
template
<
index_t
idx
,
typename
Shape
,
typename
Strides
>
__host__
__device__
constexpr
index_t
size
(
const
Layout
<
Shape
,
Strides
>&
layout
)
{
return
layout
.
template
GetLength
<
idx
>();
}
/**
* \brief Shape size (product of dims).
*
* \param shape Shape to lookup.
* \return Requsted size.
*/
template
<
typename
...
ShapeDims
>
__host__
__device__
constexpr
index_t
size
(
const
Tuple
<
ShapeDims
...
>&
shape
)
{
const
auto
unrolled_shape
=
UnrollNestedTuple
(
shape
);
return
TupleReduce
<
0
,
unrolled_shape
.
Size
()
>
([](
auto
x
,
auto
y
)
{
return
x
*
y
;
},
unrolled_shape
);
}
/**
* \brief Layout size (product of dims).
*
* \param layout Layout to calculate shape size.
* \return Requsted size.
*/
template
<
typename
Shape
,
typename
Strides
>
__host__
__device__
constexpr
index_t
size
(
const
Layout
<
Shape
,
Strides
>&
layout
)
{
return
layout
.
GetLengths
();
}
/**
* \brief Length get from tuple (product if tuple).
*
* \tparam idx Index to lookup.
* \param tuple Tuple to lookup.
* \return Requsted length.
*/
template
<
index_t
idx
,
typename
...
Ts
>
__host__
__device__
constexpr
index_t
size
(
const
Tuple
<
Ts
...
>&
tuple
)
{
return
size
(
tuple
.
At
(
Number
<
idx
>
{}));
}
/**
* \brief Hierarchical size.
*
* \tparam Idx First index to lookup (to avoid empty Idxs).
* \tparam Idxs Next indexes to lookup.
* \param elem Element to lookup.
* \return Requsted element.
*/
template
<
index_t
Idx
,
index_t
...
Idxs
,
typename
T
>
__host__
__device__
constexpr
auto
size
(
const
T
&
elem
)
{
return
size
(
get
<
Idx
,
Idxs
...
>
(
elem
));
}
// rank
/**
* \brief Get layout rank (num elements in shape).
*
* \param layout Layout to calculate rank.
* \return Requsted rank.
*/
template
<
typename
Shape
,
typename
Strides
>
__host__
__device__
constexpr
auto
rank
([[
maybe_unused
]]
const
Layout
<
Shape
,
Strides
>&
layout
)
{
return
Shape
::
Size
();
}
/**
* \brief Get tuple rank (num elements in tuple).
* Return 1 if scalar passed.
*
* \param tuple Tuple to calculate rank.
* \return Requsted rank.
*/
template
<
typename
...
Dims
>
__host__
__device__
constexpr
auto
rank
([[
maybe_unused
]]
const
Tuple
<
Dims
...
>&
tuple
)
{
return
Tuple
<
Dims
...
>::
Size
();
}
/**
* \private
*/
template
<
index_t
IDim
>
__host__
__device__
constexpr
index_t
rank
(
const
Number
<
IDim
>&
)
{
return
1
;
}
/**
* \private
*/
__host__
__device__
constexpr
index_t
rank
(
const
index_t
&
)
{
return
1
;
}
/**
* \brief Hierarchical rank.
*
* \tparam Idxs Indexes to lookup.
* \param elem Element to lookup.
* \return Requsted rank.
*/
template
<
index_t
...
Idxs
,
typename
T
>
__host__
__device__
constexpr
auto
rank
(
const
T
&
elem
)
{
return
rank
(
get
<
Idxs
...
>
(
elem
));
}
// depth
/**
* \brief Get depth of the layout shape (return 0 if scalar).
*
* \param layout Layout to calculate depth.
* \return Requsted depth.
*/
template
<
typename
Shape
,
typename
Strides
>
__host__
__device__
constexpr
auto
depth
(
const
Layout
<
Shape
,
Strides
>&
layout
)
{
const
auto
&
shape
=
layout
.
GetShape
();
return
TupleDepth
(
shape
);
}
/**
* \brief Get depth of the tuple. (return 0 if scalar)
*
* \param tuple Tuple to calculate depth.
* \return Requsted depth.
*/
template
<
typename
...
Dims
>
__host__
__device__
constexpr
auto
depth
(
const
Tuple
<
Dims
...
>&
tuple
)
{
return
TupleDepth
(
tuple
);
}
/**
* \private
*/
template
<
index_t
IDim
>
__host__
__device__
constexpr
index_t
depth
(
const
Number
<
IDim
>&
)
{
return
0
;
}
/**
* \private
*/
__host__
__device__
constexpr
index_t
depth
(
const
index_t
&
)
{
return
0
;
}
/**
* \brief Hierarchical depth.
*
* \tparam Idxs Indexes to lookup.
* \param elem Element to lookup.
* \return Requsted depth.
*/
template
<
index_t
...
Idxs
,
typename
T
>
__host__
__device__
constexpr
auto
depth
(
const
T
&
elem
)
{
return
depth
(
get
<
Idxs
...
>
(
elem
));
}
/**
* \brief Get Layout strides.
*
* \param layout Layout to get strides from.
* \return Requsted strides.
*/
template
<
typename
Shape
,
typename
Strides
>
__host__
__device__
constexpr
const
auto
&
stride
(
const
Layout
<
Shape
,
Strides
>&
layout
)
{
return
layout
.
GetStrides
();
}
/**
* \brief Get Layout shape.
*
* \param layout Layout to get shape from.
* \return Requsted shape.
*/
template
<
typename
Shape
,
typename
Strides
>
__host__
__device__
constexpr
const
auto
&
shape
(
const
Layout
<
Shape
,
Strides
>&
layout
)
{
return
layout
.
GetShape
();
}
}
// namespace wrapper
}
// namespace ck
include/ck/wrapper/utils/tensor_utils.hpp
0 → 100644
View file @
32806d5f
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/utility/number.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/utility/tuple_helper.hpp"
#include "ck/utility/dynamic_buffer.hpp"
#include "ck/utility/amd_address_space.hpp"
namespace
ck
{
namespace
wrapper
{
/**
* \brief Memory type, allowed members:
* - Generic,
* - Global,
* - LDS,
* - SGPR,
* - VGPR,
*/
using
MemoryTypeEnum
=
AddressSpaceEnum
;
// Disable from doxygen docs generation
/// @cond
// forward declarations
template
<
typename
Shape
,
typename
Strides
>
struct
Layout
;
template
<
MemoryTypeEnum
BufferAddressSpace
,
typename
ElementType
,
typename
Shape
,
typename
Strides
,
index_t
NumVectors
,
// params for Register memory
index_t
ScalarPerVector
// param for Register memory
>
struct
Tensor
;
template
<
typename
FromType
,
typename
ToType
>
struct
Slice
{
__host__
__device__
constexpr
Slice
()
:
from_
(),
to_
()
{}
__host__
__device__
constexpr
Slice
(
FromType
from
,
ToType
to
)
:
from_
(
from
),
to_
(
to
)
{}
template
<
typename
T
>
__host__
__device__
constexpr
auto
range
(
const
T
&
dim
)
const
{
if
constexpr
(
is_same_v
<
FromType
,
index_t
>
||
is_same_v
<
ToType
,
index_t
>
||
is_same_v
<
T
,
index_t
>
)
{
assert
(
dim
>=
to_
&&
from_
>=
0
&&
(
to_
<
0
||
to_
>
from_
)
&&
"Invalid range"
);
if
(
to_
<
0
)
{
return
dim
-
from_
+
to_
+
1
;
}
else
{
// workaround if one end of the interval is index_t and the second one is Number
return
static_cast
<
index_t
>
(
to_
)
-
static_cast
<
index_t
>
(
from_
);
}
}
else
{
static_assert
(
dim
>=
to_
&&
from_
>=
Number
<
0
>
{}
&&
(
to_
<
0
||
to_
>
from_
),
"Invalid range"
);
if
constexpr
(
to_
<
0
)
{
return
dim
-
from_
+
to_
+
Number
<
1
>
{};
}
else
{
return
to_
-
from_
;
}
}
}
__host__
__device__
static
constexpr
bool
IsSlice
()
{
return
true
;
}
const
FromType
from_
;
const
ToType
to_
;
};
template
<
typename
T
>
using
is_slice
=
decltype
(
std
::
declval
<
T
&>
().
IsSlice
());
template
<
typename
T
>
using
is_tuple
=
decltype
(
std
::
declval
<
T
&>
().
IsTuple
());
/// @endcond
/**
* \brief Make tensor function.
*
* \tparam MemoryType Type of memory.
* \param pointer Pointer to the memory.
* \param layout Tensor layout.
* \return Constructed tensor.
*/
template
<
MemoryTypeEnum
MemoryType
,
typename
ElementType
,
typename
Shape
,
typename
Strides
>
constexpr
auto
make_tensor
(
ElementType
*
pointer
,
const
Layout
<
Shape
,
Strides
>&
layout
)
{
return
Tensor
<
MemoryType
,
ElementType
,
Shape
,
Strides
,
0
/*NumVectors*/
,
0
/*ScalarPerVector*/
>
(
pointer
,
layout
);
}
/**
* \brief Make SGPR or VGPR tensor function.
*
* \tparam MemoryType Type of memory.
* \tparam NumVectors Number of vectors.
* \tparam ScalarPerVector Scalars per vector.
* \tparam ElementType Memory data type.
* \param layout Tensor layout.
* \return Constructed tensor.
*/
template
<
MemoryTypeEnum
MemoryType
,
index_t
NumVectors
,
index_t
ScalarPerVector
,
typename
ElementType
,
typename
Shape
,
typename
Strides
>
constexpr
auto
make_register_tensor
(
const
Layout
<
Shape
,
Strides
>&
layout
)
{
static_assert
(
!
IsNestedTuple
(
Shape
{}),
"Register tensor with nested layout is not supported"
);
return
Tensor
<
MemoryType
,
ElementType
,
Shape
,
Strides
,
NumVectors
,
ScalarPerVector
>
(
layout
);
}
/**
* \brief Get Tensor Layout.
*
* \param tensor Tensor to get layout of.
* \return Requsted layout.
*/
template
<
MemoryTypeEnum
BufferAddressSpace
,
typename
ElementType
,
typename
Shape
,
typename
Strides
,
index_t
NumVectors
,
index_t
ScalarPerVector
>
__host__
__device__
constexpr
const
auto
&
layout
(
const
Tensor
<
BufferAddressSpace
,
ElementType
,
Shape
,
Strides
,
NumVectors
,
ScalarPerVector
>&
tensor
)
{
return
tensor
.
GetLayout
();
}
/**
* \brief Product of tensor shape dims.
*
* \tparam Idxs Indexes to access specific shape dim (optional).
* \param tensor Tensor to get Shape of.
* \return Requsted size.
*/
template
<
index_t
...
Idxs
,
MemoryTypeEnum
BufferAddressSpace
,
typename
ElementType
,
typename
Shape
,
typename
Strides
,
index_t
NumVectors
,
index_t
ScalarPerVector
>
__host__
__device__
constexpr
index_t
size
(
const
Tensor
<
BufferAddressSpace
,
ElementType
,
Shape
,
Strides
,
NumVectors
,
ScalarPerVector
>&
tensor
)
{
return
size
<
Idxs
...
>
(
tensor
.
GetLayout
());
}
/**
* \brief Rank of Shape tuple.
*
* \tparam Idxs Indexes to access specific shape dim (optional).
* \param tensor Tensor to get rank of.
* \return Requsted rank.
*/
template
<
index_t
...
Idxs
,
MemoryTypeEnum
BufferAddressSpace
,
typename
ElementType
,
typename
Shape
,
typename
Strides
,
index_t
NumVectors
,
index_t
ScalarPerVector
>
__host__
__device__
constexpr
index_t
rank
(
const
Tensor
<
BufferAddressSpace
,
ElementType
,
Shape
,
Strides
,
NumVectors
,
ScalarPerVector
>&
tensor
)
{
return
rank
<
Idxs
...
>
(
tensor
.
GetLayout
());
}
/**
* \brief Depth of Shape tuple.
*
* \tparam Idxs Indexes to access specific shape dim (optional).
* \param tensor Tensor to get depth of.
* \return Requsted depth.
*/
template
<
index_t
...
Idxs
,
MemoryTypeEnum
BufferAddressSpace
,
typename
ElementType
,
typename
Shape
,
typename
Strides
,
index_t
NumVectors
,
index_t
ScalarPerVector
>
__host__
__device__
constexpr
index_t
depth
(
const
Tensor
<
BufferAddressSpace
,
ElementType
,
Shape
,
Strides
,
NumVectors
,
ScalarPerVector
>&
tensor
)
{
return
depth
<
Idxs
...
>
(
tensor
.
GetLayout
());
}
/**
* \brief Get Tensor strides.
*
* \param tensor Tensor to get strides from.
* \return Requsted strides.
*/
template
<
MemoryTypeEnum
BufferAddressSpace
,
typename
ElementType
,
typename
Shape
,
typename
Strides
,
index_t
NumVectors
,
index_t
ScalarPerVector
>
__host__
__device__
constexpr
const
auto
&
stride
(
const
Tensor
<
BufferAddressSpace
,
ElementType
,
Shape
,
Strides
,
NumVectors
,
ScalarPerVector
>&
tensor
)
{
return
stride
(
tensor
.
GetLayout
());
}
/**
* \brief Get Tensor shape.
*
* \param tensor Tensor to get shape from.
* \return Requsted shape.
*/
template
<
MemoryTypeEnum
BufferAddressSpace
,
typename
ElementType
,
typename
Shape
,
typename
Strides
,
index_t
NumVectors
,
index_t
ScalarPerVector
>
__host__
__device__
constexpr
const
auto
&
shape
(
const
Tensor
<
BufferAddressSpace
,
ElementType
,
Shape
,
Strides
,
NumVectors
,
ScalarPerVector
>&
tensor
)
{
return
shape
(
tensor
.
GetLayout
());
}
/**
* \brief Get dim slice.
*
* \param from Beginning of the interval.
* \param to End of the interval. (could be also negative to index from the end)
* \return Requested slice. Could be used to create sliced tensor from other tensor.
*/
template
<
typename
FromType
,
typename
ToType
>
constexpr
auto
slice
(
const
FromType
from
,
const
ToType
to
)
{
return
Slice
<
FromType
,
ToType
>
(
from
,
to
);
}
/**
* \brief Get dim slice. (Assumed that from is equal to 1)
*
* \param to End of the interval. (could be also negative to index from the end)
* \return Requested slice. Could be used to create sliced tensor from other tensor.
*/
template
<
typename
ToType
>
constexpr
auto
slice
(
const
ToType
to
)
{
if
constexpr
(
is_same_v
<
ToType
,
index_t
>
)
{
return
Slice
<
index_t
,
ToType
>
(
0
,
to
);
}
else
{
return
Slice
<
Number
<
0
>
,
ToType
>
(
Number
<
0
>
{},
to
);
}
}
/**
* \brief Get whole dim slice (from = 0, to = -1).
*
* \return Requested slice. Could be used to create sliced tensor from other tensor.
*/
constexpr
auto
slice
()
{
return
Slice
<
Number
<
0
>
,
Number
<-
1
>>
(
Number
<
0
>
{},
Number
<-
1
>
{});
}
}
// namespace wrapper
}
// namespace ck
library/include/ck/library/reference_tensor_operation/cpu/reference_groupnorm_bwd.hpp
View file @
32806d5f
...
@@ -16,6 +16,31 @@ namespace ck {
...
@@ -16,6 +16,31 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
host
{
namespace
host
{
// def normalization_backward_x(dy, x, gamma, x_mean, rstd, reduce_axis, reduce_size):
// ds = np.sum(dy * gamma * x, axis=reduce_axis, keepdims=True)
// db = np.sum(dy * gamma, axis=reduce_axis, keepdims=True)
// b = (db * x_mean - ds) * rstd ** (3) / reduce_size
// c = -b * x_mean - db * rstd / reduce_size
// dx = rstd * dy * gamma + b * x + c
// return dx
// def normalization_backward_gamma_beta(dy, x, x_mean, rstd, reduce_axis):
// # Assume shape of gamma and beta are the same
// dgamma = np.sum(dy * (x - x_mean) * rstd, axis=reduce_axis, keepdims=True)
// dbeta = np.sum(dy, axis=reduce_axis, keepdims=True)
// return dgamma, dbeta
// def groupnorm_backward(dy, x, gamma, x_mean, rstd):
// # dy, x = [N, H, W, G, C], gamma = [1, 1, 1, G, C], x_mean, rstd = [N, 1, 1, G, 1]
// N, H, W, G, C = x.shape
// dx = normalization_input_backward(
// dy, x, gamma, x_mean, rstd, (1, 2, 4), H * W * C)
// dgamma, dbeta = normalization_gamma_beta_backward(
// dy, x, x_mean, rstd, (0, 1, 2))
// return dx, dgamma, dbeta
// Reference (Layernorm and groupnorm):
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cpu/group_norm_kernel.cpp#L655
template
<
typename
DYDataType
,
template
<
typename
DYDataType
,
typename
XDataType
,
typename
XDataType
,
typename
GammaDataType
,
typename
GammaDataType
,
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_layernorm_bwd.hpp
View file @
32806d5f
...
@@ -16,6 +16,30 @@ namespace ck {
...
@@ -16,6 +16,30 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
host
{
namespace
host
{
// def normalization_backward_x(dy, x, gamma, x_mean, rstd, reduce_axis, reduce_size):
// ds = np.sum(dy * gamma * x, axis=reduce_axis, keepdims=True)
// db = np.sum(dy * gamma, axis=reduce_axis, keepdims=True)
// b = (db * x_mean - ds) * rstd ** (3) / reduce_size
// c = -b * x_mean - db * rstd / reduce_size
// dx = rstd * dy * gamma + b * x + c
// return dx
// def normalization_beta_backward_gamma_beta(dy, x, x_mean, rstd, reduce_axis):
// # Assume shape of gamma and beta are the same
// dgamma = np.sum(dy * (x - x_mean) * rstd, axis=reduce_axis, keepdims=True)
// dbeta = np.sum(dy, axis=reduce_axis, keepdims=True)
// return dgamma, dbeta
// def layernorm_backward(dy, x, gamma, x_mean, rstd):
// # dy, x = [M, K], gamma = [1, K], x_mean, rstd = [M, 1]
// # dx = [M, K], dgamma, dbeta = [1, K]
// M, K = x.shape
// dx = normalization_input_backward(dy, x, gamma, x_mean, rstd, 1, K)
// dgamma, dbeta = normalization_gamma_beta_backward(dy, x, x_mean, rstd, 0)
// return dx, dgamma, dbeta
// Reference (Layernorm and groupnorm):
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cpu/layer_norm_kernel.cpp#L196
template
<
typename
DYDataType
,
template
<
typename
DYDataType
,
typename
XDataType
,
typename
XDataType
,
typename
GammaDataType
,
typename
GammaDataType
,
...
...
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
View file @
32806d5f
...
@@ -86,9 +86,9 @@ using NHWGK = ck::tensor_layout::convolution::NHWGK;
...
@@ -86,9 +86,9 @@ using NHWGK = ck::tensor_layout::convolution::NHWGK;
using
NDHWGK
=
ck
::
tensor_layout
::
convolution
::
NDHWGK
;
using
NDHWGK
=
ck
::
tensor_layout
::
convolution
::
NDHWGK
;
//
//
using
GK
=
ck
::
tensor_layout
::
convolution
::
G_K
;
using
G
_
K
=
ck
::
tensor_layout
::
convolution
::
G_K
;
using
GK_Tuple
=
ck
::
Tuple
<
GK
>
;
using
GK_Tuple
=
ck
::
Tuple
<
G
_
K
>
;
using
GK_GK_Tuple
=
ck
::
Tuple
<
GK
,
GK
>
;
using
GK_GK_Tuple
=
ck
::
Tuple
<
G
_
K
,
G
_
K
>
;
// pointwise functor
// pointwise functor
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
...
...
library/include/ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp
View file @
32806d5f
...
@@ -61,7 +61,11 @@ using device_contraction_kk_instance = std::tuple<
...
@@ -61,7 +61,11 @@ using device_contraction_kk_instance = std::tuple<
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
128
,
128
,
32
,
16
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
ComputeDataType
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
128
,
128
,
32
,
16
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
ComputeDataType
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
128
,
32
,
128
,
16
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
16
>
,
4
,
ComputeDataType
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
128
,
32
,
128
,
16
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
16
>
,
4
,
ComputeDataType
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
16
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
,
ComputeDataType
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
16
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
,
ComputeDataType
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
64
,
32
,
64
,
16
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
,
ComputeDataType
>
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
64
,
32
,
64
,
16
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
,
ComputeDataType
>
,
// Small scalar per vector
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
16
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
1
,
ComputeDataType
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
128
,
128
,
32
,
16
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
4
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
ComputeDataType
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
16
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
ComputeDataType
>
// clang-format on
// clang-format on
>
;
>
;
...
@@ -96,7 +100,11 @@ using device_contraction_kn_instance = std::tuple<
...
@@ -96,7 +100,11 @@ using device_contraction_kn_instance = std::tuple<
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
16
,
4
,
1
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
ComputeDataType
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
16
,
4
,
1
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
ComputeDataType
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
16
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
ComputeDataType
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
16
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
ComputeDataType
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
16
,
4
,
1
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
ComputeDataType
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
16
,
4
,
1
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
ComputeDataType
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
16
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
ComputeDataType
>
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
16
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
ComputeDataType
>
,
// Small scalar per vector
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
16
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
1
,
ComputeDataType
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
128
,
128
,
32
,
16
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
4
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
ComputeDataType
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
16
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
ComputeDataType
>
// clang-format on
// clang-format on
>
;
>
;
...
@@ -131,7 +139,11 @@ using device_contraction_mk_instance = std::tuple<
...
@@ -131,7 +139,11 @@ using device_contraction_mk_instance = std::tuple<
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
16
,
1
,
4
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
ComputeDataType
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
16
,
1
,
4
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
ComputeDataType
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
16
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
ComputeDataType
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
16
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
ComputeDataType
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
16
,
1
,
4
,
32
,
32
,
1
,
2
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
ComputeDataType
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
16
,
1
,
4
,
32
,
32
,
1
,
2
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
ComputeDataType
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
16
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
ComputeDataType
>
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
16
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
ComputeDataType
>
,
// Small scalar per vector
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
16
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
1
,
ComputeDataType
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
128
,
128
,
32
,
16
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
ComputeDataType
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
16
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
ComputeDataType
>
// clang-format on
// clang-format on
>
;
>
;
...
@@ -166,7 +178,11 @@ using device_contraction_mn_instance = std::tuple<
...
@@ -166,7 +178,11 @@ using device_contraction_mn_instance = std::tuple<
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
16
,
1
,
1
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
ComputeDataType
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
16
,
1
,
1
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
ComputeDataType
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
16
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
ComputeDataType
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
16
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
ComputeDataType
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
16
,
1
,
1
,
32
,
32
,
1
,
2
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
ComputeDataType
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
16
,
1
,
1
,
32
,
32
,
1
,
2
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
ComputeDataType
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
16
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
ComputeDataType
>
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
16
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
ComputeDataType
>
,
// Small scalar per vector
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
16
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
1
,
ComputeDataType
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
128
,
128
,
32
,
16
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
ComputeDataType
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
16
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
ComputeDataType
>
// clang-format on
// clang-format on
>
;
>
;
...
...
library/
src
/tensor_operation_instance/gpu/
gemm/
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_instance.
c
pp
→
library/
include/ck/library
/tensor_operation_instance/gpu/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_instance.
h
pp
View file @
32806d5f
...
@@ -25,10 +25,6 @@ using S = ck::Sequence<Is...>;
...
@@ -25,10 +25,6 @@ using S = ck::Sequence<Is...>;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
MNKPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
template
<
ck
::
tensor_operation
::
device
::
GemmSpecialization
GemmSpec
>
template
<
ck
::
tensor_operation
::
device
::
GemmSpecialization
GemmSpec
>
using
device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances
=
std
::
tuple
<
using
device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances
=
std
::
tuple
<
...
@@ -37,7 +33,7 @@ using device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances = std::tuple<
...
@@ -37,7 +33,7 @@ using device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances = std::tuple<
//#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | |
//#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | |
//#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | |
//#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | |
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// pipeline v1, 1 wave
// pipeline v1, 1 wave
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F8
,
F8
,
F8
,
F32
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
256
,
256
,
128
,
64
,
16
,
4
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
16
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F8
,
F8
,
F8
,
F32
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
256
,
256
,
128
,
64
,
16
,
4
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
16
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F8
,
F8
,
F8
,
F32
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
256
,
256
,
128
,
64
,
16
,
16
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
16
,
1
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
16
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F8
,
F8
,
F8
,
F32
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
256
,
256
,
128
,
64
,
16
,
16
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
16
,
1
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
16
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F8
,
F8
,
F8
,
F32
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
256
,
128
,
256
,
64
,
16
,
4
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
16
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F8
,
F8
,
F8
,
F32
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
256
,
128
,
256
,
64
,
16
,
4
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
16
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
...
@@ -75,7 +71,8 @@ using device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances = std::tuple<
...
@@ -75,7 +71,8 @@ using device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances = std::tuple<
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F8
,
F8
,
F8
,
F32
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
256
,
64
,
128
,
64
,
16
,
16
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
16
,
1
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
16
,
LoopScheduler
::
Interwave
,
PipelineVersion
::
v1
>
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F8
,
F8
,
F8
,
F32
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
256
,
64
,
128
,
64
,
16
,
16
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
16
,
1
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
16
,
LoopScheduler
::
Interwave
,
PipelineVersion
::
v1
>
#endif
#endif
#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES
#if 0
//CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES
// pipeline v2, 1 wave
// pipeline v2, 1 wave
,
,
DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 256, 128, 64, 16, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v2>,
DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 256, 128, 64, 16, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v2>,
...
@@ -98,17 +95,6 @@ using device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances = std::tuple<
...
@@ -98,17 +95,6 @@ using device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances = std::tuple<
// clang-format on
// clang-format on
>
;
>
;
void
add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Row
,
Row
,
F8
,
F8
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances
<
GemmDefault
>
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances
<
MNKPadding
>
{});
}
}
// namespace instance
}
// namespace instance
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment