Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
7a3b49e5
Commit
7a3b49e5
authored
Jun 25, 2022
by
Chao Liu
Browse files
Merge remote-tracking branch 'origin/develop' into contraction
parents
e07b3d8e
d3051d75
Changes
592
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
929 additions
and
107 deletions
+929
-107
include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp
include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp
+383
-0
include/ck/tensor_operation/gpu/grid/gridwise_unary_elementwise_1d.hpp
...nsor_operation/gpu/grid/gridwise_unary_elementwise_1d.hpp
+132
-0
include/ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp
...r_operation/gpu/thread/reduction_functions_threadwise.hpp
+19
-44
include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl.hpp
...tensor_operation/gpu/thread/threadwise_contraction_dl.hpp
+6
-2
include/ck/tensor_operation/gpu/thread/threadwise_gemm_dlops_v3.hpp
.../tensor_operation/gpu/thread/threadwise_gemm_dlops_v3.hpp
+3
-0
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp
...nsor_operation/gpu/thread/threadwise_tensor_slice_set.hpp
+7
-6
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+8
-7
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
...tion/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
+8
-7
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r3.hpp
...tion/gpu/thread/threadwise_tensor_slice_transfer_v3r3.hpp
+3
-0
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp
...tion/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp
+7
-6
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp
...tion/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp
+7
-3
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp
...tion/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp
+9
-8
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r2.hpp
...tion/gpu/thread/threadwise_tensor_slice_transfer_v6r2.hpp
+8
-7
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r3.hpp
...tion/gpu/thread/threadwise_tensor_slice_transfer_v6r3.hpp
+8
-7
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp
...ration/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp
+298
-0
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
+7
-6
include/ck/utility/amd_address_space.hpp
include/ck/utility/amd_address_space.hpp
+5
-4
include/ck/utility/amd_buffer_addressing.hpp
include/ck/utility/amd_buffer_addressing.hpp
+5
-0
include/ck/utility/amd_inline_asm.hpp
include/ck/utility/amd_inline_asm.hpp
+3
-0
include/ck/utility/amd_llvm_intrinsic.hpp
include/ck/utility/amd_llvm_intrinsic.hpp
+3
-0
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp
0 → 100644
View file @
7a3b49e5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/reduction_common.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp"
#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp"
#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
template
<
typename
GridwiseReduction
,
typename
InDataType
,
typename
OutDataType
,
typename
AccDataType
,
typename
GridDesc_M_K
>
__global__
void
kernel_softmax
(
const
GridDesc_M_K
in_grid_desc_m_k
,
const
GridDesc_M_K
out_grid_desc_m_k
,
index_t
block_group_size
,
index_t
num_k_block_tile_iteration
,
AccDataType
alpha
,
const
InDataType
*
const
__restrict__
p_in_value_global
,
AccDataType
beta
,
OutDataType
*
const
__restrict__
p_out_value_global
)
{
GridwiseReduction
::
Run
(
in_grid_desc_m_k
,
out_grid_desc_m_k
,
block_group_size
,
num_k_block_tile_iteration
,
alpha
,
p_in_value_global
,
beta
,
p_out_value_global
);
};
template
<
typename
InDataType
,
typename
OutDataType
,
typename
AccDataType
,
typename
GridDesc_M_K
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
InSrcVectorDim
,
index_t
InSrcVectorSize
,
index_t
OutDstVectorSize
>
struct
GridwiseSoftmax_mk_to_mk
{
static_assert
(((
InSrcVectorDim
==
0
&&
MThreadSliceSize
%
InSrcVectorSize
==
0
)
||
(
InSrcVectorDim
==
1
&&
KThreadSliceSize
%
InSrcVectorSize
==
0
))
&&
(
KThreadSliceSize
%
OutDstVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static
constexpr
bool
reorder_thread_cluster
=
(
InSrcVectorDim
==
0
);
using
ThreadClusterLengths_M_K
=
Sequence
<
MThreadClusterSize
,
KThreadClusterSize
>
;
using
ThreadBufferDimAccessOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
using
ThreadClusterArrangeOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
static
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
ThreadClusterLengths_M_K
{},
ThreadClusterArrangeOrder
{});
using
ThreadReduceSrcDesc_M_K
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{})));
using
ThreadReduceDstDesc_M
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{})));
using
BlockwiseMaxReduce
=
PartitionedBlockwiseReduction
<
AccDataType
,
BlockSize
,
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
,
reduce
::
Max
,
false
>
;
// PropagateNan
using
ThreadwiseMaxReduce
=
ThreadwiseReduction
<
AccDataType
,
ThreadReduceSrcDesc_M_K
,
ThreadReduceDstDesc_M
,
reduce
::
Max
,
false
>
;
// PropagateNan
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
__device__
static
void
Run
(
const
GridDesc_M_K
&
in_grid_desc_m_k
,
const
GridDesc_M_K
&
out_grid_desc_m_k
,
index_t
block_group_size
,
index_t
num_k_block_tile_iteration
,
AccDataType
alpha
,
const
InDataType
*
const
__restrict__
p_in_value_global
,
AccDataType
beta
,
OutDataType
*
const
__restrict__
p_out_value_global
)
{
// LDS
__shared__
AccDataType
p_reduce_work_buffer
[
BlockSize
];
auto
out_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_out_value_global
,
out_grid_desc_m_k
.
GetElementSpaceSize
());
auto
reduce_work_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_reduce_work_buffer
,
BlockSize
);
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
in_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
out_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
max_value_buf
;
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
max_value_buf
(
I
)
=
reduce
::
Max
::
template
GetIdentityValue
<
AccDataType
>();
});
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
accu_value_buf
;
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
accu_value_buf
(
I
)
=
reduce
::
Add
::
template
GetIdentityValue
<
AccDataType
>();
});
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
const
index_t
blkgroup_id
=
block_global_id
/
block_group_size
;
const
index_t
block_local_id
=
block_global_id
%
block_group_size
;
const
auto
thread_cluster_idx
=
thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
thread_local_id
));
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
I0
];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
I1
];
const
index_t
reduceSizePerBlock
=
K_BlockTileSize
*
num_k_block_tile_iteration
;
using
ThreadBufferLengths
=
Sequence
<
MThreadSliceSize
,
KThreadSliceSize
>
;
constexpr
auto
thread_buffer_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{}));
auto
threadwise_src_load
=
ThreadwiseTensorSliceTransfer_v2
<
InDataType
,
AccDataType
,
GridDesc_M_K
,
decltype
(
thread_buffer_desc
),
ThreadBufferLengths
,
ThreadBufferDimAccessOrder
,
InSrcVectorDim
,
InSrcVectorSize
,
1
,
false
>
(
in_grid_desc_m_k
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_local_id
*
reduceSizePerBlock
+
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_dst_load
=
ThreadwiseTensorSliceTransfer_v2
<
OutDataType
,
AccDataType
,
GridDesc_M_K
,
decltype
(
thread_buffer_desc
),
ThreadBufferLengths
,
ThreadBufferDimAccessOrder
,
InSrcVectorDim
,
InSrcVectorSize
,
1
,
false
>
(
out_grid_desc_m_k
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_local_id
*
reduceSizePerBlock
+
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_dst_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
OutDataType
,
decltype
(
thread_buffer_desc
),
GridDesc_M_K
,
PassThroughOp
,
ThreadBufferLengths
,
ThreadBufferDimAccessOrder
,
InSrcVectorDim
,
OutDstVectorSize
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
out_grid_desc_m_k
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_local_id
*
reduceSizePerBlock
+
thread_k_cluster_id
*
KThreadSliceSize
),
PassThroughOp
{});
constexpr
auto
in_thread_copy_fwd_step
=
make_multi_index
(
0
,
K_BlockTileSize
);
constexpr
auto
in_thread_copy_bwd_step
=
make_multi_index
(
0
,
-
K_BlockTileSize
);
///
/// max(x)
///
const
auto
in_global_val_buf_oob_non_zero
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_value_global
,
in_grid_desc_m_k
.
GetElementSpaceSize
(),
reduce
::
Max
::
template
GetIdentityValue
<
InDataType
>());
index_t
reducedTiles
=
0
;
do
{
threadwise_src_load
.
Run
(
in_grid_desc_m_k
,
in_global_val_buf_oob_non_zero
,
thread_buffer_desc
,
make_tuple
(
I0
,
I0
),
in_thread_buf
);
ThreadwiseMaxReduce
::
Reduce
(
in_thread_buf
,
max_value_buf
);
threadwise_src_load
.
MoveSrcSliceWindow
(
in_grid_desc_m_k
,
in_thread_copy_fwd_step
);
reducedTiles
++
;
}
while
(
reducedTiles
<
num_k_block_tile_iteration
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}(
[
&
](
auto
I
)
{
BlockwiseMaxReduce
::
Reduce
(
reduce_work_buf
,
max_value_buf
(
I
));
});
threadwise_src_load
.
MoveSrcSliceWindow
(
in_grid_desc_m_k
,
in_thread_copy_bwd_step
);
///
/// sum(exp(x - max(x)))
///
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
accu_value_buf
(
I
)
=
reduce
::
Add
::
template
GetIdentityValue
<
AccDataType
>();
});
// Normally, 0 as invalid element value is adequate since 0 makes no contribution to
// accumulated result. However, in stable softmax, all values 0s or not are subtracted by
// another value_max. As numbers become non-zero, effectively it allows invalid values to
// slip through and contribute to the accumulated result.
//
// The trick here is leveraging the fact that many math functions (add, sub, exp, ...)
// propagate NaNs when operands have NaNs involved. By initialiing invalid element value
// with NaN, an invalid value doing math manipulations is still NaN, which in turn can still
// be identified as an invalid value. We can then discard the invalid values which
// originally failed the bound check during accumulation. This allows to ignore values that
// failed bound check even after multiple math manipulations.
const
auto
in_global_val_buf_oob_nan
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_value_global
,
in_grid_desc_m_k
.
GetElementSpaceSize
(),
NumericLimits
<
InDataType
>::
QuietNaN
());
using
BlockwiseSumReduce
=
PartitionedBlockwiseReduction
<
AccDataType
,
BlockSize
,
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
,
reduce
::
Add
,
false
,
// ignored
detail
::
AccumulateWithNanIgnore
<
reduce
::
Add
,
AccDataType
>>
;
using
ThreadwiseSumReduce
=
ThreadwiseReduction
<
AccDataType
,
ThreadReduceSrcDesc_M_K
,
ThreadReduceDstDesc_M
,
reduce
::
Add
,
false
,
// ignored
detail
::
AccumulateWithNanIgnore
<
reduce
::
Add
,
AccDataType
>>
;
reducedTiles
=
0
;
do
{
threadwise_src_load
.
Run
(
in_grid_desc_m_k
,
in_global_val_buf_oob_nan
,
thread_buffer_desc
,
make_tuple
(
I0
,
I0
),
in_thread_buf
);
// do element-wise pre-reduction operation
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset
=
thread_buffer_desc
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
in_thread_buf
(
Number
<
offset
>
{})
=
math
::
exp
(
in_thread_buf
(
Number
<
offset
>
{})
-
max_value_buf
(
iM
));
});
});
ThreadwiseSumReduce
::
Reduce
(
in_thread_buf
,
accu_value_buf
);
threadwise_src_load
.
MoveSrcSliceWindow
(
in_grid_desc_m_k
,
in_thread_copy_bwd_step
);
reducedTiles
++
;
}
while
(
reducedTiles
<
num_k_block_tile_iteration
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
BlockwiseSumReduce
::
Reduce
(
reduce_work_buf
,
accu_value_buf
(
I
));
// block_sync_lds();
});
threadwise_src_load
.
MoveSrcSliceWindow
(
in_grid_desc_m_k
,
in_thread_copy_fwd_step
);
///
/// softmax
///
reducedTiles
=
0
;
if
(
float_equal_zero
{}(
beta
))
{
do
{
threadwise_src_load
.
Run
(
in_grid_desc_m_k
,
in_global_val_buf_oob_nan
,
thread_buffer_desc
,
make_tuple
(
I0
,
I0
),
in_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
// out = alpha * exp(x - max(x)) / sum(exp(x - max(x)))
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset
=
thread_buffer_desc
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
out_thread_buf
(
Number
<
offset
>
{})
=
alpha
*
math
::
exp
(
in_thread_buf
(
Number
<
offset
>
{})
-
max_value_buf
(
iM
))
/
accu_value_buf
(
iM
);
});
});
threadwise_dst_store
.
Run
(
thread_buffer_desc
,
make_tuple
(
I0
,
I0
),
out_thread_buf
,
out_grid_desc_m_k
,
out_global_val_buf
);
threadwise_src_load
.
MoveSrcSliceWindow
(
in_grid_desc_m_k
,
in_thread_copy_fwd_step
);
threadwise_dst_store
.
MoveDstSliceWindow
(
out_grid_desc_m_k
,
in_thread_copy_fwd_step
);
reducedTiles
++
;
}
while
(
reducedTiles
<
num_k_block_tile_iteration
);
}
else
{
do
{
threadwise_src_load
.
Run
(
in_grid_desc_m_k
,
in_global_val_buf_oob_nan
,
thread_buffer_desc
,
make_tuple
(
I0
,
I0
),
in_thread_buf
);
threadwise_dst_load
.
Run
(
out_grid_desc_m_k
,
out_global_val_buf
,
thread_buffer_desc
,
make_tuple
(
I0
,
I0
),
out_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
// out = alpha * exp(x - max(x)) / sum(exp(x - max(x))) + beta * prior_out
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset
=
thread_buffer_desc
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
out_thread_buf
(
Number
<
offset
>
{})
=
alpha
*
math
::
exp
(
in_thread_buf
(
Number
<
offset
>
{})
-
max_value_buf
(
iM
))
/
accu_value_buf
(
iM
)
+
beta
*
out_thread_buf
(
Number
<
offset
>
{});
});
});
threadwise_dst_store
.
Run
(
thread_buffer_desc
,
make_tuple
(
I0
,
I0
),
out_thread_buf
,
out_grid_desc_m_k
,
out_global_val_buf
);
threadwise_src_load
.
MoveSrcSliceWindow
(
in_grid_desc_m_k
,
in_thread_copy_fwd_step
);
threadwise_dst_store
.
MoveDstSliceWindow
(
out_grid_desc_m_k
,
in_thread_copy_fwd_step
);
threadwise_dst_load
.
MoveSrcSliceWindow
(
out_grid_desc_m_k
,
in_thread_copy_fwd_step
);
reducedTiles
++
;
}
while
(
reducedTiles
<
num_k_block_tile_iteration
);
}
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_unary_elementwise_1d.hpp
0 → 100644
View file @
7a3b49e5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
template
<
typename
GridwiseUEltwise
,
typename
ADataType
,
typename
BDataType
,
typename
GridDesc_M0
,
typename
ElementwiseFunctor
>
__global__
void
kernel_unary_elementwise_1d
(
const
ADataType
*
__restrict__
p_a_global
,
BDataType
*
__restrict__
p_b_global
,
const
GridDesc_M0
a_grid_desc_m0
,
const
GridDesc_M0
b_grid_desc_m0
,
const
ElementwiseFunctor
functor
)
{
GridwiseUEltwise
::
Run
(
p_a_global
,
p_b_global
,
a_grid_desc_m0
,
b_grid_desc_m0
,
functor
);
}
template
<
typename
ADataType
,
typename
BDataType
,
typename
GridDesc_M0
,
typename
ElementwiseFunctor
,
index_t
ScalarPerVector
>
struct
GridwiseUnaryElementwise_1D
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
thread_desc_m0
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
ScalarPerVector
>
{}));
using
PassThrough
=
tensor_operation
::
element_wise
::
PassThrough
;
static
__device__
auto
CalculateElementwiseIndex
()
{
const
index_t
global_thread_id
=
get_thread_global_1d_id
();
return
make_multi_index
(
global_thread_id
*
ScalarPerVector
);
}
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
GridDesc_M0
a_grid_desc_m0
,
const
GridDesc_M0
b_grid_desc_m0
)
{
return
a_grid_desc_m0
.
GetLength
(
I0
)
==
b_grid_desc_m0
.
GetLength
(
I0
);
}
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
const
index_t
tensor_size
)
{
const
index_t
grid_size
=
math
::
integer_divide_ceil
(
tensor_size
,
256
*
ScalarPerVector
);
return
grid_size
;
}
__device__
static
void
Run
(
const
ADataType
*
__restrict__
p_a_global
,
BDataType
*
__restrict__
p_b_global
,
const
GridDesc_M0
a_grid_desc_m0
,
const
GridDesc_M0
b_grid_desc_m0
,
const
ElementwiseFunctor
functor
)
{
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_global
,
a_grid_desc_m0
.
GetElementSpaceSize
());
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_global
,
b_grid_desc_m0
.
GetElementSpaceSize
());
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ADataType
,
ScalarPerVector
,
true
>
a_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
BDataType
,
ScalarPerVector
,
true
>
b_thread_buf
;
const
auto
thread_store_global_offset
=
CalculateElementwiseIndex
();
auto
a_global_load
=
ThreadwiseTensorSliceTransfer_v2
<
ADataType
,
ADataType
,
GridDesc_M0
,
decltype
(
thread_desc_m0
),
Sequence
<
ScalarPerVector
>
,
// SliceLengths
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
ScalarPerVector
,
1
,
// SrcScalarStrideInVector
false
>
{
a_grid_desc_m0
,
thread_store_global_offset
};
auto
b_global_write
=
ThreadwiseTensorSliceTransfer_v1r3
<
BDataType
,
BDataType
,
decltype
(
thread_desc_m0
),
GridDesc_M0
,
PassThrough
,
Sequence
<
ScalarPerVector
>
,
// SliceLengths
Sequence
<
0
>
,
// DimAccessOrder
0
,
// DstVectorDim
ScalarPerVector
,
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarStrideInVector
false
>
{
b_grid_desc_m0
,
thread_store_global_offset
,
PassThrough
{}};
const
index_t
blockSize
=
get_block_size
();
const
index_t
blockPerGrid
=
get_grid_size
();
const
auto
m0
=
b_grid_desc_m0
.
GetLength
(
I0
);
const
index_t
loop_step
=
blockPerGrid
*
blockSize
*
ScalarPerVector
;
const
auto
loop_step_index
=
make_multi_index
(
loop_step
);
index_t
num_iter
=
m0
/
(
loop_step
);
do
{
// read and process ScalarPerVector elements
a_global_load
.
Run
(
a_grid_desc_m0
,
a_global_buf
,
thread_desc_m0
,
make_tuple
(
I0
),
a_thread_buf
);
static_for
<
0
,
ScalarPerVector
,
1
>
{}([
&
](
auto
m
)
{
constexpr
auto
offset
=
thread_desc_m0
.
CalculateOffset
(
make_tuple
(
m
));
functor
(
b_thread_buf
(
Number
<
offset
>
{}),
a_thread_buf
(
Number
<
offset
>
{}));
});
b_global_write
.
Run
(
thread_desc_m0
,
make_tuple
(
I0
),
// SrcSliceOriginIdx
b_thread_buf
,
b_grid_desc_m0
,
b_global_buf
);
a_global_load
.
MoveSrcSliceWindow
(
a_grid_desc_m0
,
loop_step_index
);
b_global_write
.
MoveDstSliceWindow
(
b_grid_desc_m0
,
loop_step_index
);
}
while
(
--
num_iter
);
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp
View file @
7a3b49e5
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2020 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef CK_REDUCTION_FUNCTIONS_THREADWISE_HPP
#define CK_REDUCTION_FUNCTIONS_THREADWISE_HPP
#include "reduction_functions_accumulate.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/reduction_functions_accumulate.hpp"
namespace
ck
{
...
...
@@ -39,7 +16,9 @@ template <typename AccDataType,
typename
SrcThreadDesc_M_K
,
typename
DstThreadDesc_M
,
typename
OpReduce
,
bool
PropagateNan
>
bool
PropagateNan
,
typename
Accumulation
=
detail
::
AccumulateWithNanCheck
<
PropagateNan
,
OpReduce
,
AccDataType
>
>
struct
ThreadwiseReduction
{
static
constexpr
auto
src_thread_desc_m_k
=
SrcThreadDesc_M_K
{};
...
...
@@ -51,8 +30,6 @@ struct ThreadwiseReduction
static_assert
(
src_length_m
==
dst_length_m
,
"lengths of source and dst buffer must match!"
);
using
Accumulation
=
detail
::
AccumulateWithNanCheck
<
PropagateNan
,
OpReduce
,
AccDataType
>
;
template
<
typename
SrcBufferType
,
typename
DstBufferType
>
__device__
static
void
Reduce
(
const
SrcBufferType
&
src_buf
,
DstBufferType
&
dst_buf
)
{
...
...
@@ -73,12 +50,15 @@ struct ThreadwiseReduction
// 2) DstDesc is known at compile-time
// 3) SrcBuffer is static buffer
// 4) DstBuffer is static buffer
template
<
typename
AccDataType
,
typename
IndexDataType
,
typename
SrcThreadDesc_M_K
,
typename
DstThreadDesc_M
,
typename
OpReduce
,
bool
PropagateNan
>
template
<
typename
AccDataType
,
typename
IndexDataType
,
typename
SrcThreadDesc_M_K
,
typename
DstThreadDesc_M
,
typename
OpReduce
,
bool
PropagateNan
,
typename
Accumulation
=
detail
::
AccumulateWithIndexAndNanCheck
<
PropagateNan
,
OpReduce
,
AccDataType
,
IndexDataType
>
>
struct
ThreadwiseReductionWithIndex
{
static
constexpr
auto
src_thread_desc_m_k
=
SrcThreadDesc_M_K
{};
...
...
@@ -90,9 +70,6 @@ struct ThreadwiseReductionWithIndex
static_assert
(
src_length_m
==
dst_length_m
,
"lengths of source and dst buffer must match!"
);
using
Accumulation
=
detail
::
AccumulateWithIndexAndNanCheck
<
PropagateNan
,
OpReduce
,
AccDataType
,
IndexDataType
>
;
template
<
typename
SrcValueBufferType
,
typename
SrcIndexBufferType
,
typename
DstValueBufferType
,
...
...
@@ -117,6 +94,4 @@ struct ThreadwiseReductionWithIndex
};
};
};
// end of namespace ck
#endif
}
// namespace ck
include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl.hpp
View file @
7a3b49e5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "common_header.hpp"
#include "math.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/utility/math.hpp"
namespace
ck
{
...
...
include/ck/tensor_operation/gpu/thread/threadwise_gemm_dlops_v3.hpp
View file @
7a3b49e5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_THREADWISE_GEMM_DLOPS_V3_HPP
#define CK_THREADWISE_GEMM_DLOPS_V3_HPP
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp
View file @
7a3b49e5
#ifndef CK_THREADWISE_TENSOR_SET_HPP
#define CK_THREADWISE_TENSOR_SET_HPP
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
namespace
ck
{
...
...
@@ -56,4 +58,3 @@ struct ThreadwiseTensorSliceSet_v1
};
}
// namespace ck
#endif
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
7a3b49e5
#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_HPP
#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_HPP
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_space_filling_curve.hpp"
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
namespace
ck
{
...
...
@@ -1168,4 +1170,3 @@ struct ThreadwiseTensorSliceTransfer_v4
};
}
// namespace ck
#endif
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
View file @
7a3b49e5
#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V3R1_HPP
#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V3R1_HPP
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "static_tensor.hpp"
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor/static_tensor.hpp"
namespace
ck
{
...
...
@@ -789,4 +791,3 @@ struct ThreadwiseTensorSliceTransfer_v3r1
};
}
// namespace ck
#endif
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r3.hpp
View file @
7a3b49e5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V3R3_HPP
#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V3R3_HPP
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp
View file @
7a3b49e5
#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V4R1_HPP
#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V4R1_HPP
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
namespace
ck
{
// Assume:
...
...
@@ -171,4 +173,3 @@ struct ThreadwiseTensorSliceTransfer_v4r1
};
}
// namespace ck
#endif
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp
View file @
7a3b49e5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
namespace
ck
{
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp
View file @
7a3b49e5
#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V6R1_HPP
#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V6R1_HPP
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_space_filling_curve.hpp"
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
namespace
ck
{
...
...
@@ -206,7 +208,6 @@ struct ThreadwiseTensorSliceTransfer_v6r1
SrcCoord
src_coord_
;
DstCoord
dst_coord_
;
const
ElementwiseOperation
element_op_
;
};
// namespace ck
};
}
// namespace ck
#endif
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r2.hpp
View file @
7a3b49e5
#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V6R2_HPP
#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V6R2_HPP
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_space_filling_curve.hpp"
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
namespace
ck
{
...
...
@@ -256,4 +258,3 @@ struct ThreadwiseTensorSliceTransfer_v6r2
};
}
// namespace ck
#endif
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r3.hpp
View file @
7a3b49e5
#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V6R3_HPP
#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V6R3_HPP
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_space_filling_curve.hpp"
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
namespace
ck
{
...
...
@@ -306,4 +308,3 @@ struct ThreadwiseTensorSliceTransfer_v6r3
};
}
// namespace ck
#endif
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp
0 → 100644
View file @
7a3b49e5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
namespace
ck
{
// Thread-level multi-source, multi-destination tensor slice data movement
// Assume:
// 1. All sources and destinations are DynamicBuffer
// 2. Same VectorDim and ScalerPerVector for all sources and destinations
// 3. DstInMemOps are per destination tensor
// 4. ThreadTransferSrcResetCoordinateAfterRunFlags are per source tensor
// 5. ThreadTransferDstResetCoordinateAfterRunFlags are per destination tensor
// 6. Does not need to know src_descs and dst_descs at compile-time
// 7. Does not need to know src_slice_origins and dst_slice_origins at compile-time,
//
// Does following things to avoid scratch memory issue
// 1. Use StaticallyIndexedArray or vector_type instead of C array for thread buffer
// 2. Pass tensor descritpors by reference (or tuple of references)
// 3. Does not keep reference to tensor descriptor
// 4. Does not construct new tensor coordinate when call Run()
template
<
typename
SrcDatas
,
typename
DstDatas
,
typename
SrcDescs
,
typename
DstDescs
,
typename
ElementwiseOperation
,
typename
DstInMemOps
,
// Sequence<InMemoryDataOperationEnum ...>
typename
SliceLengths
,
typename
DimAccessOrder
,
index_t
VectorDim
,
index_t
ScalarPerVector
,
typename
SrcResetCoordinateAfterRunFlags
,
// Sequence<bool ...>
typename
DstResetCoordinateAfterRunFlags
>
// Sequence<bool ...>
struct
ThreadwiseTensorSliceTransfer_v7
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
static
constexpr
index_t
nSrc
=
SrcDescs
::
Size
();
static
constexpr
index_t
nDst
=
DstDescs
::
Size
();
using
Index
=
MultiIndex
<
nDim
>
;
// return a tuple of coordiantes for a tuple of tensor
template
<
typename
Descs
,
typename
Indices
,
enable_if_t
<
Descs
::
Size
()
==
Indices
::
Size
(),
bool
>
=
false
>
static
constexpr
auto
MakeCoordinates
(
const
Descs
&
descs
,
const
Indices
&
indices
)
{
return
generate_tuple
([
&
](
auto
i
)
{
return
make_tensor_coordinate
(
descs
[
i
],
indices
[
i
]);
},
Number
<
Descs
::
Size
()
>
{});
}
using
SrcCoords
=
decltype
(
MakeCoordinates
(
SrcDescs
{},
StaticallyIndexedArray
<
Index
,
nSrc
>
{}));
using
DstCoords
=
decltype
(
MakeCoordinates
(
DstDescs
{},
StaticallyIndexedArray
<
Index
,
nDst
>
{}));
// scalar per access on each dim
// FIXME: don't use lambda_scalar_per_access
static
constexpr
auto
scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
VectorDim
,
ScalarPerVector
>
{},
Number
<
nDim
>
{});
using
SpaceFillingCurve
=
SpaceFillingCurve
<
SliceLengths
,
DimAccessOrder
,
remove_cv_t
<
decltype
(
scalar_per_access
)
>>
;
__device__
constexpr
ThreadwiseTensorSliceTransfer_v7
(
const
SrcDescs
&
src_descs
,
const
StaticallyIndexedArray
<
Index
,
nSrc
>&
src_slice_origins
,
const
DstDescs
&
dst_descs
,
const
StaticallyIndexedArray
<
Index
,
nDst
>&
dst_slice_origins
,
const
ElementwiseOperation
&
element_op
)
:
src_coords_
(
MakeCoordinates
(
src_descs
,
src_slice_origins
)),
dst_coords_
(
MakeCoordinates
(
dst_descs
,
dst_slice_origins
)),
element_op_
(
element_op
)
{
static_assert
(
SliceLengths
::
At
(
Number
<
VectorDim
>
{})
%
ScalarPerVector
==
0
,
"wrong! cannot evenly divide"
);
}
template
<
typename
Indices
,
enable_if_t
<
SrcDescs
::
Size
()
==
Indices
::
Size
(),
bool
>
=
false
>
__device__
void
SetSrcSliceOrigins
(
const
SrcDescs
&
src_descs
,
const
Indices
&
src_slice_origin_idxs
)
{
static_for
<
0
,
nSrc
,
1
>
{}([
&
](
auto
i
)
{
src_coords_
(
i
)
=
make_tensor_coordinate
(
src_descs
[
i
],
src_slice_origin_idxs
[
i
]);
});
}
template
<
typename
Indices
,
enable_if_t
<
DstDescs
::
Size
()
==
Indices
::
Size
(),
bool
>
=
false
>
__device__
void
SetDstSliceOrigins
(
const
DstDescs
&
dst_descs
,
const
Indices
&
dst_slice_origin_idxs
)
{
static_for
<
0
,
nDst
,
1
>
{}([
&
](
auto
i
)
{
dst_coords_
(
i
)
=
make_tensor_coordinate
(
dst_descs
[
i
],
dst_slice_origin_idxs
[
i
]);
});
}
// SrcDescs: Tuple<const SrcDesc0&, const SrcDesc1&, ...>
// SrcBuffers: Tuple<const SrcBuffer0&, const SrcBuffer1&, ...>
// DstDescs: Tuple<const DstDesc0&, const DstDesc1&, ...>
// DstBuffers: Tuple<const DstBuffer0&, const DstBuffer1&, ...>
template
<
typename
SrcBuffers
,
typename
DstBuffers
,
enable_if_t
<
SrcDescs
::
Size
()
==
SrcBuffers
::
Size
()
&&
DstDescs
::
Size
()
==
DstBuffers
::
Size
(),
bool
>
=
false
>
__device__
void
Run
(
const
SrcDescs
&
src_descs
,
const
SrcBuffers
&
src_bufs
,
const
DstDescs
&
dst_descs
,
DstBuffers
dst_bufs
)
{
auto
generate_vectors
=
[
&
](
auto
data_types
)
{
constexpr
index_t
num
=
data_types
.
Size
();
return
generate_tuple
(
[
&
](
auto
i
)
{
using
DataType
=
remove_cvref_t
<
decltype
(
data_types
[
i
])
>
;
return
vector_type_maker_t
<
DataType
,
ScalarPerVector
>
{};
},
Number
<
num
>
{});
};
constexpr
auto
num_access
=
SpaceFillingCurve
::
GetNumOfAccess
();
// loop over space-filling curve
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
iAccess
)
{
auto
src_vectors
=
generate_vectors
(
SrcDatas
{});
auto
dst_vectors
=
generate_vectors
(
DstDatas
{});
// copy data from src_bufs into src_vectors
static_for
<
0
,
nSrc
,
1
>
{}([
&
](
auto
i
)
{
using
src_vector_t
=
typename
remove_cvref_t
<
decltype
(
src_vectors
[
i
])
>::
type
;
const
bool
is_src_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_descs
[
i
],
src_coords_
[
i
]);
src_vectors
(
i
).
template
AsType
<
src_vector_t
>()(
I0
)
=
src_bufs
[
i
].
template
Get
<
src_vector_t
>(
src_coords_
[
i
].
GetOffset
(),
is_src_valid
);
});
// apply pointwise function
static_for
<
0
,
ScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
// get reference to src data
const
auto
src_data_refs
=
generate_tie
(
// return type should be lvalue
[
&
](
auto
iSrc
)
->
const
auto
&
{
using
SrcData
=
remove_cvref_t
<
tuple_element_t
<
iSrc
.
value
,
SrcDatas
>>
;
return
src_vectors
[
iSrc
].
template
AsType
<
SrcData
>()[
i
];
},
Number
<
nSrc
>
{});
// get reference to dst data
auto
dst_data_refs
=
generate_tie
(
// return type should be lvalue
[
&
](
auto
iDst
)
->
auto
&
{
using
DstData
=
remove_cvref_t
<
tuple_element_t
<
iDst
.
value
,
DstDatas
>>
;
return
dst_vectors
(
iDst
).
template
AsType
<
DstData
>()(
i
);
},
Number
<
nDst
>
{});
// apply pointwise function
// pointwise function signature:
// element_op_(dst_data_refs[I0],
// dst_data_refs[I1],
// ...,
// src_data_refs[I0],
// src_data_refs[I1],
// ...)
unpack2
(
element_op_
,
dst_data_refs
,
src_data_refs
);
});
// copy data from buf_vectors into dst_bufs
static_for
<
0
,
nDst
,
1
>
{}([
&
](
auto
i
)
{
using
dst_vector_t
=
typename
remove_cvref_t
<
decltype
(
dst_vectors
[
i
])
>::
type
;
const
bool
is_dst_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
dst_descs
[
i
],
dst_coords_
[
i
]);
constexpr
InMemoryDataOperationEnum
DstInMemOp
=
static_cast
<
InMemoryDataOperationEnum
>
(
DstInMemOps
::
At
(
i
.
value
));
dst_bufs
(
i
).
template
Update
<
DstInMemOp
,
dst_vector_t
>(
dst_coords_
[
i
].
GetOffset
(),
is_dst_valid
,
dst_vectors
[
i
].
template
AsType
<
dst_vector_t
>()[
I0
]);
});
// move coordinate
if
constexpr
(
iAccess
.
value
!=
num_access
-
1
)
{
constexpr
auto
forward_step
=
SpaceFillingCurve
::
GetForwardStep
(
iAccess
);
static_for
<
0
,
nSrc
,
1
>
{}([
&
](
auto
i
)
{
move_tensor_coordinate
(
src_descs
[
i
],
src_coords_
(
i
),
make_tensor_coordinate_step
(
src_descs
[
i
],
forward_step
));
});
static_for
<
0
,
nDst
,
1
>
{}([
&
](
auto
i
)
{
move_tensor_coordinate
(
dst_descs
[
i
],
dst_coords_
(
i
),
make_tensor_coordinate_step
(
dst_descs
[
i
],
forward_step
));
});
}
});
// move coordinate back to slice origin (or not)
static_for
<
0
,
nSrc
,
1
>
{}([
&
](
auto
i
)
{
if
constexpr
(
SrcResetCoordinateAfterRunFlags
::
At
(
i
))
{
const
auto
src_reset_step
=
make_tensor_coordinate_step
(
src_descs
[
i
],
GetCoordinateResetStep
());
move_tensor_coordinate
(
src_descs
[
i
],
src_coords_
(
i
),
src_reset_step
);
}
});
static_for
<
0
,
nDst
,
1
>
{}([
&
](
auto
i
)
{
if
constexpr
(
DstResetCoordinateAfterRunFlags
::
At
(
i
))
{
const
auto
dst_reset_step
=
make_tensor_coordinate_step
(
dst_descs
[
i
],
GetCoordinateResetStep
());
move_tensor_coordinate
(
dst_descs
[
i
],
dst_coords_
(
i
),
dst_reset_step
);
}
});
}
__device__
static
constexpr
auto
GetCoordinateResetStep
()
{
constexpr
auto
num_access
=
SpaceFillingCurve
::
GetNumOfAccess
();
if
constexpr
(
num_access
==
0
)
{
return
typename
SpaceFillingCurve
::
Index
{};
}
else
{
constexpr
auto
reset_step
=
SpaceFillingCurve
::
GetStepBetween
(
Number
<
num_access
-
1
>
{},
Number
<
0
>
{});
return
reset_step
;
}
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
template
<
index_t
ISrc
>
__device__
void
MoveSrcSliceWindow
(
const
SrcDescs
&
src_descs
,
Number
<
ISrc
>
iSrc
,
const
Index
&
src_slice_origin_step_idx
)
{
// if src coord was not reset by RunRead(), then need to adjust the step here
const
auto
adjusted_step_idx
=
SrcResetCoordinateAfterRunFlags
::
At
(
iSrc
)
?
src_slice_origin_step_idx
:
src_slice_origin_step_idx
+
GetCoordinateResetStep
();
// is it OK to construct a new step every time?
const
auto
adjusted_step
=
make_tensor_coordinate_step
(
src_descs
[
iSrc
],
adjusted_step_idx
);
move_tensor_coordinate
(
src_descs
[
iSrc
],
src_coords_
(
iSrc
),
adjusted_step
);
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
template
<
index_t
IDst
>
__device__
void
MoveDstSliceWindow
(
const
DstDescs
&
dst_descs
,
Number
<
IDst
>
iDst
,
const
Index
&
dst_slice_origin_step_idx
)
{
// if dst coord was not reset by Run(), then need to adjust the step here
const
auto
adjusted_step_idx
=
DstResetCoordinateAfterRunFlags
::
At
(
iDst
)
?
dst_slice_origin_step_idx
:
dst_slice_origin_step_idx
+
GetCoordinateResetStep
();
// is it OK to construct a new step every time?
const
auto
adjusted_step
=
make_tensor_coordinate_step
(
dst_descs
[
iDst
],
adjusted_step_idx
);
move_tensor_coordinate
(
dst_descs
[
iDst
],
dst_coords_
(
iDst
),
adjusted_step
);
}
private:
SrcCoords
src_coords_
;
DstCoords
dst_coords_
;
const
ElementwiseOperation
element_op_
;
};
}
// namespace ck
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
View file @
7a3b49e5
#ifndef CK_XDLOPS_GEMM_HPP
#define CK_XDLOPS_GEMM_HPP
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "common_header.hpp"
#include "math.hpp"
#include "amd_xdlops.hpp"
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/math.hpp"
#include "ck/utility/amd_xdlops.hpp"
namespace
ck
{
...
...
@@ -786,4 +788,3 @@ struct XdlopsGemm
};
}
// namespace ck
#endif
include/ck/utility/amd_address_space.hpp
View file @
7a3b49e5
#ifndef CK_AMD_ADDRESS_SPACE_HPP
#define CK_AMD_ADDRESS_SPACE_HPP
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "config.hpp"
#pragma once
#include "ck/ck.hpp"
#include "c_style_pointer_cast.hpp"
// Address Space for AMDGCN
...
...
@@ -41,4 +43,3 @@ __host__ __device__ T CK_CONSTANT_ADDRESS_SPACE* cast_pointer_to_constant_addres
}
}
// namespace ck
#endif
include/ck/utility/amd_buffer_addressing.hpp
View file @
7a3b49e5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "data_type.hpp"
...
...
@@ -6,6 +9,8 @@ namespace ck {
template
<
typename
T
>
union
BufferResource
{
__device__
constexpr
BufferResource
()
:
content
{}
{}
// 128 bit SGPRs to supply buffer resource in buffer instructions
// https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions
int32x4_t
content
;
...
...
include/ck/utility/amd_inline_asm.hpp
View file @
7a3b49e5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_AMD_INLINE_ASM_HPP
#define CK_AMD_INLINE_ASM_HPP
...
...
include/ck/utility/amd_llvm_intrinsic.hpp
View file @
7a3b49e5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_AMD_LLVM_INTRINSIC_HPP
#define CK_AMD_LLVM_INTRINSIC_HPP
...
...
Prev
1
…
6
7
8
9
10
11
12
13
14
…
30
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment