Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
29dcb956
Unverified
Commit
29dcb956
authored
Feb 08, 2024
by
Illia Silin
Committed by
GitHub
Feb 08, 2024
Browse files
Merge pull request #33 from ROCm/lwpck-1292
Merge from the public repo.
parents
29deceb6
cbcc844e
Changes
393
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2846 additions
and
53 deletions
+2846
-53
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp
+1
-1
include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp
...k/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp
+4
-3
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_data.hpp
...pu/grid/normalization/gridwise_normalization_bwd_data.hpp
+554
-0
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_gamma_beta.hpp
...d/normalization/gridwise_normalization_bwd_gamma_beta.hpp
+10
-1
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
+1
-1
include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp
...eration/operator_transform/transform_conv_fwd_to_gemm.hpp
+7
-8
include/ck/utility/amd_buffer_addressing.hpp
include/ck/utility/amd_buffer_addressing.hpp
+10
-0
include/ck/utility/amd_lds.hpp
include/ck/utility/amd_lds.hpp
+43
-0
include/ck/utility/amd_wmma.hpp
include/ck/utility/amd_wmma.hpp
+13
-10
include/ck/utility/amd_xdlops.hpp
include/ck/utility/amd_xdlops.hpp
+13
-9
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+65
-0
include/ck/utility/is_known_at_compile_time.hpp
include/ck/utility/is_known_at_compile_time.hpp
+7
-1
include/ck/utility/tuple_helper.hpp
include/ck/utility/tuple_helper.hpp
+111
-0
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+23
-19
include/ck/wrapper/layout.hpp
include/ck/wrapper/layout.hpp
+475
-0
include/ck/wrapper/operations/copy.hpp
include/ck/wrapper/operations/copy.hpp
+238
-0
include/ck/wrapper/operations/gemm.hpp
include/ck/wrapper/operations/gemm.hpp
+337
-0
include/ck/wrapper/tensor.hpp
include/ck/wrapper/tensor.hpp
+434
-0
include/ck/wrapper/traits/blockwise_gemm_xdl_traits.hpp
include/ck/wrapper/traits/blockwise_gemm_xdl_traits.hpp
+69
-0
include/ck/wrapper/utils/layout_utils.hpp
include/ck/wrapper/utils/layout_utils.hpp
+431
-0
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp
View file @
29dcb956
...
...
@@ -54,7 +54,7 @@ __global__ void
const
Block2CTileMap
block_2_ctile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__))
defined(__gfx94__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
...
...
include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp
View file @
29dcb956
...
...
@@ -35,9 +35,8 @@ __global__ void
const
Block2ETileMap
block_2_tile_map
,
const
ComputePtrOffsetOfStridedBatch
compute_ptr_offset_of_batch
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx1030__) || defined(__gfx1100__) || \
defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx941__) || defined(__gfx942__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__))
GridwiseTensorRearrangeKernel
::
Run
(
in_grid_desc
,
p_in_global
,
out_grid_desc
,
...
...
@@ -50,7 +49,9 @@ __global__ void
ignore
=
p_in_global
;
ignore
=
out_grid_desc
;
ignore
=
p_out_global
;
ignore
=
batch_count
;
ignore
=
block_2_tile_map
;
ignore
=
compute_ptr_offset_of_batch
;
#endif
}
...
...
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_data.hpp
0 → 100644
View file @
29dcb956
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp"
namespace
ck
{
// Tensor Shape
// dy, x = [M, K], gamma = [1, K], x_mean, inv_std = [M, 1]
// Flow:
// def normalization_backward_x(dy, x, gamma, x_mean, inv_std, reduce_axis, reduce_size):
// ds = np.sum(dy * gamma * x, axis=reduce_axis, keepdims=True)
// db = np.sum(dy * gamma, axis=reduce_axis, keepdims=True)
// b = (db * x_mean - ds) * inv_std ** (3) / reduce_size
// c = -b * x_mean - db * inv_std / reduce_size
// dx = inv_std * dy * gamma + b * x + c
// return dx
template
<
typename
DYDataType
,
typename
XDataType
,
typename
GammaDataType
,
typename
MeanInvStdDataType
,
typename
ComputeDataType
,
typename
DXDataType
,
typename
GridDesc_M_K
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
DYSrcVectorDim
,
index_t
DYSrcVectorSize
,
index_t
XSrcVectorDim
,
index_t
XSrcVectorSize
,
index_t
GammaSrcVectorDim
,
index_t
GammaSrcVectorSize
,
index_t
MeanInvStdSrcVectorDim
,
index_t
MeanInvStdSrcVectorSize
,
index_t
DXDstVectorDim
,
index_t
DXDstVectorSize
,
bool
SweepOnce
>
struct
GridwiseNormalizationBwdData_mk_to_mk
{
// if we just check ThreadSliceSize % VectorSize == 0, the performance may be poor (coalesce)
static_assert
(((
DYSrcVectorDim
==
0
&&
MThreadSliceSize
==
DYSrcVectorSize
)
||
(
DYSrcVectorDim
==
1
&&
KThreadSliceSize
==
DYSrcVectorSize
)),
"Invalid thread slice sizes and/or dy vector sizes configuration, please check!"
);
static_assert
(((
XSrcVectorDim
==
0
&&
MThreadSliceSize
==
XSrcVectorSize
)
||
(
XSrcVectorDim
==
1
&&
KThreadSliceSize
==
XSrcVectorSize
)),
"Invalid thread slice sizes and/or x vector sizes configuration, please check!"
);
static_assert
(
((
GammaSrcVectorDim
==
0
&&
MThreadSliceSize
==
GammaSrcVectorSize
)
||
(
GammaSrcVectorDim
==
1
&&
KThreadSliceSize
==
GammaSrcVectorSize
)),
"Invalid thread slice sizes and/or gamma vector sizes configuration, please check!"
);
static_assert
(
((
MeanInvStdSrcVectorDim
==
0
&&
MThreadSliceSize
==
MeanInvStdSrcVectorSize
)
||
(
MeanInvStdSrcVectorDim
==
1
&&
KThreadSliceSize
==
MeanInvStdSrcVectorSize
)),
"Invalid thread slice sizes and/or mean/inv_std vector sizes configuration, please check!"
);
static_assert
(((
DXDstVectorDim
==
0
&&
MThreadSliceSize
==
DXDstVectorSize
)
||
(
DXDstVectorDim
==
1
&&
KThreadSliceSize
==
DXDstVectorSize
)),
"Invalid thread slice sizes and/or dx vector sizes configuration, please check!"
);
using
ThreadClusterLengths_M_K
=
Sequence
<
MThreadClusterSize
,
KThreadClusterSize
>
;
using
DYThreadBufferDimAccessOrder
=
typename
conditional
<
DYSrcVectorDim
==
0
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
using
XThreadBufferDimAccessOrder
=
typename
conditional
<
XSrcVectorDim
==
0
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
using
GammaThreadBufferDimAccessOrder
=
typename
conditional
<
GammaSrcVectorDim
==
0
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
using
MeanInvStdThreadBufferDimAccessOrder
=
typename
conditional
<
MeanInvStdSrcVectorDim
==
0
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
using
DXThreadBufferDimAccessOrder
=
typename
conditional
<
DXDstVectorDim
==
0
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
using
ThreadClusterArrangeOrder
=
DYThreadBufferDimAccessOrder
;
static
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
ThreadClusterLengths_M_K
{},
ThreadClusterArrangeOrder
{});
using
ThreadBufferLengths_M_K
=
Sequence
<
MThreadSliceSize
,
KThreadSliceSize
>
;
static
constexpr
auto
thread_buffer_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{}));
static
constexpr
auto
thread_buffer_desc_m
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{}));
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
using
BlockwiseSumReduce
=
PartitionedBlockwiseReduction
<
ComputeDataType
,
BlockSize
,
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
,
reduce
::
Add
,
true
>
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
__device__
static
void
Run
(
const
GridDesc_M_K
&
dy_grid_desc_m_k
,
const
GridDesc_M_K
&
x_grid_desc_m_k
,
const
GridDesc_M_K
&
gamma_grid_desc_m_k
,
const
GridDesc_M_K
&
mean_grid_desc_m_k
,
const
GridDesc_M_K
&
inv_std_grid_desc_m_k
,
const
GridDesc_M_K
&
dx_grid_desc_m_k
,
index_t
num_k_block_tile_iteration
,
const
DYDataType
*
const
__restrict__
p_dy_global
,
const
XDataType
*
const
__restrict__
p_x_global
,
const
GammaDataType
*
const
__restrict__
p_gamma_global
,
const
MeanInvStdDataType
*
const
__restrict__
p_mean_global
,
const
MeanInvStdDataType
*
const
__restrict__
p_inv_std_global
,
DXDataType
*
const
__restrict__
p_dx_global
)
{
// LDS
__shared__
ComputeDataType
p_reduce_work_buffer
[
BlockSize
];
auto
reduce_work_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_reduce_work_buffer
,
BlockSize
);
// Global
const
auto
dy_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_dy_global
,
dy_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
x_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_x_global
,
x_grid_desc_m_k
.
GetElementSpaceSize
());
auto
gamma_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_gamma_global
,
gamma_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
mean_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_mean_global
,
mean_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
inv_std_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_inv_std_global
,
inv_std_grid_desc_m_k
.
GetElementSpaceSize
());
auto
dx_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_dx_global
,
dx_grid_desc_m_k
.
GetElementSpaceSize
());
// VGPR
auto
dy_thread_buf
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
{};
auto
x_thread_buf
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
{};
auto
gamma_thread_buf
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
{};
auto
mean_thread_buf
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
{};
auto
inv_std_thread_buf
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
{};
auto
dx_thread_buf
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
{};
auto
ds_thread_buf
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
,
true
>
{};
auto
db_thread_buf
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
,
true
>
{};
// thread id
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
const
auto
thread_cluster_idx
=
thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
thread_local_id
));
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
I0
];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
I1
];
// IO
auto
threadwise_dy_load
=
ThreadwiseTensorSliceTransfer_v2
<
DYDataType
,
ComputeDataType
,
GridDesc_M_K
,
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
DYThreadBufferDimAccessOrder
,
DYSrcVectorDim
,
DYSrcVectorSize
,
1
,
false
>
(
dy_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_x_load
=
ThreadwiseTensorSliceTransfer_v2
<
XDataType
,
ComputeDataType
,
GridDesc_M_K
,
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
XThreadBufferDimAccessOrder
,
XSrcVectorDim
,
XSrcVectorSize
,
1
,
false
>
(
x_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_gamma_load
=
ThreadwiseTensorSliceTransfer_v2
<
GammaDataType
,
ComputeDataType
,
GridDesc_M_K
,
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
XThreadBufferDimAccessOrder
,
GammaSrcVectorDim
,
GammaSrcVectorSize
,
1
,
false
>
(
gamma_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_mean_load
=
ThreadwiseTensorSliceTransfer_v2
<
MeanInvStdDataType
,
ComputeDataType
,
GridDesc_M_K
,
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
MeanInvStdThreadBufferDimAccessOrder
,
MeanInvStdSrcVectorDim
,
MeanInvStdSrcVectorSize
,
1
,
false
>
(
mean_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_inv_std_load
=
ThreadwiseTensorSliceTransfer_v2
<
MeanInvStdDataType
,
ComputeDataType
,
GridDesc_M_K
,
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
MeanInvStdThreadBufferDimAccessOrder
,
MeanInvStdSrcVectorDim
,
MeanInvStdSrcVectorSize
,
1
,
false
>
(
inv_std_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_dx_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
ComputeDataType
,
DXDataType
,
decltype
(
thread_buffer_desc_m_k
),
GridDesc_M_K
,
PassThroughOp
,
ThreadBufferLengths_M_K
,
DXThreadBufferDimAccessOrder
,
DXDstVectorDim
,
DXDstVectorSize
,
InMemoryDataOperationEnum
::
Set
,
1
,
false
>
(
dx_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
),
PassThroughOp
{});
ComputeDataType
reduce_size
=
type_convert
<
ComputeDataType
>
(
dy_grid_desc_m_k
.
GetTransforms
()[
I2
].
GetUpperLengths
()[
I0
]);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
ds_thread_buf
(
I
)
=
type_convert
<
ComputeDataType
>
(
0.0
f
);
db_thread_buf
(
I
)
=
type_convert
<
ComputeDataType
>
(
0.0
f
);
});
// Separate sweep once and sweep twice pipeline
// Sweep once: for small k, if KThreadClusterSize * KThreadSliceSize > K
// we don't need to use loop to read x, dy, gamma twice
if
constexpr
(
SweepOnce
)
{
threadwise_dy_load
.
Run
(
dy_grid_desc_m_k
,
dy_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
dy_thread_buf
);
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
x_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
x_thread_buf
);
threadwise_gamma_load
.
Run
(
gamma_grid_desc_m_k
,
gamma_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
gamma_thread_buf
);
threadwise_mean_load
.
Run
(
mean_grid_desc_m_k
,
mean_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
mean_thread_buf
);
threadwise_inv_std_load
.
Run
(
inv_std_grid_desc_m_k
,
inv_std_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
inv_std_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
constexpr
auto
offset_m
=
Number
<
thread_buffer_desc_m
.
CalculateOffset
(
make_tuple
(
iM
))
>
{};
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset_m_k
=
Number
<
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
ds_thread_buf
(
offset_m
)
+=
dy_thread_buf
[
offset_m_k
]
*
gamma_thread_buf
[
offset_m_k
]
*
x_thread_buf
[
offset_m_k
];
db_thread_buf
(
offset_m
)
+=
dy_thread_buf
[
offset_m_k
]
*
gamma_thread_buf
[
offset_m_k
];
});
});
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
if
constexpr
(
I
>
0
)
block_sync_lds
();
BlockwiseSumReduce
::
Reduce
(
reduce_work_buf
,
ds_thread_buf
(
I
));
block_sync_lds
();
BlockwiseSumReduce
::
Reduce
(
reduce_work_buf
,
db_thread_buf
(
I
));
});
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
constexpr
auto
offset_m
=
Number
<
thread_buffer_desc_m
.
CalculateOffset
(
make_tuple
(
iM
))
>
{};
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset_m_k
=
Number
<
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
// b = (db * x_mean - ds) * rstd ** (3) / reduce_size
// c = -b * x_mean - db * rstd / reduce_size
// dx = rstd * dy * gamma + b * x + c
ComputeDataType
b
=
db_thread_buf
[
offset_m
]
*
mean_thread_buf
[
offset_m_k
]
-
ds_thread_buf
[
offset_m
];
b
*=
inv_std_thread_buf
[
offset_m_k
]
*
inv_std_thread_buf
[
offset_m_k
]
*
inv_std_thread_buf
[
offset_m_k
]
/
reduce_size
;
ComputeDataType
c
=
-
b
*
mean_thread_buf
(
offset_m_k
);
c
-=
db_thread_buf
[
offset_m
]
*
inv_std_thread_buf
[
offset_m_k
]
/
reduce_size
;
dx_thread_buf
(
offset_m_k
)
=
dy_thread_buf
[
offset_m_k
]
*
gamma_thread_buf
[
offset_m_k
]
*
inv_std_thread_buf
[
offset_m_k
]
+
b
*
x_thread_buf
[
offset_m_k
]
+
c
;
});
});
threadwise_dx_store
.
Run
(
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
dx_thread_buf
,
dx_grid_desc_m_k
,
dx_global_val_buf
);
}
// end of sweep once
else
// Sweep Twice pipeline
{
constexpr
auto
thread_copy_fwd_step_m_k
=
make_multi_index
(
0
,
K_BlockTileSize
);
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_k_block_tile_iteration
;
++
reducedTiles
)
{
threadwise_dy_load
.
Run
(
dy_grid_desc_m_k
,
dy_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
dy_thread_buf
);
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
x_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
x_thread_buf
);
threadwise_gamma_load
.
Run
(
gamma_grid_desc_m_k
,
gamma_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
gamma_thread_buf
);
threadwise_dy_load
.
MoveSrcSliceWindow
(
dy_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
constexpr
auto
offset_m
=
Number
<
thread_buffer_desc_m
.
CalculateOffset
(
make_tuple
(
iM
))
>
{};
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset_m_k
=
Number
<
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
ds_thread_buf
(
offset_m
)
+=
dy_thread_buf
[
offset_m_k
]
*
gamma_thread_buf
[
offset_m_k
]
*
x_thread_buf
[
offset_m_k
];
db_thread_buf
(
offset_m
)
+=
dy_thread_buf
[
offset_m_k
]
*
gamma_thread_buf
[
offset_m_k
];
});
});
}
// end of first sweep
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
if
constexpr
(
I
>
0
)
block_sync_lds
();
BlockwiseSumReduce
::
Reduce
(
reduce_work_buf
,
ds_thread_buf
(
I
));
block_sync_lds
();
BlockwiseSumReduce
::
Reduce
(
reduce_work_buf
,
db_thread_buf
(
I
));
});
// reverse read for using dy, gamma and x in the cache
constexpr
auto
thread_copy_bwd_step_m_k
=
make_multi_index
(
0
,
-
K_BlockTileSize
);
auto
thread_copy_tail_m_k
=
(
num_k_block_tile_iteration
-
1
)
*
thread_copy_fwd_step_m_k
;
// move to tail
threadwise_dy_load
.
MoveSrcSliceWindow
(
dy_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
// move from start to tail
threadwise_mean_load
.
MoveSrcSliceWindow
(
mean_grid_desc_m_k
,
thread_copy_tail_m_k
);
threadwise_inv_std_load
.
MoveSrcSliceWindow
(
inv_std_grid_desc_m_k
,
thread_copy_tail_m_k
);
threadwise_dx_store
.
MoveDstSliceWindow
(
dx_grid_desc_m_k
,
thread_copy_tail_m_k
);
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_k_block_tile_iteration
;
++
reducedTiles
)
{
threadwise_dy_load
.
Run
(
dy_grid_desc_m_k
,
dy_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
dy_thread_buf
);
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
x_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
x_thread_buf
);
threadwise_gamma_load
.
Run
(
gamma_grid_desc_m_k
,
gamma_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
gamma_thread_buf
);
threadwise_mean_load
.
Run
(
mean_grid_desc_m_k
,
mean_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
mean_thread_buf
);
threadwise_inv_std_load
.
Run
(
inv_std_grid_desc_m_k
,
inv_std_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
inv_std_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
constexpr
auto
offset_m
=
Number
<
thread_buffer_desc_m
.
CalculateOffset
(
make_tuple
(
iM
))
>
{};
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset_m_k
=
Number
<
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
// b = (db * x_mean - ds) * rstd ** (3) / reduce_size
// c = -b * x_mean - db * rstd / reduce_size
// dx = rstd * dy * gamma + b * x + c
ComputeDataType
b
=
db_thread_buf
[
offset_m
]
*
mean_thread_buf
[
offset_m_k
]
-
ds_thread_buf
[
offset_m
];
b
*=
inv_std_thread_buf
[
offset_m_k
]
*
inv_std_thread_buf
[
offset_m_k
]
*
inv_std_thread_buf
[
offset_m_k
]
/
reduce_size
;
ComputeDataType
c
=
-
b
*
mean_thread_buf
(
offset_m_k
);
c
-=
db_thread_buf
[
offset_m
]
*
inv_std_thread_buf
[
offset_m_k
]
/
reduce_size
;
dx_thread_buf
(
offset_m_k
)
=
dy_thread_buf
[
offset_m_k
]
*
gamma_thread_buf
[
offset_m_k
]
*
inv_std_thread_buf
[
offset_m_k
]
+
b
*
x_thread_buf
[
offset_m_k
]
+
c
;
});
});
threadwise_dx_store
.
Run
(
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
dx_thread_buf
,
dx_grid_desc_m_k
,
dx_global_val_buf
);
threadwise_dy_load
.
MoveSrcSliceWindow
(
dy_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_mean_load
.
MoveSrcSliceWindow
(
mean_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_inv_std_load
.
MoveSrcSliceWindow
(
inv_std_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_dx_store
.
MoveDstSliceWindow
(
dx_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
}
}
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_gamma_beta.hpp
View file @
29dcb956
...
...
@@ -35,7 +35,7 @@ template <typename DYDataType,
index_t
DBetaDstVectorSize
>
struct
GridwiseNormalizationBwdGammaBeta_mk_to_k
{
// if we just check ThreadSliceSize
&
VectorSize == 0, the performance may be poor
// if we just check ThreadSliceSize
%
VectorSize == 0, the performance may be poor
(coalesce)
static_assert
(((
DYSrcVectorDim
==
0
&&
MThreadSliceSize
==
DYSrcVectorSize
)
||
(
DYSrcVectorDim
==
1
&&
KThreadSliceSize
==
DYSrcVectorSize
)),
"Invalid thread slice sizes and/or dy vector sizes configuration, please check!"
);
...
...
@@ -44,6 +44,15 @@ struct GridwiseNormalizationBwdGammaBeta_mk_to_k
(
XSrcVectorDim
==
1
&&
KThreadSliceSize
==
XSrcVectorSize
)),
"Invalid thread slice sizes and/or x vector sizes configuration, please check!"
);
// do not force SliceSize == MeanInvStdSrcVectorSize for groupnorm
static_assert
(
((
MeanInvStdSrcVectorDim
==
0
&&
MThreadSliceSize
%
MeanInvStdSrcVectorSize
==
0
)
||
(
MeanInvStdSrcVectorDim
==
1
&&
KThreadSliceSize
%
MeanInvStdSrcVectorSize
==
0
)),
"Invalid thread slice sizes and/or mean/inv_std vector sizes configuration, please check!"
);
static_assert
(
MThreadSliceSize
==
DGammaDstVectorSize
&&
MThreadSliceSize
==
DBetaDstVectorSize
,
"Invalid thread slice sizes and/or dx vector sizes configuration, please check!"
);
using
ThreadClusterLengths_M_K
=
Sequence
<
MThreadClusterSize
,
KThreadClusterSize
>
;
using
DYThreadBufferDimAccessOrder
=
...
...
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
View file @
29dcb956
...
...
@@ -328,7 +328,7 @@ struct WmmaSelector
}
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template
<
>
static
constexpr
auto
GetWmma
<
int4_t
,
int
,
16
,
16
>
()
static
constexpr
auto
GetWmma
<
int4_t
,
int4_t
,
int
,
16
,
16
>
()
{
return
WmmaInstr
::
wmma_i32_16x16x16_iu4
;
}
...
...
include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp
View file @
29dcb956
...
...
@@ -522,22 +522,21 @@ struct TransformConvFwdToGemm
// for output bias
template
<
typename
CLayout
,
typename
std
::
enable_if
<
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
GK
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_K
>
,
typename
std
::
enable_if
<
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_K
>,
bool
>::
type
=
false
>
static
auto
MakeCDescriptor_M_N
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* c_g_n_k_wos_strides */
)
static
auto
MakeCDescriptor_M_N
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_strides
)
{
const
index_t
N
=
c_g_n_k_wos_lengths
[
1
];
const
index_t
K
=
c_g_n_k_wos_lengths
[
2
];
const
index_t
N
=
c_g_n_k_wos_lengths
[
1
];
const
index_t
K
=
c_g_n_k_wos_lengths
[
2
];
const
index_t
KStride
=
c_g_n_k_wos_strides
[
2
];
const
index_t
NHoWo
=
N
*
ck
::
accumulate_n
<
index_t
>
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
const
auto
out_gemmm_gemmn_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
NHoWo
,
K
),
make_tuple
(
I0
,
I1
));
make_naive_tensor_descriptor
(
make_tuple
(
NHoWo
,
K
),
make_tuple
(
I0
,
KStride
));
return
out_gemmm_gemmn_desc
;
}
...
...
include/ck/utility/amd_buffer_addressing.hpp
View file @
29dcb956
...
...
@@ -972,6 +972,15 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr,
const
int32x4_t
src_resource
=
make_wave_buffer_resource
(
global_ptr
,
src_element_space_size
);
const
index_t
global_offset_bytes
=
is_valid
?
global_offset
*
sizeof
(
T
)
:
0x80000000
;
#if CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM
T
*
lds_ptr
=
lds_base_ptr
+
lds_offset
;
auto
const
lds_ptr_sgpr
=
__builtin_amdgcn_readfirstlane
((
reinterpret_cast
<
uintptr_t
>
(
lds_ptr
)));
asm
volatile
(
"s_mov_b32 m0, %0;
\n\t
"
"buffer_load_dword %1, %2, 0 offen lds;
\n\t
"
::
"s"
(
lds_ptr_sgpr
),
"v"
(
global_offset_bytes
),
"s"
(
src_resource
));
#else
// LDS pointer must be attributed with the LDS address space.
__attribute__
((
address_space
(
3
)))
uint32_t
*
lds_ptr
=
reinterpret_cast
<
__attribute__
((
address_space
(
3
)))
uint32_t
*>
(
...
...
@@ -979,6 +988,7 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr,
llvm_amdgcn_raw_buffer_load_lds
(
src_resource
,
lds_ptr
,
sizeof
(
uint32_t
),
global_offset_bytes
,
0
,
0
,
0
);
#endif
}
}
// namespace ck
include/ck/utility/amd_lds.hpp
0 → 100644
View file @
29dcb956
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/amd_address_space.hpp"
#include "ck/utility/dynamic_buffer.hpp"
#include "ck/utility/math.hpp"
namespace
ck
{
namespace
lds_utils
{
/** \brief Allocate a given number of buffers in LDS and return them as a tuple.
*
* \tparam DataType Data type of elements to be stored in LDS.
* \tparam NumBuffers Number of buffers to be allocated.
* \param lds_ptr Address of the beginning of LDS space.
* \param num_elems_per_buffer Number of elements to allocate per single buffer.
* \param start_offset_elems Number of elements to move from the start of LDS for the allocation of
* the first buffer. \param lds_alignment Alignment of every buffer allocation given as a number of
* elements. \return Tuple of dynamic buffers representing memory allocated in LDS.
*/
template
<
typename
DataType
,
index_t
NumBuffers
>
__device__
static
auto
AllocateLdsBuffers
(
void
*
lds_ptr
,
int32_t
num_elems_per_buffer
,
int32_t
start_offset_elems
,
int32_t
lds_alignment
)
{
const
DataType
*
lds_start
=
static_cast
<
DataType
*>
(
lds_ptr
)
+
start_offset_elems
;
const
int32_t
single_buffer_offset
=
math
::
integer_least_multiple
(
num_elems_per_buffer
,
lds_alignment
);
return
generate_tuple
(
[
&
](
auto
i
)
{
const
int32_t
local_offset
=
i
*
single_buffer_offset
;
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
lds_start
+
local_offset
,
num_elems_per_buffer
);
},
Number
<
NumBuffers
>
{});
}
}
// namespace lds_utils
}
// namespace ck
include/ck/utility/amd_wmma.hpp
View file @
29dcb956
...
...
@@ -9,6 +9,9 @@
// TODO: Add arch limitation
namespace
ck
{
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__)
#define __gfx11__
#endif
/********************************WAVE32 MODE***********************************************/
// src: fp16, dst: fp32
...
...
@@ -25,7 +28,7 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16>
// delete them.
// amd_assembly_wmma_f32_16x16x16_f16_w32(
// reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{}));
#if defined(__gfx11
00__) || defined(__gfx1101__) || defined(__gfx1102
__)
#if defined(__gfx11__)
reg_c
.
template
AsType
<
float8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f32_16x16x16_f16_w32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float8_t
>()[
Number
<
0
>
{}]);
#else
...
...
@@ -46,7 +49,7 @@ struct intrin_wmma_f32_16x16x16_bf16_w32<16, 16>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
bhalf16_t
&
reg_a
,
const
bhalf16_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx11
00__) || defined(__gfx1101__) || defined(__gfx1102
__)
#if defined(__gfx11__)
reg_c
.
template
AsType
<
float8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float8_t
>()[
Number
<
0
>
{}]);
...
...
@@ -71,7 +74,7 @@ struct intrin_wmma_f16_16x16x16_f16_w32<16, 16, Opsel>
// opsel usage
// false: D0.[0:15] = result
// true : D0.[16:31]= result
#if defined(__gfx11
00__) || defined(__gfx1101__) || defined(__gfx1102
__)
#if defined(__gfx11__)
reg_c
.
template
AsType
<
half16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f16_16x16x16_f16_w32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
half16_t
>()[
Number
<
0
>
{}],
Opsel
);
#else
...
...
@@ -95,7 +98,7 @@ struct intrin_wmma_bf16_16x16x16_bf16_w32<16, 16, Opsel>
// opsel usage
// false: D0.[0:15] = result
// true : D0.[16:31]= result
#if defined(__gfx11
00__) || defined(__gfx1101__) || defined(__gfx1102
__)
#if defined(__gfx11__)
reg_c
.
template
AsType
<
bhalf16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
bhalf16_t
>()[
Number
<
0
>
{}],
Opsel
);
...
...
@@ -117,7 +120,7 @@ struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
int8x16_t
&
reg_a
,
const
int8x16_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx11
00__) || defined(__gfx1101__) || defined(__gfx1102
__)
#if defined(__gfx11__)
reg_c
.
template
AsType
<
int32x8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32
(
neg_a
,
...
...
@@ -145,7 +148,7 @@ struct intrin_wmma_f32_16x16x16_f16_w64<16, 16>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half16_t
&
reg_a
,
const
half16_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx11
00__) || defined(__gfx1101__) || defined(__gfx1102
__)
#if defined(__gfx11__)
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f32_16x16x16_f16_w64
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}]);
#else
...
...
@@ -166,7 +169,7 @@ struct intrin_wmma_f32_16x16x16_bf16_w64<16, 16>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
bhalf16_t
&
reg_a
,
const
bhalf16_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx11
00__) || defined(__gfx1101__) || defined(__gfx1102
__)
#if defined(__gfx11__)
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f32_16x16x16_bf16_w64
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}]);
...
...
@@ -191,7 +194,7 @@ struct intrin_wmma_f16_16x16x16_f16_w64<16, 16, Opsel>
// opsel usage
// false: D0.[0:15] = result
// true : D0.[16:31]= result
#if defined(__gfx11
00__) || defined(__gfx1101__) || defined(__gfx1102
__)
#if defined(__gfx11__)
reg_c
.
template
AsType
<
half8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f16_16x16x16_f16_w64
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
half8_t
>()[
Number
<
0
>
{}],
Opsel
);
#else
...
...
@@ -215,7 +218,7 @@ struct intrin_wmma_bf16_16x16x16_bf16_w64<16, 16, Opsel>
// opsel usage
// false: D0.[0:15] = result
// true : D0.[16:31]= result
#if defined(__gfx11
00__) || defined(__gfx1101__) || defined(__gfx1102
__)
#if defined(__gfx11__)
reg_c
.
template
AsType
<
bhalf8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
bhalf8_t
>()[
Number
<
0
>
{}],
Opsel
);
...
...
@@ -237,7 +240,7 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
int8x16_t
&
reg_a
,
const
int8x16_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx11
00__) || defined(__gfx1101__) || defined(__gfx1102
__)
#if defined(__gfx11__)
reg_c
.
template
AsType
<
int32x4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64
(
neg_a
,
...
...
include/ck/utility/amd_xdlops.hpp
View file @
29dcb956
...
...
@@ -4,6 +4,10 @@
#pragma once
namespace
ck
{
// Define the common macro for MI300 models
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#define __gfx94__
#endif
// fp32
template
<
index_t
MPerWave
,
index_t
NPerWave
>
...
...
@@ -341,7 +345,7 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
double
&
reg_a
,
const
double
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx90a__) || defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx90a__) || defined(__gfx94__)
reg_c
.
template
AsType
<
double4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f64_16x16x4f64
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
double4_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
#else
...
...
@@ -361,7 +365,7 @@ struct intrin_mfma_f32_32x32x16f8f8<32, 32>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
f8x8_t
&
reg_a
,
const
f8x8_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx94__)
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8
(
bit_cast
<
long
>
(
reg_a
),
...
...
@@ -393,7 +397,7 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
f8x8_t
&
reg_a
,
const
f8x8_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx94__)
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8
(
bit_cast
<
long
>
(
reg_a
),
bit_cast
<
long
>
(
reg_b
),
...
...
@@ -424,7 +428,7 @@ struct intrin_mfma_f32_32x32x16bf8bf8<32, 32>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
bf8x8_t
&
reg_a
,
const
bf8x8_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx94__)
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8
(
bit_cast
<
long
>
(
reg_a
),
...
...
@@ -456,7 +460,7 @@ struct intrin_mfma_f32_16x16x32bf8bf8<16, 16>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
bf8x8_t
&
reg_a
,
const
bf8x8_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx94__)
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8
(
bit_cast
<
long
>
(
reg_a
),
bit_cast
<
long
>
(
reg_b
),
...
...
@@ -487,7 +491,7 @@ struct intrin_mfma_f32_32x32x16f8bf8<32, 32>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
f8x8_t
&
reg_a
,
const
bf8x8_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx94__)
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8
(
bit_cast
<
long
>
(
reg_a
),
...
...
@@ -519,7 +523,7 @@ struct intrin_mfma_f32_16x16x32f8bf8<16, 16>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
f8x8_t
&
reg_a
,
const
bf8x8_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx94__)
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8
(
bit_cast
<
long
>
(
reg_a
),
bit_cast
<
long
>
(
reg_b
),
...
...
@@ -550,7 +554,7 @@ struct intrin_mfma_f32_32x32x16bf8f8<32, 32>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
bf8x8_t
&
reg_a
,
const
f8x8_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx94__)
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8
(
bit_cast
<
long
>
(
reg_a
),
...
...
@@ -582,7 +586,7 @@ struct intrin_mfma_f32_16x16x32bf8f8<16, 16>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
bf8x8_t
&
reg_a
,
const
f8x8_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx94__)
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8
(
bit_cast
<
long
>
(
reg_a
),
bit_cast
<
long
>
(
reg_b
),
...
...
include/ck/utility/data_type.hpp
View file @
29dcb956
...
...
@@ -189,6 +189,7 @@ struct vector_type<T, 1>
}
};
int
static
err
=
0
;
template
<
typename
T
>
struct
vector_type
<
T
,
2
>
{
...
...
@@ -221,6 +222,10 @@ struct vector_type<T, 2>
{
return
data_
.
d2x1_
;
}
else
{
return
err
;
}
}
template
<
typename
X
>
...
...
@@ -236,6 +241,10 @@ struct vector_type<T, 2>
{
return
data_
.
d2x1_
;
}
else
{
return
err
;
}
}
};
...
...
@@ -278,6 +287,10 @@ struct vector_type<T, 4>
{
return
data_
.
d4x1_
;
}
else
{
return
err
;
}
}
template
<
typename
X
>
...
...
@@ -298,6 +311,10 @@ struct vector_type<T, 4>
{
return
data_
.
d4x1_
;
}
else
{
return
err
;
}
}
};
...
...
@@ -347,6 +364,10 @@ struct vector_type<T, 8>
{
return
data_
.
d8x1_
;
}
else
{
return
err
;
}
}
template
<
typename
X
>
...
...
@@ -372,6 +393,10 @@ struct vector_type<T, 8>
{
return
data_
.
d8x1_
;
}
else
{
return
err
;
}
}
};
...
...
@@ -428,6 +453,10 @@ struct vector_type<T, 16>
{
return
data_
.
d16x1_
;
}
else
{
return
err
;
}
}
template
<
typename
X
>
...
...
@@ -458,6 +487,10 @@ struct vector_type<T, 16>
{
return
data_
.
d16x1_
;
}
else
{
return
err
;
}
}
};
...
...
@@ -520,6 +553,10 @@ struct vector_type<T, 32>
{
return
data_
.
d32x1_
;
}
else
{
return
err
;
}
}
template
<
typename
X
>
...
...
@@ -554,6 +591,10 @@ struct vector_type<T, 32>
{
return
data_
.
d32x1_
;
}
else
{
return
err
;
}
}
};
...
...
@@ -623,6 +664,10 @@ struct vector_type<T, 64>
{
return
data_
.
d64x1_
;
}
else
{
return
err
;
}
}
template
<
typename
X
>
...
...
@@ -662,6 +707,10 @@ struct vector_type<T, 64>
{
return
data_
.
d64x1_
;
}
else
{
return
err
;
}
}
};
...
...
@@ -737,6 +786,10 @@ struct vector_type<T, 128>
{
return
data_
.
d128x1_
;
}
else
{
return
err
;
}
}
template
<
typename
X
>
...
...
@@ -780,6 +833,10 @@ struct vector_type<T, 128>
{
return
data_
.
d128x1_
;
}
else
{
return
err
;
}
}
};
...
...
@@ -861,6 +918,10 @@ struct vector_type<T, 256>
{
return
data_
.
d256x1_
;
}
else
{
return
err
;
}
}
template
<
typename
X
>
...
...
@@ -908,6 +969,10 @@ struct vector_type<T, 256>
{
return
data_
.
d256x1_
;
}
else
{
return
err
;
}
}
};
...
...
include/ck/utility/is_known_at_compile_time.hpp
View file @
29dcb956
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -19,6 +19,12 @@ struct is_known_at_compile_time<index_t>
static
constexpr
bool
value
=
false
;
};
template
<
>
struct
is_known_at_compile_time
<
unsigned
int
>
{
static
constexpr
bool
value
=
false
;
};
template
<
>
struct
is_known_at_compile_time
<
long_index_t
>
{
...
...
include/ck/utility/tuple_helper.hpp
View file @
29dcb956
...
...
@@ -5,6 +5,7 @@
#include "functional4.hpp"
#include "tuple.hpp"
#include "is_detected.hpp"
namespace
ck
{
...
...
@@ -33,6 +34,28 @@ __host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple<X&...>&
ty
);
}
template
<
typename
...
X
,
typename
...
Y
>
__host__
__device__
constexpr
auto
concat_tuple
(
const
Tuple
<
X
...
>&
tx
,
const
Tuple
<
Y
...
>&
ty
)
{
return
unpack2
(
[
&
](
auto
...
zs
)
{
return
Tuple
<
decltype
(
zs
)...
>
{
std
::
forward
<
decltype
(
zs
)
>
(
zs
)...};
},
tx
,
ty
);
}
// Support any number of tuples to concat (also 1)
template
<
typename
...
X
>
__host__
__device__
constexpr
auto
concat_tuple
(
const
Tuple
<
X
...
>&
tx
)
{
return
tx
;
}
template
<
typename
...
X
,
typename
...
Tuples
>
__host__
__device__
constexpr
auto
concat_tuple
(
const
Tuple
<
X
...
>&
tx
,
const
Tuples
&
...
tuples
)
{
return
concat_tuple
(
tx
,
concat_tuple
(
tuples
...));
}
namespace
detail
{
template
<
typename
F
,
typename
X
,
index_t
...
Is
>
...
...
@@ -78,4 +101,92 @@ __host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y,
f
,
x
,
y
,
z
,
typename
arithmetic_sequence_gen
<
0
,
X
::
Size
(),
1
>::
type
{});
}
// By default unroll to the flatten
template
<
index_t
Depth
=
0
,
index_t
MaxDepth
=
-
1
>
__host__
__device__
constexpr
auto
UnrollNestedTuple
(
const
Tuple
<>&
element
)
{
return
element
;
}
template
<
index_t
Depth
=
0
,
index_t
MaxDepth
=
-
1
,
typename
T
>
__host__
__device__
constexpr
auto
UnrollNestedTuple
(
const
T
&
element
)
{
return
make_tuple
(
element
);
}
template
<
index_t
Depth
=
0
,
index_t
MaxDepth
=
-
1
,
typename
...
Ts
>
__host__
__device__
constexpr
auto
UnrollNestedTuple
(
const
Tuple
<
Ts
...
>&
tuple
)
{
if
constexpr
(
Depth
==
MaxDepth
)
{
return
tuple
;
}
else
{
return
unpack
(
[
&
](
auto
&&
...
ts
)
{
return
concat_tuple
(
UnrollNestedTuple
<
Depth
+
1
,
MaxDepth
>
(
ts
)...);
},
tuple
);
}
}
template
<
typename
...
Ts
>
__host__
__device__
constexpr
auto
TupleReverse
(
const
Tuple
<
Ts
...
>&
tuple
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
using
Idx
=
Number
<
Tuple
<
Ts
...
>::
Size
()
-
i
-
1
>
;
return
tuple
.
At
(
Idx
{});
},
Number
<
Tuple
<
Ts
...
>::
Size
()
>
{});
}
// Reduce tuple values in specific range using Function
template
<
index_t
Idx
,
index_t
End
,
typename
F
,
typename
...
Ts
>
__host__
__device__
constexpr
auto
TupleReduce
(
F
&&
f
,
const
Tuple
<
Ts
...
>&
tuple
)
{
static_assert
(
Idx
<
End
,
"Wrong parameters for TupleReduce"
);
if
constexpr
(
Idx
+
1
==
End
)
{
return
tuple
.
At
(
Number
<
Idx
>
{});
}
else
{
return
f
(
tuple
.
At
(
Number
<
Idx
>
{}),
TupleReduce
<
Idx
+
1
,
End
>
(
f
,
tuple
));
}
}
template
<
typename
T
>
using
is_tuple
=
decltype
(
std
::
declval
<
T
&>
().
IsTuple
());
template
<
typename
...
Ts
>
__host__
__device__
constexpr
auto
IsNestedTuple
(
const
Tuple
<
Ts
...
>&
)
{
return
(
is_detected
<
is_tuple
,
Ts
>::
value
||
...);
}
template
<
index_t
depth
=
0
,
typename
T
>
__host__
__device__
constexpr
auto
TupleDepth
(
const
T
&
)
{
return
depth
;
}
template
<
index_t
depth
=
0
,
typename
...
Ts
>
__host__
__device__
constexpr
auto
TupleDepth
(
const
Tuple
<
Ts
...
>&
)
{
return
math
::
max
(
TupleDepth
<
depth
+
1
>
(
Ts
{})...);
}
template
<
index_t
from
,
index_t
to
,
typename
...
Ts
>
__host__
__device__
constexpr
auto
TupleSlice
(
const
Tuple
<
Ts
...
>&
tuple
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
using
Idx
=
Number
<
from
+
i
>
;
return
tuple
.
At
(
Idx
{});
},
Number
<
to
-
from
>
{});
}
}
// namespace ck
include/ck/utility/type_convert.hpp
View file @
29dcb956
...
...
@@ -8,6 +8,10 @@
#include "ck/utility/random_gen.hpp"
namespace
ck
{
// Define the common macro for MI300 models
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#define __gfx94__
#endif
// Convert X to Y, both X and Y are non-const data types.
template
<
typename
Y
,
...
...
@@ -105,7 +109,7 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
{
constexpr
int
seed
=
42
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#if defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx94__)
float
max_fp8
=
240.0
f
;
x
=
x
>
max_fp8
?
max_fp8
:
(
x
<
-
max_fp8
?
-
max_fp8
:
x
);
union
...
...
@@ -133,7 +137,7 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
template
<
>
inline
__host__
__device__
f8_t
f8_convert_sr
<
f8_t
,
half_t
>
(
half_t
x
)
{
#if defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx94__)
// convert to float and use native converion
return
f8_convert_sr
<
f8_t
>
(
type_convert
<
float
>
(
x
));
#else
...
...
@@ -154,7 +158,7 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x)
{
constexpr
int
seed
=
42
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#if defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx94__)
union
{
float
fval
;
...
...
@@ -180,9 +184,9 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x)
template
<
>
inline
__host__
__device__
bf8_t
f8_convert_sr
<
bf8_t
,
half_t
>
(
half_t
x
)
{
#if defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx94__)
// convert to float and use native converion
return
f8_convert_sr
<
f8_t
>
(
type_convert
<
float
>
(
x
));
return
f8_convert_sr
<
b
f8_t
>
(
type_convert
<
float
>
(
x
));
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
...
...
@@ -203,7 +207,7 @@ __host__ __device__ constexpr Y f8_convert_rne(X x);
template
<
>
inline
__host__
__device__
f8_t
f8_convert_rne
<
f8_t
,
float
>
(
float
x
)
{
#if defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx94__)
float
max_fp8
=
240.0
f
;
x
=
x
>
max_fp8
?
max_fp8
:
(
x
<
-
max_fp8
?
-
max_fp8
:
x
);
union
...
...
@@ -232,7 +236,7 @@ inline __host__ __device__ f8_t f8_convert_rne<f8_t, float>(float x)
template
<
>
inline
__host__
__device__
f8_t
f8_convert_rne
<
f8_t
,
half_t
>
(
half_t
x
)
{
#if defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx94__)
// convert to float and use native converion
return
f8_convert_rne
<
f8_t
>
(
type_convert
<
float
>
(
x
));
#else
...
...
@@ -250,7 +254,7 @@ inline __host__ __device__ f8_t f8_convert_rne<f8_t, half_t>(half_t x)
template
<
>
inline
__host__
__device__
bf8_t
f8_convert_rne
<
bf8_t
,
float
>
(
float
x
)
{
#if defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx94__)
union
{
float
fval
;
...
...
@@ -277,7 +281,7 @@ inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, float>(float x)
template
<
>
inline
__host__
__device__
bf8_t
f8_convert_rne
<
bf8_t
,
half_t
>
(
half_t
x
)
{
#if defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx94__)
// convert to float and use native converion
return
f8_convert_rne
<
bf8_t
>
(
type_convert
<
float
>
(
x
));
#else
...
...
@@ -295,7 +299,7 @@ inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, half_t>(half_t x)
template
<
>
inline
__host__
__device__
f8_t
type_convert
<
f8_t
,
float
>
(
float
x
)
{
#if
defined
CK_USE_SR_F8_CONVERSION
#if CK_USE_SR_F8_CONVERSION
return
f8_convert_sr
<
f8_t
>
(
x
);
#else
return
f8_convert_rne
<
f8_t
>
(
x
);
...
...
@@ -306,7 +310,7 @@ inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
template
<
>
inline
__host__
__device__
float
type_convert
<
float
,
f8_t
>
(
f8_t
x
)
{
#if defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx94__)
float
fval
;
uint32_t
i32val
=
static_cast
<
uint32_t
>
(
x
);
fval
=
__builtin_amdgcn_cvt_f32_fp8
(
i32val
,
0
);
...
...
@@ -321,7 +325,7 @@ inline __host__ __device__ float type_convert<float, f8_t>(f8_t x)
template
<
>
inline
__host__
__device__
float2_t
type_convert
<
float2_t
,
f8x2_t
>
(
f8x2_t
x
)
{
#if defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx94__)
const
auto
i16val
=
bit_cast
<
uint16_t
>
(
x
);
return
__builtin_amdgcn_cvt_pk_f32_fp8
(
i16val
,
0
);
#else
...
...
@@ -352,10 +356,10 @@ inline __host__ __device__ half2_t type_convert<half2_t, float2_t>(float2_t x)
template
<
>
inline
__host__
__device__
f8_t
type_convert
<
f8_t
,
half_t
>
(
half_t
x
)
{
#if
defined
CK_USE_SR_F8_CONVERSION
#if CK_USE_SR_F8_CONVERSION
return
f8_convert_sr
<
f8_t
>
(
x
);
#else
return
f8_convert_
n
re
<
f8_t
>
(
x
);
return
f8_convert_r
n
e
<
f8_t
>
(
x
);
#endif
}
...
...
@@ -363,7 +367,7 @@ inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
template
<
>
inline
__host__
__device__
half_t
type_convert
<
half_t
,
f8_t
>
(
f8_t
x
)
{
#if defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx94__)
// use native conversion to float and convert to fp16
return
type_convert
<
half_t
>
(
type_convert
<
float
>
(
x
));
#else
...
...
@@ -376,7 +380,7 @@ inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x)
template
<
>
inline
__host__
__device__
bf8_t
type_convert
<
bf8_t
,
float
>
(
float
x
)
{
#if
defined
CK_USE_SR_F8_CONVERSION
#if CK_USE_SR_F8_CONVERSION
return
f8_convert_sr
<
bf8_t
>
(
x
);
#else
return
f8_convert_rne
<
bf8_t
>
(
x
);
...
...
@@ -387,7 +391,7 @@ inline __host__ __device__ bf8_t type_convert<bf8_t, float>(float x)
template
<
>
inline
__host__
__device__
float
type_convert
<
float
,
bf8_t
>
(
bf8_t
x
)
{
#if defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx94__)
float
fval
;
uint32_t
i32val
=
static_cast
<
uint32_t
>
(
x
);
fval
=
__builtin_amdgcn_cvt_f32_bf8
(
i32val
,
0
);
...
...
@@ -403,7 +407,7 @@ inline __host__ __device__ float type_convert<float, bf8_t>(bf8_t x)
template
<
>
inline
__host__
__device__
bf8_t
type_convert
<
bf8_t
,
half_t
>
(
half_t
x
)
{
#if
defined
CK_USE_SR_F8_CONVERSION
#if CK_USE_SR_F8_CONVERSION
return
f8_convert_sr
<
bf8_t
>
(
x
);
#else
return
f8_convert_rne
<
bf8_t
>
(
x
);
...
...
@@ -414,7 +418,7 @@ inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x)
template
<
>
inline
__host__
__device__
half_t
type_convert
<
half_t
,
bf8_t
>
(
bf8_t
x
)
{
#if defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx94__)
// use native conversion to float and convert to fp16
return
type_convert
<
half_t
>
(
type_convert
<
float
>
(
x
));
#else
...
...
include/ck/wrapper/layout.hpp
0 → 100644
View file @
29dcb956
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/wrapper/utils/layout_utils.hpp"
namespace
ck
{
namespace
wrapper
{
/**
* \brief Layout wrapper that performs the tensor descriptor logic.
*
* \tparam Shape Tuple of Number<> (for compile-time layout) or index_t
* (dynamic layout). It is possible to pass nested shapes
* (e.g. ((4, 2), 2)), nested dimensions are merged.
* \tparam UnrolledDescriptorType Tensor descriptor for unnested shape dims.
*/
template
<
typename
Shape
,
typename
UnrolledDescriptorType
>
struct
Layout
{
private:
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
/**
* \brief Generate default indices tuple (idx with all merged nested shapes)
*
* \param shape Shape to align.
* \return Multi idx tuple with zeros.
*/
template
<
typename
...
Ts
>
__host__
__device__
constexpr
static
auto
GenerateDefaultIdxsTuple
([[
maybe_unused
]]
const
Tuple
<
Ts
...
>&
shape
)
{
return
generate_tuple
(
[
&
](
auto
)
{
if
constexpr
(
!
remove_cvref_t
<
UnrolledDescriptorType
>::
IsKnownAtCompileTime
())
{
// runtime layout
return
index_t
(
0
);
}
else
{
// compiletime layout
return
I0
;
}
},
Number
<
Tuple
<
Ts
...
>::
Size
()
>
{});
}
/**
* \brief Generate lower dims in compile-time for the Merge transform using
* provided type. If element of nested Tuple<Ts...> is also a tuple, then
* merge (generate sequence for merge). If tuple is element, then pass
* through (sequence with one element).
*
* \param shape Shape to align.
* \return LowerDims for MergeTrasform.
*/
template
<
typename
Idx
,
typename
...
Ts
>
__host__
__device__
constexpr
static
auto
GenerateLowerDim
([[
maybe_unused
]]
const
Tuple
<
Ts
...
>&
shape
)
{
if
constexpr
(
Idx
::
value
==
0
)
{
if
constexpr
(
is_detected
<
is_tuple
,
tuple_element_t
<
Idx
::
value
,
Tuple
<
Ts
...
>>>::
value
)
{
// Return Sequence for the first tuple
constexpr
index_t
merge_nelems
=
decltype
(
UnrollNestedTuple
(
tuple_element_t
<
Idx
::
value
,
Tuple
<
Ts
...
>>
{}))
::
Size
();
using
LowerDimsSequence
=
typename
arithmetic_sequence_gen
<
0
,
merge_nelems
,
1
>::
type
;
return
LowerDimsSequence
::
Reverse
();
}
else
{
// Return first element
return
Sequence
<
0
>
{};
}
}
else
{
// Get previous element using recurence (in compile-time)
using
PreviousSeqT
=
decltype
(
GenerateLowerDim
<
Number
<
Idx
::
value
-
1
>>
(
Tuple
<
Ts
...
>
{}));
const
auto
next_seq_val
=
PreviousSeqT
::
At
(
I0
)
+
1
;
if
constexpr
(
is_detected
<
is_tuple
,
tuple_element_t
<
Idx
::
value
,
Tuple
<
Ts
...
>>>::
value
)
{
constexpr
index_t
merge_nelems
=
decltype
(
UnrollNestedTuple
(
tuple_element_t
<
Idx
::
value
,
Tuple
<
Ts
...
>>
{}))
::
Size
();
using
LowerDimsSequence
=
typename
arithmetic_sequence_gen
<
next_seq_val
,
next_seq_val
+
merge_nelems
,
1
>::
type
;
return
LowerDimsSequence
::
Reverse
();
}
else
{
return
Sequence
<
next_seq_val
>
{};
}
}
}
/**
* \brief Iterate over the nested tuples in the shape.
* Unroll nested tuples to align Tuple<ShapeDims...> to Tuple<IdxDims...>
* Example idx: (1, 1), 1, 1
* Example shape: (2, (2, 2)), 2, (2, 2)
* Unrolled shape: 2, (2, 2), 2, (2, 2)
*
* \param shape Layout shape.
* \param idx Idx to align.
* \return Algined shape.
*/
template
<
typename
...
ShapeDims
,
typename
...
IdxDims
>
__host__
__device__
constexpr
static
auto
AlignShapeToIdx
(
const
Tuple
<
ShapeDims
...
>&
shape
,
const
Tuple
<
IdxDims
...
>&
idx
)
{
if
constexpr
(
!
IsNestedTuple
(
Tuple
<
IdxDims
...
>
{}))
{
// Index unrolled to flatten, return shape
return
shape
;
}
else
{
// Iterate over shape tuple elements:
// 1. If corresponding idx element is tuple then return (will be unrolled)
// 2. If no, pack in tuple. It will be restored during unroll.
auto
aligned_shape
=
generate_tuple
(
[
&
](
auto
i
)
{
if
constexpr
(
is_detected
<
is_tuple
,
tuple_element_t
<
i
,
Tuple
<
IdxDims
...
>>>::
value
)
{
return
shape
.
At
(
i
);
}
else
{
return
make_tuple
(
shape
.
At
(
i
));
}
},
Number
<
Tuple
<
IdxDims
...
>::
Size
()
>
{});
// Unroll and process next step
return
AlignShapeToIdx
(
UnrollNestedTuple
<
0
,
1
>
(
aligned_shape
),
UnrollNestedTuple
<
0
,
1
>
(
idx
));
}
}
/**
* \brief Merge descriptor to 1D.
*
* \param shape Layout shape.
* \param desc Descriptor to merge.
* \return 1D descriptor.
*/
template
<
typename
...
ShapeDims
,
typename
DescriptorToMerge
>
__host__
__device__
constexpr
static
auto
MakeMerge1d
(
const
Tuple
<
ShapeDims
...
>&
shape
,
const
DescriptorToMerge
&
desc
)
{
// Reverse each element in tuple
const
auto
merge_elems
=
TupleReverse
(
UnrollNestedTuple
(
shape
));
// Generate reverted indexes (column major traverse)
using
MergeElemsSequence
=
typename
arithmetic_sequence_gen
<
0
,
merge_elems
.
Size
(),
1
>::
type
;
const
auto
lower_dims
=
make_tuple
(
MergeElemsSequence
::
Reverse
());
const
auto
upper_dims
=
make_tuple
(
Sequence
<
0
>
{});
// Merge to 1d
if
constexpr
(
!
remove_cvref_t
<
UnrolledDescriptorType
>::
IsKnownAtCompileTime
())
{
return
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
merge_elems
)),
lower_dims
,
upper_dims
);
}
else
{
// If the descriptor is known at the compilation time,
// use `make_merge_transform_v1_carry_check` because it doesn't use
// memcpy.
return
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform_v1_carry_check
(
merge_elems
)),
lower_dims
,
upper_dims
);
}
}
/**
* \brief Merge nested shape dims when corresponding index is also merged.
* Input desc shape: 2, 2, 2, 2, 2, 2
* Example idx: 1, 1, 1, (1, 1)
* Example shape: 2, (2, 2), 2, (2, 2)
* Merged shape: 2, 4, 2, 2, 2
*
* \param shape Layout shape.
* \param idxs Indexes to align descriptor.
* \param desc Descriptor to merge.
* \return Aligned descriptor to idx.
*/
template
<
typename
...
ShapeDims
,
typename
...
IdxDims
,
typename
DescriptorToMerge
>
__host__
__device__
constexpr
static
auto
CreateMergedDescriptor
(
const
Tuple
<
ShapeDims
...
>&
shape
,
[[
maybe_unused
]]
const
Tuple
<
IdxDims
...
>&
idxs
,
DescriptorToMerge
&
desc
)
{
const
auto
transforms
=
generate_tuple
(
[
&
](
auto
i
)
{
// Compare Idx with shape
if
constexpr
(
is_detected
<
is_tuple
,
tuple_element_t
<
i
,
Tuple
<
ShapeDims
...
>>>::
value
&&
!
is_detected
<
is_tuple
,
tuple_element_t
<
i
,
Tuple
<
IdxDims
...
>>>::
value
)
{
// If shape element is tuple and idx element is Number, then merge
// Unroll and reverse tuple to traverse column-major
const
auto
merge_elems
=
TupleReverse
(
UnrollNestedTuple
(
shape
.
At
(
i
)));
if
constexpr
(
!
remove_cvref_t
<
UnrolledDescriptorType
>::
IsKnownAtCompileTime
())
{
return
make_merge_transform
(
merge_elems
);
}
else
{
// If the descriptor is known at the compilation time,
// use `make_merge_transform_v1_carry_check` because
// it doesn't use memcpy.
return
make_merge_transform_v1_carry_check
(
merge_elems
);
}
}
else
{
// If shape element is integer and idx element is tuple, passed idx is wrong
static_assert
(
!
(
!
is_detected
<
is_tuple
,
tuple_element_t
<
i
,
Tuple
<
ShapeDims
...
>>>::
value
&&
is_detected
<
is_tuple
,
tuple_element_t
<
i
,
Tuple
<
IdxDims
...
>>>::
value
),
"Wrong Idx for layout()"
);
// If shape element has the same type as idx element, then pass through
return
make_pass_through_transform
(
shape
.
At
(
i
));
}
},
Number
<
Tuple
<
ShapeDims
...
>::
Size
()
>
{});
const
auto
lower_dims
=
generate_tuple
([
&
](
auto
i
)
{
return
GenerateLowerDim
<
Number
<
i
>>
(
shape
);
},
Number
<
Tuple
<
ShapeDims
...
>::
Size
()
>
{});
const
auto
upper_dims
=
generate_tuple
([
&
](
auto
i
)
{
return
Sequence
<
i
.
value
>
{};
},
Number
<
Tuple
<
ShapeDims
...
>::
Size
()
>
{});
return
transform_tensor_descriptor
(
desc
,
transforms
,
lower_dims
,
upper_dims
);
}
using
Descriptor1dType
=
remove_cvref_t
<
decltype
(
MakeMerge1d
(
Shape
{},
UnrolledDescriptorType
{}))
>
;
using
DefaultIdxsTupleType
=
remove_cvref_t
<
decltype
(
GenerateDefaultIdxsTuple
(
Shape
{}))
>
;
public:
using
LayoutShape
=
Shape
;
using
LayoutUnrolledDescriptorType
=
UnrolledDescriptorType
;
/**
* \brief Transform descriptor to align to passed indexes.
*
* \param shape Layout shape.
* \param idxs Indexes to align descriptor.
* \param naive_descriptor Descriptor to merge.
* \return Aligned descriptor to idx.
*/
template
<
typename
...
ShapeDims
,
typename
...
IdxDims
>
__host__
__device__
constexpr
static
auto
TransformDesc
(
const
Tuple
<
ShapeDims
...
>&
shape
,
const
Tuple
<
IdxDims
...
>&
idxs
,
const
UnrolledDescriptorType
&
naive_descriptor
)
{
if
constexpr
(
Tuple
<
IdxDims
...
>::
Size
()
==
I1
)
{
// 1d idx path
return
MakeMerge1d
(
shape
,
naive_descriptor
);
}
else
{
// Merge nested shape dims
// Example idx: (1, 1), 1, 1
// Example shape: (2, (2, 2)), 2, (2, 2)
// Merged shape: (2, 4), 2, 4
static_assert
(
Tuple
<
ShapeDims
...
>::
Size
()
==
Tuple
<
IdxDims
...
>::
Size
(),
"Idx rank and Shape rank must be the same (except 1d)."
);
// Unroll while IdxDims is nested
const
auto
aligned_shape
=
AlignShapeToIdx
(
shape
,
idxs
);
// Transform correct form of shape
return
CreateMergedDescriptor
(
aligned_shape
,
UnrollNestedTuple
(
idxs
),
naive_descriptor
);
}
}
using
MergedNestsDescriptorType
=
remove_cvref_t
<
decltype
(
TransformDesc
(
Shape
{},
DefaultIdxsTupleType
{},
UnrolledDescriptorType
{}))
>
;
__host__
__device__
constexpr
auto
GetElementSpaceSize
()
const
{
return
unrolled_descriptor_
.
GetElementSpaceSize
();
}
__host__
__device__
Layout
()
=
delete
;
/**
* \brief Layout constructor.
*
* \param shape Shape for layout.
* \param unnested_descriptor Descriptor
*/
__host__
__device__
constexpr
Layout
(
const
Shape
&
shape
,
const
UnrolledDescriptorType
&
unnested_descriptor
)
:
unrolled_descriptor_
(
unnested_descriptor
),
shape_
(
shape
)
{
// Construct if runtime mode
if
constexpr
(
!
remove_cvref_t
<
UnrolledDescriptorType
>::
IsKnownAtCompileTime
())
{
descriptor_1d_
=
MakeMerge1d
(
shape_
,
unrolled_descriptor_
);
merged_nests_descriptor_
=
TransformDesc
(
shape_
,
DefaultIdxsTupleType
{},
unrolled_descriptor_
);
}
}
/**
* \brief Returns real offset to element in runtime.
*
* \tparam Idxs Tuple of indexes.
* \return Calculated offset.
*/
template
<
typename
Idxs
>
__host__
__device__
constexpr
index_t
operator
()()
const
{
static_assert
(
remove_cvref_t
<
UnrolledDescriptorType
>::
IsKnownAtCompileTime
(),
"Compiletime operator used on runtime layout."
);
using
TransformedDesc
=
decltype
(
TransformDesc
(
Shape
{},
Idxs
{},
UnrolledDescriptorType
{}));
using
UnrolledIdx
=
decltype
(
UnrollNestedTuple
(
Idxs
{}));
return
TransformedDesc
{}.
CalculateOffset
(
UnrolledIdx
{});
}
/**
* \brief Returns real offset to element in compile time.
*
* \param Idx Tuple of indexes.
* \return Calculated offset.
*/
template
<
typename
...
Ts
>
__host__
__device__
index_t
operator
()(
const
Tuple
<
Ts
...
>&
Idx
)
const
{
if
constexpr
(
!
IsNestedTuple
(
Tuple
<
Ts
...
>
{})
&&
Tuple
<
Ts
...
>::
Size
()
==
1
)
{
// if 1d access
return
descriptor_1d_
.
CalculateOffset
(
Idx
);
}
else
if
constexpr
(
!
IsNestedTuple
(
Tuple
<
Ts
...
>
{})
&&
Tuple
<
Ts
...
>::
Size
()
==
Shape
::
Size
())
{
// if Shape::Size() access (merged nested shapes)
return
merged_nests_descriptor_
.
CalculateOffset
(
UnrollNestedTuple
(
Idx
));
}
else
{
// Custom index, need to transform descriptor
const
auto
transformed_desc
=
TransformDesc
(
shape_
,
Idx
,
unrolled_descriptor_
);
return
transformed_desc
.
CalculateOffset
(
UnrollNestedTuple
(
Idx
));
}
}
/**
* \brief Length getter (product if tuple).
*
* \tparam IDim Tuple of indexes or index.
* \return Calculated size.
*/
template
<
index_t
IDim
>
__host__
__device__
constexpr
auto
GetLength
()
const
{
const
auto
elem
=
shape_
.
At
(
Number
<
IDim
>
{});
if
constexpr
(
is_detected
<
is_tuple
,
tuple_element_t
<
IDim
,
Shape
>>::
value
)
{
const
auto
unrolled_element
=
UnrollNestedTuple
(
elem
);
return
TupleReduce
<
I0
.
value
,
unrolled_element
.
Size
()
>
(
[](
auto
x
,
auto
y
)
{
return
x
*
y
;
},
unrolled_element
);
}
else
{
return
elem
;
}
}
/**
* \brief Layout size getter (product of shape).
*
* \return Calculated size.
*/
__host__
__device__
constexpr
auto
GetLengths
()
const
{
const
auto
unrolled_shape
=
UnrollNestedTuple
(
shape_
);
return
TupleReduce
<
I0
.
value
,
unrolled_shape
.
Size
()
>
([](
auto
x
,
auto
y
)
{
return
x
*
y
;
},
unrolled_shape
);
}
/**
* \brief Shape getter.
*
* \return Shape.
*/
__host__
__device__
constexpr
const
Shape
&
GetShape
()
const
{
return
shape_
;
}
/**
* \brief Get default lengths (tuple filled with Shape length elements).
*
* \return Default lengths.
*/
__host__
__device__
constexpr
auto
GetDefaultLengthsTuple
()
const
{
return
generate_tuple
([
&
](
auto
i
)
{
return
GetLength
<
i
>
();
},
Number
<
Shape
::
Size
()
>
{});
}
/**
* \brief Get default start idx (tuple filled with 0s of the same size as Shape).
*
* \return Default start idx.
*/
__host__
__device__
constexpr
auto
GetDefaultStartIdxs
()
const
{
return
GenerateDefaultIdxsTuple
(
shape_
);
}
/**
* \brief Get descriptor with all nested dimensions merged.
* Example, shape: ((2, 2), 2)
* Descriptor lengths: (4, 2)
*
* \note The size of merged descriptor is the same as Layout's shape.
*
* \return Merged nests descriptor.
*/
__host__
__device__
constexpr
const
MergedNestsDescriptorType
&
GetMergedNestingDescriptor
()
const
{
return
merged_nests_descriptor_
;
}
/**
* \brief Get descriptor with all dimensions are merged (1D).
* Example, shape: ((2, 2), 2)
* Descriptor lengths: (8)
*
* \return 1D descriptor.
*/
__host__
__device__
constexpr
const
Descriptor1dType
&
Get1DDescriptor
()
const
{
return
descriptor_1d_
;
}
/**
* \brief Get unnested descriptor (with unrolled dims)
* Example, shape: ((2, 2), 2)
* Descriptor lengths: (2, 2, 2)
*
* \return Flattened descriptor.
*/
__host__
__device__
constexpr
const
UnrolledDescriptorType
&
GetUnrolledDescriptor
()
const
{
return
unrolled_descriptor_
;
}
private:
// All dimensions are unrolled
UnrolledDescriptorType
unrolled_descriptor_
;
// 1D descriptor
Descriptor1dType
descriptor_1d_
;
// All nesting are merged
MergedNestsDescriptorType
merged_nests_descriptor_
;
// Example, shape: ((2, 2), 2)
// UnrolledDescriptorType lengths: (2, 2, 2)
// Descriptor1dType lengths: (8)
// MergedNestsDescriptorType lengths: (4, 2)
const
Shape
shape_
;
};
}
// namespace wrapper
}
// namespace ck
include/ck/wrapper/operations/copy.hpp
0 → 100644
View file @
29dcb956
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/wrapper/utils/tensor_utils.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
namespace
ck
{
namespace
wrapper
{
/**
* \brief Perform optimized copy between two tensors partitions (threadwise copy).
* Tensors must have the same size.
*
* \tparam DimAccessOrderTuple Tuple with dimension access order.
* \tparam VectorDim Dimension for vectorized read and write.
* \tparam ScalarPerVector Number of scalar per vectorized read and write.
* \param src_tensor Source tensor.
* \param dst_tensor Destination tensor.
*/
template
<
typename
DimAccessOrderTuple
,
index_t
VectorDim
,
index_t
ScalarPerVector
,
typename
SrcTensorType
,
typename
DstTensorType
>
__device__
void
copy
(
const
SrcTensorType
&
src_tensor
,
DstTensorType
&
dst_tensor
)
{
static_assert
(
is_detected
<
is_tuple
,
DimAccessOrderTuple
>::
value
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
const
auto
&
in_grid_desc
=
layout
(
src_tensor
).
GetUnrolledDescriptor
();
const
auto
&
out_grid_desc
=
layout
(
dst_tensor
).
GetUnrolledDescriptor
();
using
SrcShapeType
=
remove_cvref_t
<
decltype
(
shape
(
src_tensor
))
>
;
constexpr
index_t
num_dims
=
SrcShapeType
::
Size
();
constexpr
auto
thread_slice_lengths
=
generate_sequence_v2
([](
auto
I
)
{
return
size
(
SrcShapeType
{}.
At
(
I
));
},
Number
<
num_dims
>
{});
constexpr
auto
dim_access_order
=
generate_sequence_v2
(
[](
auto
I
)
{
return
DimAccessOrderTuple
{}.
At
(
I
);
},
Number
<
num_dims
>
{});
if
constexpr
(
SrcTensorType
::
IsDynamicBuffer
&&
DstTensorType
::
IsDynamicBuffer
)
{
// Perform a copy between DynamicBuffers
auto
transfer
=
ThreadwiseTensorSliceTransfer_v7
<
Tuple
<
typename
SrcTensorType
::
TensorElementType
>
,
Tuple
<
typename
DstTensorType
::
TensorElementType
>
,
decltype
(
tie
(
in_grid_desc
)),
decltype
(
tie
(
out_grid_desc
)),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
static_cast
<
index_t
>
(
InMemoryDataOperationEnum
::
Set
)
>
,
decltype
(
thread_slice_lengths
),
decltype
(
dim_access_order
),
VectorDim
,
ScalarPerVector
,
Sequence
<
false
>
,
Sequence
<
false
>>
{
in_grid_desc
,
make_tuple
(
src_tensor
.
GetMultiIdxOffsets
()),
out_grid_desc
,
make_tuple
(
dst_tensor
.
GetMultiIdxOffsets
()),
tensor_operation
::
element_wise
::
PassThrough
{}};
transfer
.
Run
(
tie
(
in_grid_desc
),
tie
(
src_tensor
.
GetBuffer
()),
tie
(
out_grid_desc
),
tie
(
dst_tensor
.
GetBuffer
()));
}
else
if
constexpr
(
!
SrcTensorType
::
IsDynamicBuffer
&&
DstTensorType
::
IsDynamicBuffer
)
{
// Perform copy from StaticBuffer to DynamicBuffer
const
auto
src_slice_origin_idxs
=
generate_tuple
([
&
](
auto
)
{
return
I0
;
},
Number
<
num_dims
>
{});
auto
transfer
=
ThreadwiseTensorSliceTransfer_v1r3
<
typename
SrcTensorType
::
TensorElementType
,
typename
DstTensorType
::
TensorElementType
,
remove_cvref_t
<
decltype
(
in_grid_desc
)
>
,
remove_cvref_t
<
decltype
(
out_grid_desc
)
>
,
tensor_operation
::
element_wise
::
PassThrough
,
decltype
(
thread_slice_lengths
),
decltype
(
dim_access_order
),
VectorDim
,
ScalarPerVector
,
InMemoryDataOperationEnum
::
Set
,
I1
,
true
>
{
out_grid_desc
,
dst_tensor
.
GetMultiIdxOffsets
(),
tensor_operation
::
element_wise
::
PassThrough
{}};
transfer
.
Run
(
in_grid_desc
,
src_slice_origin_idxs
,
src_tensor
.
GetBuffer
(),
out_grid_desc
,
dst_tensor
.
GetBuffer
());
}
else
if
constexpr
(
SrcTensorType
::
IsDynamicBuffer
&&
!
DstTensorType
::
IsDynamicBuffer
)
{
// Perform copy from DynamicBuffer to StaticBuffer
const
auto
src_dst_slice_origin
=
generate_tuple
([
&
](
auto
)
{
return
I0
;
},
Number
<
num_dims
>
{});
constexpr
auto
src_vector_tensor_lengths
=
generate_sequence_v2
(
[
&
](
auto
I
)
{
if
constexpr
(
I
==
VectorDim
)
{
return
Number
<
ScalarPerVector
>
{};
}
else
{
return
I1
;
}
},
Number
<
num_dims
>
{});
auto
transfer
=
ThreadwiseTensorSliceTransfer_v4r1
<
typename
SrcTensorType
::
TensorElementType
,
typename
DstTensorType
::
TensorElementType
,
remove_cvref_t
<
decltype
(
in_grid_desc
)
>
,
remove_cvref_t
<
decltype
(
out_grid_desc
)
>
,
decltype
(
thread_slice_lengths
),
decltype
(
dim_access_order
),
decltype
(
src_vector_tensor_lengths
),
decltype
(
dim_access_order
)
>
{
src_tensor
.
GetMultiIdxOffsets
()};
transfer
.
Run
(
in_grid_desc
,
src_dst_slice_origin
,
src_tensor
.
GetBuffer
(),
out_grid_desc
,
src_dst_slice_origin
,
dst_tensor
.
GetBuffer
());
}
else
{
// Perform copy between StaticBuffers
static_for
<
0
,
SrcShapeType
::
Size
(),
1
>
{}([
&
](
auto
i
)
{
dst_tensor
(
i
)
=
src_tensor
(
i
);
});
}
}
/**
* \brief Perform generic copy between two tensors partitions (threadwise copy).
* Tensors must have the same size.
*
* \param src_tensor Source tensor.
* \param dst_tensor Destination tensor.
*/
template
<
typename
SrcTensorType
,
typename
DstTensorType
>
__host__
__device__
void
copy
(
const
SrcTensorType
&
src_tensor
,
DstTensorType
&
dst_tensor
)
{
// Generate default params
using
SrcShapeType
=
remove_cvref_t
<
decltype
(
shape
(
src_tensor
))
>
;
constexpr
index_t
num_dims
=
SrcShapeType
::
Size
();
// Incrementing dims 0, 1, 2 ... num_dims - 1
constexpr
auto
dim_access_order_tuple
=
generate_tuple
([](
auto
i
)
{
return
Number
<
i
>
{};
},
Number
<
num_dims
>
{});
constexpr
index_t
vector_dim
=
num_dims
-
1
;
constexpr
index_t
scalar_per_vector
=
1
;
copy
<
decltype
(
dim_access_order_tuple
),
vector_dim
,
scalar_per_vector
>
(
src_tensor
,
dst_tensor
);
}
/**
* \brief Perform optimized blockwise copy between two tensors. Tensors must have the
* same size.
*
* \note At now Vgpr and Sgpr are not supported.
*
* \tparam DimAccessOrderTuple Tuple with dimension access order.
* \tparam VectorDim Dimension for vectorize read and write.
* \tparam ScalarPerVector Number of scalar per vectorize read and write.
* \param src_tensor Source tensor.
* \param dst_tensor Destination tensor.
* \param thread_layout Thread layout per each dimension for copy.
*/
template
<
typename
DimAccessOrderTuple
,
index_t
VectorDim
,
index_t
ScalarPerVector
,
typename
SrcTensorType
,
typename
DstTensorType
,
typename
ThreadLayoutTuple
>
__device__
void
blockwise_copy
(
const
SrcTensorType
&
src_tensor
,
DstTensorType
&
dst_tensor
,
[[
maybe_unused
]]
ThreadLayoutTuple
&
thread_layout
)
{
static_assert
(
SrcTensorType
::
IsDynamicBuffer
&&
DstTensorType
::
IsDynamicBuffer
);
static_assert
(
is_detected
<
is_tuple
,
DimAccessOrderTuple
>::
value
);
const
auto
&
in_grid_desc
=
layout
(
src_tensor
).
GetUnrolledDescriptor
();
const
auto
&
out_grid_desc
=
layout
(
dst_tensor
).
GetUnrolledDescriptor
();
using
SrcShapeType
=
remove_cvref_t
<
decltype
(
shape
(
src_tensor
))
>
;
constexpr
index_t
num_dims
=
SrcShapeType
::
Size
();
constexpr
auto
tile_lengths_seq
=
generate_sequence_v2
([](
auto
I
)
{
return
size
(
SrcShapeType
{}.
At
(
I
));
},
Number
<
num_dims
>
{});
constexpr
auto
thread_layout_seq
=
generate_sequence_v2
(
[](
auto
I
)
{
return
size
(
ThreadLayoutTuple
{}.
At
(
I
));
},
Number
<
num_dims
>
{});
constexpr
auto
dim_access_order
=
generate_sequence_v2
(
[](
auto
I
)
{
return
DimAccessOrderTuple
{}.
At
(
I
);
},
Number
<
num_dims
>
{});
using
ThisThreadBlock
=
ThisThreadBlock
<
size
(
ThreadLayoutTuple
{})
>
;
// Perform copy between DynamicBuffers
auto
transfer
=
ThreadGroupTensorSliceTransfer_v7
<
ThisThreadBlock
,
Tuple
<
typename
SrcTensorType
::
TensorElementType
>
,
Tuple
<
typename
DstTensorType
::
TensorElementType
>
,
decltype
(
tie
(
in_grid_desc
)),
decltype
(
tie
(
out_grid_desc
)),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
static_cast
<
index_t
>
(
InMemoryDataOperationEnum
::
Set
)
>
,
std
::
remove_const_t
<
decltype
(
tile_lengths_seq
)
>
,
std
::
remove_const_t
<
decltype
(
thread_layout_seq
)
>
,
std
::
remove_const_t
<
decltype
(
dim_access_order
)
>
,
std
::
remove_const_t
<
decltype
(
dim_access_order
)
>
,
VectorDim
,
ScalarPerVector
,
Sequence
<
true
>
,
Sequence
<
true
>>
{
in_grid_desc
,
make_tuple
(
src_tensor
.
GetMultiIdxOffsets
()),
out_grid_desc
,
make_tuple
(
dst_tensor
.
GetMultiIdxOffsets
()),
tensor_operation
::
element_wise
::
PassThrough
{}};
transfer
.
Run
(
tie
(
in_grid_desc
),
tie
(
src_tensor
.
GetBuffer
()),
tie
(
out_grid_desc
),
tie
(
dst_tensor
.
GetBuffer
()));
}
}
// namespace wrapper
}
// namespace ck
include/ck/wrapper/operations/gemm.hpp
0 → 100644
View file @
29dcb956
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/wrapper/utils/tensor_utils.hpp"
#include "ck/wrapper/traits/blockwise_gemm_xdl_traits.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
namespace
ck
{
namespace
wrapper
{
namespace
{
namespace
detail
{
/**
* \brief Create block descriptor (K0, MPerBlock or NPerBlock, K1).
*
*
* \tparam K1 The number of K-dim elements that are packed together as a separate logical dimension.
* \tparam TileLayout Tensor data tile layout (M,K) or (N,K).
*
* \return Block descriptor (K0, MPerBlock or NPerBlock, K1)
*/
template
<
index_t
K1
,
typename
TileLayout
>
__device__
constexpr
auto
GetBlockDescriptor
()
{
using
TileLayoutShape
=
typename
TileLayout
::
LayoutShape
;
using
TileLayoutDescriptor
=
typename
TileLayout
::
LayoutUnrolledDescriptorType
;
constexpr
auto
K0PerBlock
=
Number
<
size
<
1
>
(
TileLayoutShape
{})
>
{}
/
Number
<
K1
>
{};
// MPerBlock or NPerBlock
constexpr
auto
Dim0
=
Number
<
size
<
0
>
(
TileLayoutShape
{})
>
{};
constexpr
auto
a_block_desc_k0_m_k1
=
transform_tensor_descriptor
(
TileLayoutDescriptor
{},
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0PerBlock
,
Number
<
K1
>
{})),
make_pass_through_transform
(
Dim0
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_block_desc_k0_m_k1
;
}
}
// namespace detail
}
// namespace
/**
* \brief Perform blockwise gemm xdl on tensors stored in lds. Result will be
* stored in Vgpr register. A data layout must be (MPerBlock, KPerBlock) and B
* data layout must be (NPerBlock, KPerBlock).
*
* \note C output Vgpr register layout (8D):
* - MXdlPerWave - The number of MFMA instructions run by single wave in M
* dimension per tile.
* - NXdlPerWave - The number of MFMA instructions run by single wave in N
* dimension per tile.
* - MWave - Equals to 1 since this is for single wave.
* - NWave - Equals to 1 since this is for single wave.
* - NumGroupsPerBlock - Mfma instruction internal layout (depeneds on the
* instruction size).
* - NumInputsBlock - Mfma instruction internal layout (depeneds on the
* instruction size).
* - GroupSize - Mfma instruction internal layout (depeneds on the
* instruction size).
* - NumThreadsPerBlock - Mfma instruction internal layout (depeneds on the
* instruction size).
*
* \tparam DataType Input data types.
* \tparam BlockSize Tensor to pad.
* \tparam GemmTraits Traits of gemm xdl operation.
* \param a_local_tile_tensor A tensor in LDS memory for blockwise gemm
* (MPerBlock, KPerBlock) layout.
* \param b_local_tile_tensor B tensor in LDS memory for blockwise gemm
* (NPerBlock, KPerBlock) layout.
* \param c_reg_tensor C tensor VGPR memory for blockwise gemm.
*/
template
<
typename
DataType
,
index_t
BlockSize
,
typename
GemmTraits
,
typename
ATensorType
,
typename
BTensorType
,
typename
CTensorType
>
__device__
void
blockwise_gemm_xdl
(
const
ATensorType
&
a_local_tile_tensor
,
const
BTensorType
&
b_local_tile_tensor
,
CTensorType
&
c_reg_tensor
)
{
static_assert
(
ATensorType
::
TensorBufferAddressSpace
==
MemoryTypeEnum
::
Lds
);
static_assert
(
BTensorType
::
TensorBufferAddressSpace
==
MemoryTypeEnum
::
Lds
);
static_assert
(
CTensorType
::
TensorBufferAddressSpace
==
MemoryTypeEnum
::
Vgpr
);
static_assert
(
is_same_v
<
DataType
,
typename
ATensorType
::
TensorElementType
>
);
static_assert
(
is_same_v
<
DataType
,
typename
BTensorType
::
TensorElementType
>
);
constexpr
bool
is_integer
=
is_same_v
<
DataType
,
int8_t
>
||
is_same_v
<
DataType
,
int16_t
>
||
is_same_v
<
DataType
,
int32_t
>
;
using
GemmAccDataType
=
std
::
conditional_t
<
is_integer
,
int32_t
,
float
>
;
using
ATileLayout
=
remove_cvref_t
<
decltype
(
layout
(
a_local_tile_tensor
))
>
;
using
BTileLayout
=
remove_cvref_t
<
decltype
(
layout
(
b_local_tile_tensor
))
>
;
using
ABlockDesc_K0_M_K1_Type
=
decltype
(
detail
::
GetBlockDescriptor
<
GemmTraits
::
K1
,
ATileLayout
>
());
using
BBlockDesc_K0_N_K1_Type
=
decltype
(
detail
::
GetBlockDescriptor
<
GemmTraits
::
K1
,
BTileLayout
>
());
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
DataType
,
DataType
,
GemmAccDataType
,
ABlockDesc_K0_M_K1_Type
,
BBlockDesc_K0_N_K1_Type
,
GemmTraits
::
MPerXDL
,
GemmTraits
::
NPerXDL
,
GemmTraits
::
MXdlPerWave
,
GemmTraits
::
NXdlPerWave
,
GemmTraits
::
K1
>
blockwise_gemm_xdl_op
{};
blockwise_gemm_xdl_op
.
Run
(
a_local_tile_tensor
.
GetBuffer
(),
b_local_tile_tensor
.
GetBuffer
(),
c_reg_tensor
.
GetBuffer
());
}
/**
* \brief Create local partition per thread for C tensor.
*
* \note C output global memory layout (8D):
* - MXdlPerWave - The number of MFMA instructions run by single wave in M
* dimension.
* - NXdlPerWave - The number of MFMA instructions run by single wave in N
* dimension.
* - MWave - The number of waves in single tile M dimension per tile.
* - NWave - The number of waves in single tile N dimension per tile.
* - NumGroupsPerBlock - Mfma instruction internal layout (depeneds on the
* instruction size).
* - NumInputsBlock - Mfma instruction internal layout (depeneds on the
* instruction size).
* - GroupSize - Mfma instruction internal layout (depeneds on the
* instruction size).
* - NumThreadsPerBlock - Mfma instruction internal layout (depeneds on the
* instruction size).
*
* \tparam DataType Input data types.
* \tparam ATileLayout A tensor layout.
* \tparam BTileLayout B tensor layout.
* \tparam BlockSize Number of threads in block.
* \tparam GemmTraits Traits of gemm xdl operation.
* \param c_local_tile_tensor C tensor in LDS memory for blockwise gemm
* (MPerBlock, NPerBlock) layout.
*
* \return Partition c tensor for blockwise gemm.
*/
template
<
typename
DataType
,
typename
ATileLayout
,
typename
BTileLayout
,
index_t
BlockSize
,
typename
GemmTraits
,
typename
CTensorType
>
__host__
__device__
constexpr
auto
make_blockwise_gemm_xdl_c_local_partition
(
CTensorType
&
c_local_tile_tensor
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I4
=
Number
<
4
>
{};
constexpr
auto
I5
=
Number
<
5
>
{};
constexpr
auto
I6
=
Number
<
6
>
{};
constexpr
auto
I7
=
Number
<
7
>
{};
constexpr
bool
is_integer
=
is_same_v
<
DataType
,
int8_t
>
||
is_same_v
<
DataType
,
int16_t
>
||
is_same_v
<
DataType
,
int32_t
>
;
using
GemmAccDataType
=
std
::
conditional_t
<
is_integer
,
int32_t
,
float
>
;
using
ABlockDesc_K0_M_K1_Type
=
decltype
(
detail
::
GetBlockDescriptor
<
GemmTraits
::
K1
,
ATileLayout
>
());
using
BBlockDesc_K0_N_K1_Type
=
decltype
(
detail
::
GetBlockDescriptor
<
GemmTraits
::
K1
,
BTileLayout
>
());
using
BlockwiseGemmXdlops
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
DataType
,
DataType
,
GemmAccDataType
,
ABlockDesc_K0_M_K1_Type
,
BBlockDesc_K0_N_K1_Type
,
GemmTraits
::
MPerXDL
,
GemmTraits
::
NPerXDL
,
GemmTraits
::
MXdlPerWave
,
GemmTraits
::
NXdlPerWave
,
GemmTraits
::
K1
>
;
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
BlockwiseGemmXdlops
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
constexpr
auto
M0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I0
);
constexpr
auto
N0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I1
);
constexpr
auto
M1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I2
);
constexpr
auto
N1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I3
);
constexpr
auto
M2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I4
);
constexpr
auto
M3
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I5
);
constexpr
auto
M4
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I6
);
constexpr
auto
N2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I7
);
// Calculate offset on grid
const
auto
c_thread_mtx_on_block
=
BlockwiseGemmXdlops
::
CalculateCThreadOriginDataIndex
(
I0
,
I0
,
I0
,
I0
);
const
index_t
m_thread_data_on_grid
=
c_local_tile_tensor
.
GetMultiIdxOffsets
()[
I0
]
+
c_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_grid
=
c_local_tile_tensor
.
GetMultiIdxOffsets
()[
I1
]
+
c_thread_mtx_on_block
[
I1
];
const
auto
m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
m_thread_data_on_grid_idx
=
m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
m_thread_data_on_grid
));
const
auto
n_thread_data_on_grid_to_n0_n1_n2_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_on_grid_idx
=
n_thread_data_on_grid_to_n0_n1_n2_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_grid
));
// Create partition shape based on descriptor dims.
const
auto
partition_shape
=
make_tuple
(
M0
,
N0
,
I1
,
I1
,
M2
,
I1
,
M4
,
I1
);
const
auto
partition_desc
=
BlockwiseGemmXdlops
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
layout
(
c_local_tile_tensor
).
GetUnrolledDescriptor
());
const
auto
partition_layout
=
Layout
<
remove_reference_t
<
decltype
(
partition_shape
)
>
,
decltype
(
partition_desc
)
>
(
partition_shape
,
partition_desc
);
auto
partition_tensor
=
make_tensor
<
CTensorType
::
TensorBufferAddressSpace
>
(
c_local_tile_tensor
.
GetPointer
(),
partition_layout
);
partition_tensor
.
SetMultiIdxOffset
(
make_multi_index
(
m_thread_data_on_grid_idx
[
I0
],
n_thread_data_on_grid_idx
[
I0
],
m_thread_data_on_grid_idx
[
I1
],
n_thread_data_on_grid_idx
[
I1
],
m_thread_data_on_grid_idx
[
I2
],
m_thread_data_on_grid_idx
[
I3
],
m_thread_data_on_grid_idx
[
I4
],
n_thread_data_on_grid_idx
[
I2
]));
return
partition_tensor
;
}
/**
* \brief Create local partition per thread for C tensor.
*
* \note C output Vgpr register layout (8D):
* - MXdlPerWave - The number of MFMA instructions run by single wave in M
* dimension per tile.
* - NXdlPerWave - The number of MFMA instructions run by single wave in N
* dimension per tile.
* - MWave - Equals to 1 since this is for single wave.
* - NWave - Equals to 1 since this is for single wave.
* - NumGroupsPerBlock - Mfma instruction internal layout (depeneds on the
* instruction size).
* - NumInputsBlock - Mfma instruction internal layout (depeneds on the
* instruction size).
* - GroupSize - Mfma instruction internal layout (depeneds on the
* instruction size).
* - NumThreadsPerBlock - Mfma instruction internal layout (depeneds on the
* instruction size).
*
* \tparam DataType Input data types.
* \tparam ATileLayout A tensor layout.
* \tparam BTileLayout B tensor layout.
* \tparam BlockSize Number of threads in block.
* \tparam GemmTraits Traits of gemm xdl operation.
*
* \return Vgpr c tensor for blockwise gemm.
*/
template
<
typename
DataType
,
typename
ATileLayout
,
typename
BTileLayout
,
index_t
BlockSize
,
typename
GemmTraits
>
__host__
__device__
constexpr
auto
make_blockwise_gemm_xdl_c_vgpr
()
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I4
=
Number
<
4
>
{};
constexpr
auto
I5
=
Number
<
5
>
{};
constexpr
auto
I6
=
Number
<
6
>
{};
constexpr
auto
I7
=
Number
<
7
>
{};
constexpr
bool
is_integer
=
is_same_v
<
DataType
,
int8_t
>
||
is_same_v
<
DataType
,
int16_t
>
||
is_same_v
<
DataType
,
int32_t
>
;
using
GemmAccDataType
=
std
::
conditional_t
<
is_integer
,
int32_t
,
float
>
;
using
ABlockDesc_K0_M_K1_Type
=
decltype
(
detail
::
GetBlockDescriptor
<
GemmTraits
::
K1
,
ATileLayout
>
());
using
BBlockDesc_K0_N_K1_Type
=
decltype
(
detail
::
GetBlockDescriptor
<
GemmTraits
::
K1
,
BTileLayout
>
());
using
BlockwiseGemmXdlops
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
DataType
,
DataType
,
GemmAccDataType
,
ABlockDesc_K0_M_K1_Type
,
BBlockDesc_K0_N_K1_Type
,
GemmTraits
::
MPerXDL
,
GemmTraits
::
NPerXDL
,
GemmTraits
::
MXdlPerWave
,
GemmTraits
::
NXdlPerWave
,
GemmTraits
::
K1
>
;
// Calcualte descriptor, shape and layout
constexpr
auto
vgpr_desc
=
BlockwiseGemmXdlops
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
const
auto
vgpr_shape
=
make_tuple
(
vgpr_desc
.
GetLengths
()[
I0
],
vgpr_desc
.
GetLengths
()[
I1
],
vgpr_desc
.
GetLengths
()[
I2
],
vgpr_desc
.
GetLengths
()[
I3
],
vgpr_desc
.
GetLengths
()[
I4
],
vgpr_desc
.
GetLengths
()[
I5
],
vgpr_desc
.
GetLengths
()[
I6
],
vgpr_desc
.
GetLengths
()[
I7
]);
const
auto
vgpr_layout
=
Layout
<
remove_reference_t
<
decltype
(
vgpr_shape
)
>
,
decltype
(
vgpr_desc
)
>
(
vgpr_shape
,
vgpr_desc
);
// Get vector type for Vgpr
using
BlockwiseGemmCThreadBufferType
=
remove_reference_t
<
decltype
(
BlockwiseGemmXdlops
{}.
GetCThreadBuffer
())
>
;
using
VgprVectorType
=
typename
BlockwiseGemmCThreadBufferType
::
V
;
return
ck
::
wrapper
::
make_register_tensor
<
ck
::
wrapper
::
MemoryTypeEnum
::
Vgpr
,
VgprVectorType
>
(
vgpr_layout
);
}
}
// namespace wrapper
}
// namespace ck
include/ck/wrapper/tensor.hpp
0 → 100644
View file @
29dcb956
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "utils/tensor_utils.hpp"
#include "utils/tensor_partition.hpp"
#include "utils/layout_utils.hpp"
namespace
ck
{
namespace
wrapper
{
namespace
{
namespace
detail
{
/**
* \brief Check if Tuple contains Slice object
*
* \return True if tuple contains Slice object.
*/
template
<
typename
T
>
__host__
__device__
constexpr
bool
HasSlice
(
T
&&
)
{
return
is_detected
<
is_slice
,
T
>::
value
;
}
template
<
typename
...
Ts
>
__host__
__device__
constexpr
bool
HasSlice
(
Tuple
<
Ts
...
>&&
)
{
return
(
HasSlice
(
Ts
{})
||
...);
}
/**
* \brief Calculate new shape after slice from parent shape.
*
* \param idxs Tuple of indexes defining slice ranges.
* \param shape Shape which will be sliced.
* \return New tensor shape.
*/
template
<
typename
...
Ts
,
typename
SlicedShape
>
__host__
__device__
constexpr
auto
GetSlicedShape
(
const
Tuple
<
Ts
...
>&
idxs
,
const
SlicedShape
&
shape
)
{
// Pack each value in tuple to remove empty tuples after generation
auto
new_shape
=
generate_tuple
(
[
&
](
auto
i
)
{
constexpr
auto
num_i
=
Number
<
i
>
{};
if
constexpr
(
is_detected
<
is_tuple
,
tuple_element_t
<
i
.
value
,
Tuple
<
Ts
...
>>>::
value
)
{
if
constexpr
(
!
detail
::
HasSlice
(
tuple_element_t
<
i
.
value
,
Tuple
<
Ts
...
>>
{}))
{
// if tuple does not have any slice then we can remove dimension
return
Tuple
<>
{};
}
else
{
// if tuple then recurrence
return
make_tuple
(
GetSlicedShape
(
idxs
.
At
(
num_i
),
shape
.
At
(
num_i
)));
}
}
else
if
constexpr
(
is_detected
<
is_slice
,
tuple_element_t
<
i
.
value
,
Tuple
<
Ts
...
>>>::
value
)
{
// calculate new dimension
const
auto
&
dim
=
size
(
shape
.
At
(
num_i
));
const
auto
val
=
idxs
.
At
(
num_i
).
range
(
dim
);
return
make_tuple
(
val
);
}
else
{
// remove dimension for just value
return
Tuple
<>
{};
}
},
Number
<
Tuple
<
Ts
...
>::
Size
()
>
{});
// Remove empty tuples (deleted elements) and return
return
UnrollNestedTuple
<
0
,
1
>
(
new_shape
);
}
/**
* \brief Generate Freeze for each of nested shape.
*
* \param idx Tuple of start indices for slice.
* \param shape Shape which will be freezed.
* \return Generated freeze transforms.
*/
template
<
typename
T
,
typename
Shape
>
__host__
__device__
constexpr
auto
GenerateMultipleFreeze
(
T
idx
,
const
Shape
&
shape
)
{
const
auto
unrolled_shape
=
UnrollNestedTuple
(
shape
);
return
generate_tuple
(
[
&
](
auto
i
)
{
// dimension offset from idx
const
auto
dim
=
unrolled_shape
.
At
(
Number
<
i
>
{});
const
auto
dim_idx
=
idx
%
dim
;
idx
/=
dim
;
return
make_freeze_transform
(
dim_idx
);
},
Number
<
decltype
(
unrolled_shape
)
::
Size
()
>
{});
}
/**
* \brief Generate transforms for slice tensor.
*
* \param idx Tuple of start indices for slice.
* \param shape Shape which will be sliced.
* \return Generated transforms.
*/
template
<
typename
...
Ts
,
typename
Shape
>
__host__
__device__
constexpr
auto
GenerateSliceTransforms
(
const
Tuple
<
Ts
...
>&
idx
,
const
Shape
&
shape
)
{
// Pack each value in tuple to remove empty tuples after generation
auto
transforms
=
generate_tuple
(
[
&
](
auto
i
)
{
constexpr
auto
num_i
=
Number
<
i
>
{};
if
constexpr
(
is_detected
<
is_tuple
,
tuple_element_t
<
i
.
value
,
Tuple
<
Ts
...
>>>::
value
)
{
return
GenerateSliceTransforms
(
idx
.
At
(
num_i
),
shape
.
At
(
num_i
));
}
else
if
constexpr
(
is_detected
<
is_slice
,
tuple_element_t
<
i
.
value
,
Tuple
<
Ts
...
>>>::
value
)
{
const
auto
from
=
idx
.
At
(
num_i
).
from_
;
const
auto
dim
=
size
<
num_i
>
(
shape
);
const
auto
range
=
idx
.
At
(
num_i
).
range
(
dim
);
return
make_slice_transform
(
range
,
from
,
from
+
range
);
}
else
{
// remove dimension for just value
return
GenerateMultipleFreeze
(
idx
.
At
(
num_i
),
shape
.
At
(
num_i
));
}
},
Number
<
Tuple
<
Ts
...
>::
Size
()
>
{});
// Remove empty tuples (deleted elements) and return
return
UnrollNestedTuple
(
transforms
);
}
template
<
index_t
i
,
typename
LowerIndex
>
__host__
__device__
constexpr
auto
GetSequenceVal
(
const
ck
::
Freeze
<
LowerIndex
>&
)
{
// There is no output for Freeze transform
return
Sequence
<>
{};
}
template
<
index_t
i
,
typename
LowLength
,
typename
SliceBegin
,
typename
SliceEnd
>
__host__
__device__
constexpr
auto
GetSequenceVal
(
const
ck
::
Slice
<
LowLength
,
SliceBegin
,
SliceEnd
>&
)
{
return
Sequence
<
i
>
{};
}
template
<
index_t
i
>
__host__
__device__
constexpr
auto
GenerateUpperDims
(
const
Tuple
<>&
)
{
return
Tuple
<>
{};
}
template
<
index_t
i
,
typename
...
Transforms
>
__host__
__device__
constexpr
auto
GenerateUpperDims
(
const
Tuple
<
Transforms
...
>&
transforms
)
{
constexpr
auto
num_transforms
=
Tuple
<
Transforms
...
>::
Size
();
// Deduce Sequence element for specific transform
const
auto
current_elem
=
GetSequenceVal
<
i
>
(
transforms
.
At
(
Number
<
0
>
{}));
if
constexpr
(
is_same_v
<
decltype
(
current_elem
),
const
Sequence
<>>
)
{
const
auto
next_tuple
=
GenerateUpperDims
<
i
>
(
TupleSlice
<
1
,
num_transforms
>
(
transforms
));
return
concat_tuple
(
make_tuple
(
current_elem
),
next_tuple
);
}
else
{
// Increase i if current_elem is Slice transform
const
auto
next_tuple
=
GenerateUpperDims
<
i
+
1
>
(
TupleSlice
<
1
,
num_transforms
>
(
transforms
));
return
concat_tuple
(
make_tuple
(
current_elem
),
next_tuple
);
}
}
template
<
typename
...
Ts
,
typename
Shape
,
typename
FlattenDescriptor
>
__host__
__device__
constexpr
auto
GenerateSlicedDescriptor
(
const
Tuple
<
Ts
...
>&
idx
,
const
Shape
&
shape
,
const
FlattenDescriptor
&
flatten_desc
)
{
constexpr
auto
old_shape_dims
=
decltype
(
UnrollNestedTuple
(
shape
))
::
Size
();
const
auto
transforms
=
GenerateSliceTransforms
(
idx
,
shape
);
using
TransformsTupleType
=
decltype
(
transforms
);
const
auto
lower_dims
=
generate_tuple
([
&
](
auto
i
)
{
return
Sequence
<
i
.
value
>
{};
},
Number
<
old_shape_dims
>
{});
const
auto
upper_dims
=
decltype
(
GenerateUpperDims
<
0
>
(
TransformsTupleType
{})){};
return
transform_tensor_descriptor
(
flatten_desc
,
transforms
,
lower_dims
,
upper_dims
);
}
}
// namespace detail
}
// namespace
/**
* \brief Tensor wrapper that performs static and dynamic buffer logic.
* The tensor is based on a descriptor stored in the Layout. Additionally,
* tensor can be sliced or shifted using multi-index offset.
*
* \tparam BufferAddressSpace Memory type (Generic, Global, LDS, VGPR, SGPR).
* \tparam ElementType Element data type.
* \tparam Shape Tensor shape (layout component).
* \tparam UnrolledDescriptorType Flatten descriptor (layout component).
*/
template
<
MemoryTypeEnum
BufferAddressSpace
,
typename
ElementType
,
typename
Shape
,
typename
UnrolledDescriptorType
>
struct
Tensor
{
public:
using
ElementSpaceSize
=
decltype
(
Layout
<
Shape
,
UnrolledDescriptorType
>
{
Shape
{},
UnrolledDescriptorType
{}}.
GetElementSpaceSize
());
// SpaceSize type for buffer
using
TensorElementType
=
std
::
conditional_t
<
is_scalar_type
<
ElementType
>::
value
,
ElementType
,
typename
scalar_type
<
std
::
remove_const_t
<
ElementType
>>::
type
>
;
// DataType
static
constexpr
MemoryTypeEnum
TensorBufferAddressSpace
=
BufferAddressSpace
;
static
constexpr
bool
IsDynamicBuffer
=
!
(
BufferAddressSpace
==
MemoryTypeEnum
::
Sgpr
||
BufferAddressSpace
==
MemoryTypeEnum
::
Vgpr
);
__host__
__device__
Tensor
()
=
delete
;
__host__
__device__
constexpr
Tensor
(
ElementType
*
pointer
,
const
Layout
<
Shape
,
UnrolledDescriptorType
>&
layout
)
:
layout_
(
layout
),
buffer_
(
make_dynamic_buffer
<
BufferAddressSpace
>
(
pointer
,
layout
.
GetElementSpaceSize
())),
multi_idx_offset_
(
make_zero_multi_index
<
Shape
::
Size
()
>
()),
base_offset_
(
0
)
{
static_assert
(
IsDynamicBuffer
,
"Wrong BufferAddressSpace for register."
);
}
__host__
__device__
constexpr
Tensor
(
const
Layout
<
Shape
,
UnrolledDescriptorType
>&
layout
)
:
layout_
(
layout
),
multi_idx_offset_
(
make_zero_multi_index
<
Shape
::
Size
()
>
()),
base_offset_
(
0
)
{
static_assert
(
!
IsDynamicBuffer
,
"Wrong BufferAddressSpace for register."
);
}
__host__
__device__
constexpr
const
Layout
<
Shape
,
UnrolledDescriptorType
>&
GetLayout
()
const
{
return
layout_
;
}
/**
* \brief Get the new sliced tensor.
*
* \param idx Tuple of indices: slice(from,to) or scalar.
* \return Sliced tensor.
*/
template
<
typename
...
Ts
,
enable_if_t
<
detail
::
HasSlice
(
Tuple
<
Ts
...>{}),
bool
>
=
false
>
__host__
__device__
auto
operator
[](
const
Tuple
<
Ts
...
>&
idx
)
{
static_assert
(
IsDynamicBuffer
,
"Register slice is not supported"
);
const
auto
&
shape
=
layout_
.
GetShape
();
auto
new_shape
=
detail
::
GetSlicedShape
(
idx
,
shape
);
const
auto
&
flatten_desc
=
layout_
.
GetUnrolledDescriptor
();
auto
new_desc
=
detail
::
GenerateSlicedDescriptor
(
idx
,
shape
,
flatten_desc
);
const
auto
new_layout
=
Layout
<
decltype
(
new_shape
),
decltype
(
new_desc
)
>
(
new_shape
,
new_desc
);
// Update embed offset
base_offset_
-=
new_layout
(
make_tuple
(
Number
<
0
>
{}));
return
make_tensor
<
BufferAddressSpace
>
(
buffer_
.
p_data_
,
new_layout
);
}
template
<
typename
...
Ts
,
enable_if_t
<
detail
::
HasSlice
(
Tuple
<
Ts
...>{}),
bool
>
=
false
>
__host__
__device__
auto
operator
()(
const
Tuple
<
Ts
...
>&
idx
)
{
return
this
->
operator
[](
idx
);
}
template
<
typename
...
Idxs
,
enable_if_t
<
detail
::
HasSlice
(
Tuple
<
Idxs
...>{}),
bool
>
=
false
>
__host__
__device__
auto
operator
()(
Idxs
...
idxs
)
{
return
this
->
operator
[](
make_tuple
(
idxs
...));
}
/**
* \brief Getter of the tensor's const value reference.
*
* \param idx Tuple of indices.
* \return Requested value.
*/
template
<
typename
...
Ts
,
enable_if_t
<!
detail
::
HasSlice
(
Tuple
<
Ts
...>{}),
bool
>
=
false
>
__host__
__device__
const
TensorElementType
&
operator
[](
const
Tuple
<
Ts
...
>&
idx
)
const
{
if
constexpr
(
IsDynamicBuffer
)
{
const
index_t
offset
=
layout_
(
idx
)
+
base_offset_
;
return
buffer_
[
offset
];
}
else
{
constexpr
index_t
index_offset
=
Layout
<
Shape
,
UnrolledDescriptorType
>
{
Shape
{},
UnrolledDescriptorType
{}}.
template
operator
()
<
Tuple
<
Ts
...>
>
();
// Calculate and apply base offset in compile-time
constexpr
index_t
base_offset
=
Layout
<
Shape
,
UnrolledDescriptorType
>
{
Shape
{},
UnrolledDescriptorType
{}}.
template
operator
()
<
MultiIndex
<
Shape
::
Size
()>
>
();
return
buffer_
[
Number
<
index_offset
+
base_offset
>
{}];
}
}
template
<
typename
...
Ts
,
enable_if_t
<!
detail
::
HasSlice
(
Tuple
<
Ts
...>{}),
bool
>
=
false
>
__host__
__device__
const
TensorElementType
&
operator
()(
const
Tuple
<
Ts
...
>&
idx
)
const
{
return
this
->
operator
[](
idx
);
}
template
<
typename
...
Idxs
,
enable_if_t
<!
detail
::
HasSlice
(
Tuple
<
Idxs
...>{}),
bool
>
=
false
>
__host__
__device__
const
TensorElementType
&
operator
()(
Idxs
...
idxs
)
const
{
return
this
->
operator
[](
make_tuple
(
idxs
...));
}
/**
* \brief Getter of tensor value reference.
*
* \param idx Tuple of indices.
* \return Requested value.
*/
template
<
typename
...
Ts
,
enable_if_t
<!
detail
::
HasSlice
(
Tuple
<
Ts
...>{}),
bool
>
=
false
>
__host__
__device__
TensorElementType
&
operator
[](
const
Tuple
<
Ts
...
>&
idx
)
{
if
constexpr
(
IsDynamicBuffer
)
{
const
index_t
offset
=
layout_
(
idx
)
+
base_offset_
;
return
buffer_
(
offset
);
}
else
{
constexpr
index_t
index_offset
=
Layout
<
Shape
,
UnrolledDescriptorType
>
{
Shape
{},
UnrolledDescriptorType
{}}.
template
operator
()
<
Tuple
<
Ts
...>
>
();
// Apply embed offset (calculate in compiletime)
constexpr
index_t
base_offset
=
Layout
<
Shape
,
UnrolledDescriptorType
>
{
Shape
{},
UnrolledDescriptorType
{}}.
template
operator
()
<
MultiIndex
<
Shape
::
Size
()>
>
();
return
buffer_
(
Number
<
index_offset
+
base_offset
>
{});
}
}
template
<
typename
...
Ts
,
enable_if_t
<!
detail
::
HasSlice
(
Tuple
<
Ts
...>{}),
bool
>
=
false
>
__host__
__device__
TensorElementType
&
operator
()(
const
Tuple
<
Ts
...
>&
idx
)
{
return
this
->
operator
[](
idx
);
}
template
<
typename
...
Idxs
,
enable_if_t
<!
detail
::
HasSlice
(
Tuple
<
Idxs
...>{}),
bool
>
=
false
>
__host__
__device__
TensorElementType
&
operator
()(
Idxs
...
idxs
)
{
return
this
->
operator
[](
make_tuple
(
idxs
...));
}
/**
* \brief Get descriptor with all nested dimensions merged.
*
* \return Merged nests descriptor.
*/
__host__
__device__
constexpr
auto
GetMergedNestingDescriptor
()
{
return
layout_
.
GetMergedNestingDescriptor
();
}
/**
* \brief Get pointer to the data.
*
* \return Pointer.
*/
__host__
__device__
TensorElementType
*
GetPointer
()
const
{
return
buffer_
.
p_data_
;
}
__host__
__device__
constexpr
auto
&
GetBuffer
()
{
return
buffer_
;
}
__host__
__device__
constexpr
auto
&
GetBuffer
()
const
{
return
buffer_
;
}
/**
* \brief Get multi index offset to the data.
*
* \return Multi index offset.
*/
__host__
__device__
constexpr
auto
&
GetMultiIdxOffsets
()
const
{
return
multi_idx_offset_
;
}
/**
* \brief Apply multi index offset on the tensor.
*
* \param multi_idx_offset Multi index offset.
*/
template
<
typename
MultiIdxOffsets
>
__host__
__device__
constexpr
void
SetMultiIdxOffset
(
const
MultiIdxOffsets
multi_idx_offset
)
{
multi_idx_offset_
=
multi_idx_offset
;
base_offset_
+=
layout_
(
multi_idx_offset
);
}
private:
using
DynamicBufferType
=
DynamicBuffer
<
BufferAddressSpace
,
ElementType
,
ElementSpaceSize
,
true
/*InvalidElementUseNumericalZeroValue*/
>
;
using
StaticBufferType
=
std
::
conditional_t
<
is_scalar_type
<
ElementType
>::
value
,
StaticBuffer
<
BufferAddressSpace
,
ElementType
,
size
(
Shape
{}),
true
/*InvalidElementUseNumericalZeroValue*/
>
,
StaticBufferTupleOfVector
<
BufferAddressSpace
,
TensorElementType
,
size
(
Shape
{})
/
scalar_type
<
std
::
remove_const_t
<
ElementType
>>::
vector_size
,
scalar_type
<
std
::
remove_const_t
<
ElementType
>>::
vector_size
,
true
/*InvalidElementUseNumericalZeroValue*/
>>
;
// If register use static buffer, else use dynamic buffer
using
Buffer
=
std
::
conditional_t
<
IsDynamicBuffer
,
DynamicBufferType
,
StaticBufferType
>
;
const
Layout
<
Shape
,
UnrolledDescriptorType
>
layout_
;
Buffer
buffer_
;
// We use multi_idx_offset_ to enable the creation of a descriptor in
// compile time for partitions or tiles if tile shape and thread layout
// is known at compile time (We can use the same descriptor for each
// thread). Additionally, the copy between the static and dynamic buffer
// requires a descriptor known at compile time, so we can shift data using
// such multi_idx_offset_.
MultiIndex
<
Shape
::
Size
()
>
multi_idx_offset_
;
// Base offset and multi index offset are corresponding to exactly the
// same element in tensor ( and in physical memory ). Multi index offset
// is multi dimensional index. However base offset is calculated using
// tensor descriptor (thus all it's transforms) and is linear (1D).
// We store base_offset_ to avoid multiple recalculations.
index_t
base_offset_
;
};
}
// namespace wrapper
}
// namespace ck
include/ck/wrapper/traits/blockwise_gemm_xdl_traits.hpp
0 → 100644
View file @
29dcb956
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
namespace
ck
{
namespace
wrapper
{
/**
* \brief Traits for blockwise gemm xdl.
*
* \tparam MPerXDLValue The MFMA instruction size in M dimension.
* \tparam NPerXDLValue The MFMA instruction size in N dimension.
* \tparam MXdlPerWaveValue The number of MFMA instructions run by single
* wave in M dimension.
* \tparam NXdlPerWaveValue The number of MFMA instructions run by single
* wave in N dimension.
* \tparam K1Value The number of K-dim elements that are packed together as
* a separate logical dimension. Usually aligns with vector load size.
*/
template
<
index_t
MPerXDLValue
,
index_t
NPerXDLValue
,
index_t
MXdlPerWaveValue
,
index_t
NXdlPerWaveValue
,
index_t
K1Value
>
struct
BlockwisGemmXdlTraits
{
static
constexpr
index_t
MPerXDL
=
MPerXDLValue
;
static
constexpr
index_t
NPerXDL
=
NPerXDLValue
;
static
constexpr
index_t
MXdlPerWave
=
MXdlPerWaveValue
;
static
constexpr
index_t
NXdlPerWave
=
NXdlPerWaveValue
;
static
constexpr
index_t
K1
=
K1Value
;
};
// K1 = 4
struct
BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_4K1
:
BlockwisGemmXdlTraits
<
32
,
32
,
4
,
2
,
4
>
{
};
struct
BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_4K1
:
BlockwisGemmXdlTraits
<
32
,
32
,
2
,
4
,
4
>
{
};
struct
BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1
:
BlockwisGemmXdlTraits
<
32
,
32
,
2
,
2
,
4
>
{
};
// K1 = 8
struct
BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_8K1
:
BlockwisGemmXdlTraits
<
32
,
32
,
4
,
2
,
8
>
{
};
struct
BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_8K1
:
BlockwisGemmXdlTraits
<
32
,
32
,
2
,
4
,
8
>
{
};
struct
BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_8K1
:
BlockwisGemmXdlTraits
<
32
,
32
,
2
,
2
,
8
>
{
};
// K1 = 16
struct
BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_16K1
:
BlockwisGemmXdlTraits
<
32
,
32
,
4
,
2
,
16
>
{
};
struct
BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_16K1
:
BlockwisGemmXdlTraits
<
32
,
32
,
2
,
4
,
16
>
{
};
struct
BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_16K1
:
BlockwisGemmXdlTraits
<
32
,
32
,
2
,
2
,
16
>
{
};
}
// namespace wrapper
}
// namespace ck
include/ck/wrapper/utils/layout_utils.hpp
0 → 100644
View file @
29dcb956
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/utility/number.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/utility/tuple_helper.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/utility/sequence_helper.hpp"
#include "ck/utility/is_detected.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
namespace
ck
{
namespace
wrapper
{
// Disable from doxygen docs generation
/// @cond
// forward declaration
template
<
typename
Shape
,
typename
UnrolledDescriptorType
>
struct
Layout
;
template
<
typename
T
>
using
is_tuple
=
decltype
(
std
::
declval
<
T
&>
().
IsTuple
());
namespace
{
/**
* \brief Generate packed (column-major) strides if not passed
*
* \param shape Tensor shape.
* \return Generated column-major strides.
*/
template
<
typename
...
Ts
>
__host__
__device__
constexpr
static
auto
GenerateColumnMajorPackedStrides
(
const
Tuple
<
Ts
...
>&
shape
)
{
const
auto
unrolled_shape
=
UnrollNestedTuple
(
shape
);
return
generate_tuple
(
[
&
](
auto
i
)
{
if
constexpr
(
i
.
value
==
0
)
{
return
Number
<
1
>
{};
}
else
{
return
TupleReduce
<
Number
<
0
>
{}.
value
,
i
.
value
>
([](
auto
x
,
auto
y
)
{
return
x
*
y
;
},
unrolled_shape
);
}
},
Number
<
decltype
(
unrolled_shape
)
::
Size
()
>
{});
}
/**
* \brief Create naive tensor descriptor from nested shape.
*
* \param shape Tensor shape.
* \param strides Tensor strides.
* \return Unrolled descriptor
*/
template
<
typename
LayoutShape
,
typename
LayoutStrides
>
__host__
__device__
constexpr
auto
MakeUnrolledDescriptor
(
const
LayoutShape
&
shape
,
const
LayoutStrides
&
strides
)
{
const
auto
unrolled_shape
=
UnrollNestedTuple
(
shape
);
if
constexpr
(
is_same_v
<
LayoutStrides
,
Tuple
<>>
)
{
// if not passed, then generate
const
auto
unrolled_strides
=
GenerateColumnMajorPackedStrides
(
unrolled_shape
);
static_assert
(
unrolled_shape
.
Size
()
==
unrolled_strides
.
Size
(),
"Size of strides and shape are not consistent."
);
return
make_naive_tensor_descriptor
(
unrolled_shape
,
unrolled_strides
);
}
else
{
const
auto
unrolled_strides
=
UnrollNestedTuple
(
strides
);
static_assert
(
unrolled_shape
.
Size
()
==
unrolled_strides
.
Size
(),
"Size of strides and shape are not consistent."
);
return
make_naive_tensor_descriptor
(
unrolled_shape
,
unrolled_strides
);
}
}
}
// namespace
/// @endcond
// make_*
/**
* \brief Make layout function.
*
* \tparam Shape Shape for layout.
* \tparam Strides Strides for layout.
* \return Constructed layout.
*/
template
<
typename
Shape
,
typename
Strides
>
__host__
__device__
constexpr
auto
make_layout
(
const
Shape
&
shape
,
const
Strides
&
strides
)
{
using
UnrolledDescriptorType
=
decltype
(
MakeUnrolledDescriptor
(
Shape
{},
Strides
{}));
return
Layout
<
Shape
,
UnrolledDescriptorType
>
(
shape
,
MakeUnrolledDescriptor
(
shape
,
strides
));
}
/**
* \brief Make layout function with packed strides
* (column-major).
*
* \tparam Shape Shape for layout.
* \return Constructed layout.
*/
template
<
typename
Shape
>
__host__
__device__
constexpr
auto
make_layout
(
const
Shape
&
shape
)
{
using
UnrolledDescriptorType
=
decltype
(
MakeUnrolledDescriptor
(
Shape
{},
Tuple
<>
{}));
return
Layout
<
Shape
,
UnrolledDescriptorType
>
(
shape
,
MakeUnrolledDescriptor
(
shape
,
Tuple
<>
{}));
}
// Layout helpers
// get
/**
* \private
* \brief Get dim.
*
* \param dim Dimension.
* \return Returned the same dimension.
*/
template
<
typename
T
>
__host__
__device__
T
constexpr
get
(
const
T
&
dim
)
{
return
dim
;
}
/**
* \brief Get element from tuple (Shape/Strides/Idxs).
*
* \tparam idx Index to lookup.
* \param tuple Tuple to lookup.
* \return Requsted element.
*/
template
<
index_t
idx
,
typename
...
Dims
>
__host__
__device__
constexpr
auto
get
(
const
Tuple
<
Dims
...
>&
tuple
)
{
return
tuple
.
At
(
Number
<
idx
>
{});
}
/**
* \brief Get sub layout.
*
* \tparam idx Index to lookup.
* \param layout Layout to create sub layout.
* \return Requsted sub layout.
*/
template
<
index_t
idx
,
typename
Shape
,
typename
FlattenDesc
>
__host__
__device__
constexpr
auto
get
(
const
Layout
<
Shape
,
FlattenDesc
>&
layout
)
{
const
auto
&
shape
=
layout
.
GetShape
();
const
auto
new_shape
=
get
<
idx
>
(
shape
);
static_assert
(
is_detected
<
is_tuple
,
decltype
(
new_shape
)
>::
value
,
"Shape of sub layout must be tuple"
);
constexpr
auto
old_shape_dims
=
decltype
(
UnrollNestedTuple
(
shape
))
::
Size
();
constexpr
auto
new_shape_dims
=
decltype
(
UnrollNestedTuple
(
new_shape
))
::
Size
();
constexpr
auto
shape_offset
=
decltype
(
UnrollNestedTuple
(
TupleSlice
<
0
,
idx
>
(
shape
)))
::
Size
();
const
auto
unrolled_shape
=
UnrollNestedTuple
(
shape
);
const
auto
transforms
=
generate_tuple
(
[
&
](
auto
i
)
{
// Compare Idx with shape
if
constexpr
(
i
<
shape_offset
||
i
>=
shape_offset
+
new_shape_dims
)
{
// Remove dimension
return
make_freeze_transform
(
Number
<
0
>
{});
}
else
{
return
make_pass_through_transform
(
unrolled_shape
.
At
(
i
));
}
},
Number
<
old_shape_dims
>
{});
const
auto
lower_dims
=
generate_tuple
([
&
](
auto
i
)
{
return
Sequence
<
i
.
value
>
{};
},
Number
<
old_shape_dims
>
{});
const
auto
upper_dims
=
generate_tuple
(
[
&
](
auto
i
)
{
if
constexpr
(
i
<
shape_offset
||
i
>=
shape_offset
+
new_shape_dims
)
return
Sequence
<>
{};
else
{
return
Sequence
<
i
.
value
-
shape_offset
>
{};
}
},
Number
<
old_shape_dims
>
{});
const
auto
&
flatten_desc
=
layout
.
GetUnrolledDescriptor
();
auto
new_desc
=
transform_tensor_descriptor
(
flatten_desc
,
transforms
,
lower_dims
,
upper_dims
);
return
Layout
<
decltype
(
new_shape
),
decltype
(
new_desc
)
>
(
new_shape
,
new_desc
);
}
/**
* \brief Hierarchical get.
*
* \tparam Idxs Indexes to lookup.
* \param elem Element to lookup.
* \return Requsted element.
*/
template
<
index_t
Idx
,
index_t
...
Idxs
,
typename
T
>
__host__
__device__
constexpr
auto
get
(
const
T
&
elem
)
{
return
get
<
Idxs
...
>
(
get
<
Idx
>
(
elem
));
}
// size
/**
* \private
* \brief Get size.
*
* \param dim Size.
* \return Returned the same size.
*/
template
<
typename
T
>
__host__
__device__
T
constexpr
size
(
const
T
&
dim
)
{
return
dim
;
}
/**
* \brief Length get (product if tuple).
*
* \tparam idx Index to lookup.
* \param layout Layout to get Shape of.
* \return Requsted length.
*/
template
<
index_t
idx
,
typename
Shape
,
typename
UnrolledDescriptorType
>
__host__
__device__
constexpr
auto
size
(
const
Layout
<
Shape
,
UnrolledDescriptorType
>&
layout
)
{
return
layout
.
template
GetLength
<
idx
>();
}
/**
* \brief Shape size (product of dims).
*
* \param shape Shape to lookup.
* \return Requsted size.
*/
template
<
typename
...
ShapeDims
>
__host__
__device__
constexpr
auto
size
(
const
Tuple
<
ShapeDims
...
>&
shape
)
{
const
auto
unrolled_shape
=
UnrollNestedTuple
(
shape
);
return
TupleReduce
<
0
,
unrolled_shape
.
Size
()
>
([](
auto
x
,
auto
y
)
{
return
x
*
y
;
},
unrolled_shape
);
}
/**
* \brief Layout size (product of dims).
*
* \param layout Layout to calculate shape size.
* \return Requsted size.
*/
template
<
typename
Shape
,
typename
UnrolledDescriptorType
>
__host__
__device__
constexpr
auto
size
(
const
Layout
<
Shape
,
UnrolledDescriptorType
>&
layout
)
{
return
layout
.
GetLengths
();
}
/**
* \brief Length get from tuple (product if tuple).
*
* \tparam idx Index to lookup.
* \param tuple Tuple to lookup.
* \return Requsted length.
*/
template
<
index_t
idx
,
typename
...
Ts
>
__host__
__device__
constexpr
auto
size
(
const
Tuple
<
Ts
...
>&
tuple
)
{
return
size
(
tuple
.
At
(
Number
<
idx
>
{}));
}
/**
* \brief Hierarchical size.
*
* \tparam Idx First index to lookup (to avoid empty Idxs).
* \tparam Idxs Next indexes to lookup.
* \param elem Element to lookup.
* \return Requsted element.
*/
template
<
index_t
Idx
,
index_t
...
Idxs
,
typename
T
>
__host__
__device__
constexpr
auto
size
(
const
T
&
elem
)
{
return
size
(
get
<
Idx
,
Idxs
...
>
(
elem
));
}
// rank
/**
* \brief Get layout rank (num elements in shape).
*
* \param layout Layout to calculate rank.
* \return Requsted rank.
*/
template
<
typename
Shape
,
typename
UnrolledDescriptorType
>
__host__
__device__
constexpr
auto
rank
([[
maybe_unused
]]
const
Layout
<
Shape
,
UnrolledDescriptorType
>&
layout
)
{
return
Shape
::
Size
();
}
/**
* \brief Get tuple rank (num elements in tuple).
* Return 1 if scalar passed.
*
* \param tuple Tuple to calculate rank.
* \return Requsted rank.
*/
template
<
typename
...
Dims
>
__host__
__device__
constexpr
auto
rank
([[
maybe_unused
]]
const
Tuple
<
Dims
...
>&
tuple
)
{
return
Tuple
<
Dims
...
>::
Size
();
}
/**
* \private
* \brief Rank for scalar
*
* \param dim Dimension scalar.
* \return Returned 1.
*/
template
<
index_t
IDim
>
__host__
__device__
constexpr
index_t
rank
([[
maybe_unused
]]
const
Number
<
IDim
>&
dim
)
{
return
1
;
}
/**
* \private
* \brief Rank for scalar
*
* \param dim Dimension scalar.
* \return Returned 1.
*/
__host__
__device__
constexpr
index_t
rank
([[
maybe_unused
]]
const
index_t
&
dim
)
{
return
1
;
}
/**
* \brief Hierarchical rank.
*
* \tparam Idxs Indexes to lookup.
* \param elem Element to lookup.
* \return Requsted rank.
*/
template
<
index_t
...
Idxs
,
typename
T
>
__host__
__device__
constexpr
auto
rank
(
const
T
&
elem
)
{
return
rank
(
get
<
Idxs
...
>
(
elem
));
}
// depth
/**
* \brief Get depth of the layout shape (return 0 if scalar).
*
* \param layout Layout to calculate depth.
* \return Requsted depth.
*/
template
<
typename
Shape
,
typename
UnrolledDescriptorType
>
__host__
__device__
constexpr
auto
depth
(
const
Layout
<
Shape
,
UnrolledDescriptorType
>&
layout
)
{
const
auto
&
shape
=
layout
.
GetShape
();
return
TupleDepth
(
shape
);
}
/**
* \brief Get depth of the tuple. (return 0 if scalar)
*
* \param tuple Tuple to calculate depth.
* \return Requsted depth.
*/
template
<
typename
...
Dims
>
__host__
__device__
constexpr
auto
depth
(
const
Tuple
<
Dims
...
>&
tuple
)
{
return
TupleDepth
(
tuple
);
}
/**
* \private
* \brief Depth for scalar
*
* \param dim Scalar.
* \return Returned 0.
*/
template
<
index_t
IDim
>
__host__
__device__
constexpr
index_t
depth
([[
maybe_unused
]]
const
Number
<
IDim
>&
dim
)
{
return
0
;
}
/**
* \private
* \brief Depth for scalar
*
* \param dim Scalar.
* \return Returned 0.
*/
__host__
__device__
constexpr
index_t
depth
([[
maybe_unused
]]
const
index_t
&
dim
)
{
return
0
;
}
/**
* \brief Hierarchical depth.
*
* \tparam Idxs Indexes to lookup.
* \param elem Element to lookup.
* \return Requsted depth.
*/
template
<
index_t
...
Idxs
,
typename
T
>
__host__
__device__
constexpr
auto
depth
(
const
T
&
elem
)
{
return
depth
(
get
<
Idxs
...
>
(
elem
));
}
/**
* \brief Get Layout shape.
*
* \param layout Layout to get shape from.
* \return Requsted shape.
*/
template
<
typename
LayoutType
>
__host__
__device__
constexpr
const
auto
&
shape
(
const
LayoutType
&
layout
)
{
return
layout
.
GetShape
();
}
}
// namespace wrapper
}
// namespace ck
Prev
1
…
6
7
8
9
10
11
12
13
14
…
20
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment