Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
8c4e33f1
Commit
8c4e33f1
authored
Nov 15, 2021
by
Chao Liu
Browse files
Merge remote-tracking branch 'origin/develop' into v5r1_add
parents
5aed38d4
3737bb03
Changes
36
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1786 additions
and
390 deletions
+1786
-390
composable_kernel/include/tensor_description/multi_index_transform.hpp
...rnel/include/tensor_description/multi_index_transform.hpp
+4
-2
composable_kernel/include/tensor_description/static_tensor.hpp
...sable_kernel/include/tensor_description/static_tensor.hpp
+265
-0
composable_kernel/include/tensor_description/tensor_adaptor.hpp
...able_kernel/include/tensor_description/tensor_adaptor.hpp
+14
-0
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
...kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
+4
-1
composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer.hpp
...lude/tensor_operation/blockwise_tensor_slice_transfer.hpp
+17
-17
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v3r2.hpp
...ensor_operation/threadwise_tensor_slice_transfer_v3r2.hpp
+802
-0
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
+85
-146
composable_kernel/include/utility/amd_buffer_addressing.hpp
composable_kernel/include/utility/amd_buffer_addressing.hpp
+100
-3
composable_kernel/include/utility/amd_xdlops.hpp
composable_kernel/include/utility/amd_xdlops.hpp
+87
-105
composable_kernel/include/utility/common_header.hpp
composable_kernel/include/utility/common_header.hpp
+4
-0
composable_kernel/include/utility/config.hpp
composable_kernel/include/utility/config.hpp
+6
-1
composable_kernel/include/utility/container_helper.hpp
composable_kernel/include/utility/container_helper.hpp
+0
-13
composable_kernel/include/utility/data_type.hpp
composable_kernel/include/utility/data_type.hpp
+12
-0
composable_kernel/include/utility/ignore.hpp
composable_kernel/include/utility/ignore.hpp
+21
-0
composable_kernel/include/utility/is_known_at_compile_time.hpp
...sable_kernel/include/utility/is_known_at_compile_time.hpp
+49
-0
composable_kernel/include/utility/static_buffer.hpp
composable_kernel/include/utility/static_buffer.hpp
+92
-94
composable_kernel/include/utility/static_buffer_of_vector_type_v2.hpp
...ernel/include/utility/static_buffer_of_vector_type_v2.hpp
+100
-0
composable_kernel/include/utility/statically_indexed_array.hpp
...sable_kernel/include/utility/statically_indexed_array.hpp
+26
-8
composable_kernel/include/utility/transpose_vectors.hpp
composable_kernel/include/utility/transpose_vectors.hpp
+87
-0
composable_kernel/include/utility/tuple.hpp
composable_kernel/include/utility/tuple.hpp
+11
-0
No files found.
composable_kernel/include/tensor_description/multi_index_transform.hpp
View file @
8c4e33f1
...
@@ -30,7 +30,8 @@ struct PassThrough
...
@@ -30,7 +30,8 @@ struct PassThrough
__host__
__device__
constexpr
const
auto
&
GetUpperLengths
()
const
{
return
up_lengths_
;
}
__host__
__device__
constexpr
const
auto
&
GetUpperLengths
()
const
{
return
up_lengths_
;
}
template
<
typename
LowIdx
,
typename
UpIdx
>
template
<
typename
LowIdx
,
typename
UpIdx
>
__host__
__device__
static
void
CalculateLowerIndex
(
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
__host__
__device__
static
constexpr
void
CalculateLowerIndex
(
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
{
{
static_assert
(
LowIdx
::
Size
()
==
1
&&
UpIdx
::
Size
()
==
1
,
static_assert
(
LowIdx
::
Size
()
==
1
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
"wrong! inconsistent # of dimension"
);
...
@@ -1708,7 +1709,8 @@ struct Vectorize
...
@@ -1708,7 +1709,8 @@ struct Vectorize
__host__
__device__
constexpr
const
auto
&
GetUpperLengths
()
const
{
return
up_lengths_
;
}
__host__
__device__
constexpr
const
auto
&
GetUpperLengths
()
const
{
return
up_lengths_
;
}
template
<
typename
LowIdx
,
typename
UpIdx
>
template
<
typename
LowIdx
,
typename
UpIdx
>
__host__
__device__
void
CalculateLowerIndex
(
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
const
__host__
__device__
constexpr
void
CalculateLowerIndex
(
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
const
{
{
static_assert
(
LowIdx
::
Size
()
==
1
&&
UpIdx
::
Size
()
==
1
,
static_assert
(
LowIdx
::
Size
()
==
1
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
"wrong! inconsistent # of dimension"
);
...
...
composable_kernel/include/tensor_description/static_tensor.hpp
0 → 100644
View file @
8c4e33f1
#ifndef CK_STATIC_TENSOR_HPP
#define CK_STATIC_TENSOR_HPP
#include "ignore.hpp"
namespace
ck
{
// StaticTensor for Scalar
template
<
AddressSpaceEnum_t
AddressSpace
,
typename
T
,
typename
TensorDesc
,
bool
InvalidElementUseNumericalZeroValue
,
typename
enable_if
<
TensorDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
struct
StaticTensor
{
static
constexpr
auto
desc_
=
TensorDesc
{};
static
constexpr
index_t
ndim_
=
TensorDesc
::
GetNumOfDimension
();
static
constexpr
index_t
element_space_size_
=
desc_
.
GetElementSpaceSize
();
__host__
__device__
constexpr
StaticTensor
()
:
invalid_element_value_
{
0
}
{}
__host__
__device__
constexpr
StaticTensor
(
T
invalid_element_value
)
:
invalid_element_value_
{
invalid_element_value
}
{
}
// read access
template
<
typename
Idx
,
typename
enable_if
<
is_known_at_compile_time
<
Idx
>
::
value
&&
Idx
::
Size
()
==
ndim_
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
const
T
&
operator
[](
Idx
)
const
{
constexpr
auto
coord
=
make_tensor_coordinate
(
desc_
,
to_multi_index
(
Idx
{}));
constexpr
index_t
offset
=
coord
.
GetOffset
();
constexpr
bool
is_valid
=
coordinate_has_valid_offset
(
desc_
,
coord
);
if
constexpr
(
is_valid
)
{
return
data_
[
Number
<
offset
>
{}];
}
else
{
if
constexpr
(
InvalidElementUseNumericalZeroValue
)
{
return
T
{
0
};
}
else
{
return
invalid_element_value_
;
}
}
}
// write access
template
<
typename
Idx
,
typename
enable_if
<
is_known_at_compile_time
<
Idx
>
::
value
&&
Idx
::
Size
()
==
ndim_
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
T
&
operator
()(
Idx
)
{
constexpr
auto
coord
=
make_tensor_coordinate
(
desc_
,
to_multi_index
(
Idx
{}));
constexpr
index_t
offset
=
coord
.
GetOffset
();
constexpr
bool
is_valid
=
coordinate_has_valid_offset
(
desc_
,
coord
);
if
constexpr
(
is_valid
)
{
return
data_
(
Number
<
offset
>
{});
}
else
{
return
ignore
;
}
}
StaticBuffer
<
AddressSpace
,
T
,
element_space_size_
,
true
>
data_
;
T
invalid_element_value_
=
T
{
0
};
};
// StaticTensor for vector
template
<
AddressSpaceEnum_t
AddressSpace
,
typename
S
,
index_t
ScalarPerVector
,
typename
TensorDesc
,
bool
InvalidElementUseNumericalZeroValue
,
typename
enable_if
<
TensorDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
struct
StaticTensorTupleOfVectorBuffer
{
static
constexpr
auto
desc_
=
TensorDesc
{};
static
constexpr
index_t
ndim_
=
TensorDesc
::
GetNumOfDimension
();
static
constexpr
index_t
element_space_size_
=
desc_
.
GetElementSpaceSize
();
static
constexpr
index_t
num_of_vector_
=
math
::
integer_divide_ceil
(
element_space_size_
,
ScalarPerVector
);
using
V
=
vector_type
<
S
,
ScalarPerVector
>
;
__host__
__device__
constexpr
StaticTensorTupleOfVectorBuffer
()
:
invalid_element_value_
{
0
}
{}
__host__
__device__
constexpr
StaticTensorTupleOfVectorBuffer
(
S
invalid_element_value
)
:
invalid_element_value_
{
invalid_element_value
}
{
}
// Get S
// Idx is for S, not V
template
<
typename
Idx
,
typename
enable_if
<
is_known_at_compile_time
<
Idx
>
::
value
&&
Idx
::
Size
()
==
ndim_
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
const
S
&
operator
[](
Idx
)
const
{
constexpr
auto
coord
=
make_tensor_coordinate
(
desc_
,
to_multi_index
(
Idx
{}));
constexpr
index_t
offset
=
coord
.
GetOffset
();
constexpr
bool
is_valid
=
coordinate_has_valid_offset
(
desc_
,
coord
);
if
constexpr
(
is_valid
)
{
return
data_
[
Number
<
offset
>
{}];
}
else
{
if
constexpr
(
InvalidElementUseNumericalZeroValue
)
{
return
S
{
0
};
}
else
{
return
invalid_element_value_
;
}
}
}
// Set S
// Idx is for S, not V
template
<
typename
Idx
,
typename
enable_if
<
is_known_at_compile_time
<
Idx
>
::
value
&&
Idx
::
Size
()
==
ndim_
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
S
&
operator
()(
Idx
)
{
constexpr
auto
coord
=
make_tensor_coordinate
(
desc_
,
to_multi_index
(
Idx
{}));
constexpr
index_t
offset
=
coord
.
GetOffset
();
constexpr
bool
is_valid
=
coordinate_has_valid_offset
(
desc_
,
coord
);
if
constexpr
(
is_valid
)
{
return
data_
(
Number
<
offset
>
{});
}
else
{
return
ignore
;
}
}
// Get X
// Idx is for S, not X. Idx should be aligned with X
template
<
typename
X
,
typename
Idx
,
typename
enable_if
<
has_same_scalar_type
<
S
,
X
>
::
value
&&
is_known_at_compile_time
<
Idx
>::
value
&&
Idx
::
Size
()
==
ndim_
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
X
GetAsType
(
Idx
)
const
{
constexpr
auto
coord
=
make_tensor_coordinate
(
desc_
,
to_multi_index
(
Idx
{}));
constexpr
index_t
offset
=
coord
.
GetOffset
();
constexpr
bool
is_valid
=
coordinate_has_valid_offset
(
desc_
,
coord
);
if
constexpr
(
is_valid
)
{
return
data_
.
template
GetAsType
<
X
>(
Number
<
offset
>
{});
}
else
{
if
constexpr
(
InvalidElementUseNumericalZeroValue
)
{
// TODO: is this right way to initialize a vector?
return
X
{
0
};
}
else
{
// TODO: is this right way to initialize a vector?
return
X
{
invalid_element_value_
};
}
}
}
// Set X
// Idx is for S, not X. Idx should be aligned with X
template
<
typename
X
,
typename
Idx
,
typename
enable_if
<
has_same_scalar_type
<
S
,
X
>
::
value
&&
is_known_at_compile_time
<
Idx
>::
value
&&
Idx
::
Size
()
==
ndim_
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
void
SetAsType
(
Idx
,
X
x
)
{
constexpr
auto
coord
=
make_tensor_coordinate
(
desc_
,
to_multi_index
(
Idx
{}));
constexpr
index_t
offset
=
coord
.
GetOffset
();
constexpr
bool
is_valid
=
coordinate_has_valid_offset
(
desc_
,
coord
);
if
constexpr
(
is_valid
)
{
data_
.
template
SetAsType
<
X
>(
Number
<
offset
>
{},
x
);
}
}
// Get read access to V. No is_valid check
// Idx is for S, not V. Idx should be aligned with V
template
<
typename
Idx
>
__host__
__device__
constexpr
const
V
&
GetVectorTypeReference
(
Idx
)
const
{
constexpr
auto
coord
=
make_tensor_coordinate
(
desc_
,
to_multi_index
(
Idx
{}));
constexpr
index_t
offset
=
coord
.
GetOffset
();
return
data_
.
GetVectorTypeReference
(
Number
<
offset
>
{});
}
// Get read access to V. No is_valid check
// Idx is for S, not V. Idx should be aligned with V
template
<
typename
Idx
>
__host__
__device__
constexpr
V
&
GetVectorTypeReference
(
Idx
)
{
constexpr
auto
coord
=
make_tensor_coordinate
(
desc_
,
to_multi_index
(
Idx
{}));
constexpr
index_t
offset
=
coord
.
GetOffset
();
return
data_
.
GetVectorTypeReference
(
Number
<
offset
>
{});
}
StaticBufferTupleOfVector
<
AddressSpace
,
S
,
num_of_vector_
,
ScalarPerVector
,
true
>
data_
;
S
invalid_element_value_
=
S
{
0
};
};
template
<
AddressSpaceEnum_t
AddressSpace
,
typename
T
,
typename
TensorDesc
,
typename
enable_if
<
TensorDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
__host__
__device__
constexpr
auto
make_static_tensor
(
TensorDesc
)
{
return
StaticTensor
<
AddressSpace
,
T
,
TensorDesc
,
true
>
{};
}
template
<
AddressSpaceEnum_t
AddressSpace
,
typename
T
,
typename
TensorDesc
,
typename
X
,
typename
enable_if
<
TensorDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
,
typename
enable_if
<
is_same
<
remove_cvref_t
<
T
>
,
remove_cvref_t
<
X
>>::
value
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
auto
make_static_tensor
(
TensorDesc
,
X
invalid_element_value
)
{
return
StaticTensor
<
AddressSpace
,
T
,
TensorDesc
,
true
>
{
invalid_element_value
};
}
}
// namespace ck
#endif
composable_kernel/include/tensor_description/tensor_adaptor.hpp
View file @
8c4e33f1
...
@@ -151,6 +151,20 @@ struct TensorAdaptor
...
@@ -151,6 +151,20 @@ struct TensorAdaptor
__host__
__device__
constexpr
auto
GetElementSize
()
const
{
return
element_size_
;
}
__host__
__device__
constexpr
auto
GetElementSize
()
const
{
return
element_size_
;
}
#if 0 // debug
template <index_t I>
__host__ __device__ constexpr index_t GetTopDimensionLength(Number<I> idim) const
{
// TODO: not implemented
}
template <index_t I>
__host__ __device__ constexpr index_t GetBottomDimensionLength(Number<I> idim) const
{
// TODO: not implemented
}
#endif
template
<
typename
TopIdx
>
template
<
typename
TopIdx
>
__host__
__device__
constexpr
auto
CalculateBottomIndex
(
const
TopIdx
&
idx_top
)
const
__host__
__device__
constexpr
auto
CalculateBottomIndex
(
const
TopIdx
&
idx_top
)
const
{
{
...
...
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
View file @
8c4e33f1
...
@@ -37,7 +37,10 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -37,7 +37,10 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerXDL
);
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerXDL
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerXDL
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerXDL
);
StaticBufferV2
<
AddressSpaceEnum_t
::
Vgpr
,
vector_type
<
FloatAcc
,
16
>
,
MRepeat
*
NRepeat
,
true
>
StaticBufferOfVectorTypeV2
<
AddressSpaceEnum_t
::
Vgpr
,
vector_type
<
FloatAcc
,
16
>
,
MRepeat
*
NRepeat
,
true
>
c_thread_buf_
;
c_thread_buf_
;
__host__
__device__
constexpr
auto
&
GetCThreadBuffer
()
{
return
c_thread_buf_
;
}
__host__
__device__
constexpr
auto
&
GetCThreadBuffer
()
{
return
c_thread_buf_
;
}
...
...
composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer.hpp
View file @
8c4e33f1
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
#include "tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "cluster_descriptor.hpp"
#include "cluster_descriptor.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_transfer
_v3r2
.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -146,7 +146,7 @@ struct BlockwiseTensorSliceTransfer_v4
...
@@ -146,7 +146,7 @@ struct BlockwiseTensorSliceTransfer_v4
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
using
ThreadwiseTransfer
=
using
ThreadwiseTransfer
=
ThreadwiseTensorSliceTransfer_v3
<
ThreadSliceLengths
,
ThreadwiseTensorSliceTransfer_v3
r2
<
ThreadSliceLengths
,
DstInMemOp
,
DstInMemOp
,
SrcData
,
SrcData
,
DstData
,
DstData
,
...
...
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v3r2.hpp
0 → 100644
View file @
8c4e33f1
#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V3R2_HPP
#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V3R2_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "static_tensor.hpp"
namespace
ck
{
namespace
detail
{
// TODO: How to fix this? It uses an struct instead of lambda because lambda
// doesn't have constructor
template
<
index_t
SrcVectorDim
,
index_t
SrcScalarPerVector
,
index_t
DstVectorDim
,
index_t
DstScalarPerVector
>
struct
lambda_scalar_per_access_for_src_and_dst
{
__host__
__device__
constexpr
auto
operator
()(
index_t
i
)
const
{
if
(
i
==
SrcVectorDim
&&
i
==
DstVectorDim
)
{
return
math
::
lcm
(
SrcScalarPerVector
,
DstScalarPerVector
);
}
else
if
(
i
==
SrcVectorDim
)
{
return
SrcScalarPerVector
;
}
else
if
(
i
==
DstVectorDim
)
{
return
DstScalarPerVector
;
}
else
{
return
1
;
}
}
};
}
// namespace detail
// Assume:
// 1. src_desc and dst_desc are not known at compile-time
// 2. SrcBuffer and DstBuffer are DynamicBuffer
// 3. src_slice_origin and dst_slice_origin are not known at compile-time,
// 4. Use thread buffer
template
<
typename
SliceLengths
,
InMemoryDataOperationEnum_t
DstInMemOp
,
typename
SrcData
,
typename
DstData
,
typename
SrcDesc
,
typename
DstDesc
,
typename
SrcDimAccessOrder
,
typename
DstDimAccessOrder
,
index_t
SrcVectorDim
,
index_t
DstVectorDim
,
index_t
SrcScalarPerVector
,
index_t
DstScalarPerVector
,
index_t
SrcScalarStrideInVector
,
index_t
DstScalarStrideInVector
,
bool
SrcResetCoordinateAfterRun
,
// control whether to move back src coordinate after each
// RunRead(), will be fused with MoveSrcSliceWindow to
// save addr computation
bool
DstResetCoordinateAfterRun
>
// control whether to move back dst coordinate after each
// RunWrite(), will be fused with MoveDstSliceWindow to
// save addr computation
struct
ThreadwiseTensorSliceTransfer_v3r2
{
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
using
Index
=
MultiIndex
<
nDim
>
;
using
SrcCoord
=
decltype
(
make_tensor_coordinate
(
SrcDesc
{},
Index
{}));
using
DstCoord
=
decltype
(
make_tensor_coordinate
(
DstDesc
{},
Index
{}));
using
SrcCoordStep
=
decltype
(
make_tensor_coordinate_step
(
SrcDesc
{},
Index
{}));
using
DstCoordStep
=
decltype
(
make_tensor_coordinate_step
(
DstDesc
{},
Index
{}));
__device__
constexpr
ThreadwiseTensorSliceTransfer_v3r2
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin
)
:
src_coord_
(
make_tensor_coordinate
(
src_desc
,
src_slice_origin
)),
dst_coord_
(
make_tensor_coordinate
(
dst_desc
,
dst_slice_origin
))
{
}
__device__
void
SetSrcSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
{
src_coord_
=
make_tensor_coordinate
(
src_desc
,
src_slice_origin_idx
);
}
__device__
void
SetDstSliceOrigin
(
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin_idx
)
{
dst_coord_
=
make_tensor_coordinate
(
dst_desc
,
dst_slice_origin_idx
);
}
template
<
typename
SrcBuffer
,
typename
SrcStepHacks
>
__device__
void
RunRead
(
const
SrcDesc
&
src_desc
,
const
SrcBuffer
&
src_buf
,
const
SrcStepHacks
&
src_step_hacks
)
{
static_assert
(
SrcBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum_t
::
Global
or
SrcBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum_t
::
Lds
,
"wrong!"
);
static_assert
(
is_same
<
remove_cvref_t
<
typename
SrcBuffer
::
type
>
,
remove_cvref_t
<
SrcData
>>::
value
,
"wrong! SrcBuffer and SrcData data type are inconsistent"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
constexpr
auto
src_dim_access_order
=
SrcDimAccessOrder
{};
constexpr
auto
ordered_src_access_lengths
=
container_reorder_given_new2old
(
src_access_lengths
,
src_dim_access_order
);
// make forward steps
const
auto
src_forward_steps
=
generate_tuple
(
[
&
](
auto
i
)
{
Index
forward_step_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
forward_step_idx
(
j
)
=
(
i
.
value
==
j
.
value
)
?
src_scalar_per_access
[
i
]
:
0
;
});
return
make_tensor_coordinate_step
(
src_desc
,
forward_step_idx
,
src_step_hacks
[
I0
][
i
]);
},
Number
<
nDim
>
{});
// make backward steps
const
auto
src_backward_steps
=
generate_tuple
(
[
&
](
auto
i
)
{
Index
backward_step_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
backward_step_idx
(
j
)
=
(
i
.
value
==
j
.
value
)
?
-
src_scalar_per_access
[
i
]
:
0
;
});
return
make_tensor_coordinate_step
(
src_desc
,
backward_step_idx
,
src_step_hacks
[
I1
][
i
]);
},
Number
<
nDim
>
{});
// loop over tensor and copy
static_ford
<
decltype
(
ordered_src_access_lengths
)
>
{}([
&
](
auto
ordered_src_access_idx
)
{
// judge move forward or move backward
constexpr
auto
forward_sweep
=
[
&
]()
{
StaticallyIndexedArray
<
bool
,
nDim
>
forward_sweep_
;
forward_sweep_
(
I0
)
=
true
;
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_src_access_idx
[
I0
];
static_for
<
0
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_src_access_lengths
[
j
]
+
ordered_src_access_idx
[
j
];
});
forward_sweep_
(
i
)
=
tmp
%
2
==
0
;
});
return
forward_sweep_
;
}();
// calculate src data index
constexpr
auto
src_data_idx
=
[
&
]()
{
Index
ordered_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
ordered_idx
(
i
)
=
forward_sweep
[
i
]
?
ordered_src_access_idx
[
i
]
:
ordered_src_access_lengths
[
i
]
-
1
-
ordered_src_access_idx
[
i
];
});
return
container_reorder_given_old2new
(
ordered_idx
,
src_dim_access_order
)
*
src_scalar_per_access
;
}();
constexpr
auto
src_data_idx_seq
=
generate_sequence_v2
(
[
&
](
auto
i
)
{
return
Number
<
src_data_idx
[
i
]
>
{};
},
Number
<
src_data_idx
.
Size
()
>
{});
const
bool
is_src_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_desc
,
src_coord_
);
using
src_vector_t
=
typename
vector_type_maker_t
<
SrcData
,
SrcScalarPerVector
>::
type
;
// copy data from src_buf to src_thread_scratch_
src_thread_scratch_
.
template
SetAsType
<
src_vector_t
>(
src_data_idx_seq
,
src_buf
.
template
Get
<
src_vector_t
>(
src_coord_
.
GetOffset
(),
is_src_valid
));
constexpr
auto
move_on_dim
=
[
&
]()
constexpr
{
StaticallyIndexedArray
<
bool
,
nDim
>
move_on_dim_
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
move_on_dim_
(
i
)
=
ordered_src_access_idx
[
i
]
<
ordered_src_access_lengths
[
i
]
-
1
;
static_for
<
i
+
1
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
move_on_dim_
(
i
)
&=
ordered_src_access_idx
[
j
]
==
ordered_src_access_lengths
[
j
]
-
1
;
});
});
return
move_on_dim_
;
}
();
// move src coord
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
if
constexpr
(
move_on_dim
[
i
])
{
if
constexpr
(
forward_sweep
[
i
])
{
move_tensor_coordinate
(
src_desc
,
src_coord_
,
src_forward_steps
[
src_dim_access_order
[
i
]]);
}
else
{
move_tensor_coordinate
(
src_desc
,
src_coord_
,
src_backward_steps
[
src_dim_access_order
[
i
]]);
}
}
});
});
// move src coordinate back to slice origin (or not)
if
constexpr
(
SrcResetCoordinateAfterRun
)
{
const
auto
src_reset_step
=
make_tensor_coordinate_step
(
src_desc
,
GetSrcCoordinateResetStep
());
move_tensor_coordinate
(
src_desc
,
src_coord_
,
src_reset_step
);
}
}
__device__
void
TransferDataFromSrcThreadScratchToDstThreadScratch
()
{
#if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE
static_ford
<
SliceLengths
>
{}([
&
](
auto
idx
)
{
// convert from SrcData to DstData here
dst_thread_scratch_
(
idx
)
=
type_convert
<
DstData
>
{}(
src_thread_scratch_
[
idx
]);
});
#else
// sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_
// TODO make this logic more generic for more sub-dword datatype
if
constexpr
(
SrcVectorDim
!=
DstVectorDim
&&
is_same
<
half_t
,
remove_cvref_t
<
SrcData
>>::
value
&&
is_same
<
half_t
,
remove_cvref_t
<
DstData
>>::
value
&&
SrcScalarPerVector
%
2
==
0
&&
DstScalarPerVector
%
2
==
0
)
{
// each transpose does
// DstScalarPerVector # of src vectors in src_thread_scratch_
// SrcScalarPerVector # of dst vectors in dst_thread_scratch_
constexpr
index_t
num_src_vector
=
Number
<
DstScalarPerVector
>
{};
constexpr
index_t
num_dst_vector
=
Number
<
SrcScalarPerVector
>
{};
// Assume SrcVectorDim is not the same as DstVectorDim, so we do transpose
// TODO: make this logic generic for all scenario
static_assert
(
SrcVectorDim
!=
DstVectorDim
,
"wrong"
);
constexpr
auto
src_scalar_step_in_vector
=
generate_sequence
(
detail
::
lambda_scalar_step_in_vector
<
SrcVectorDim
>
{},
Number
<
nDim
>
{});
constexpr
auto
dst_scalar_step_in_vector
=
generate_sequence
(
detail
::
lambda_scalar_step_in_vector
<
DstVectorDim
>
{},
Number
<
nDim
>
{});
constexpr
auto
scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access_for_src_and_dst
<
SrcVectorDim
,
SrcScalarPerVector
,
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
access_lengths
=
SliceLengths
{}
/
scalar_per_access
;
static_ford
<
decltype
(
access_lengths
)
>
{}([
&
](
auto
access_idx
)
{
constexpr
auto
data_idx
=
access_idx
*
scalar_per_access
;
constexpr
auto
data_idx_seq
=
generate_sequence_v2
(
[
&
](
auto
i
)
{
return
Number
<
data_idx
[
i
]
>
{};
},
Number
<
nDim
>
{});
// TODO type_convert is not used yet!!!!!
using
src_vector_t
=
vector_type_maker_t
<
SrcData
,
SrcScalarPerVector
>
;
using
dst_vector_t
=
vector_type_maker_t
<
DstData
,
DstScalarPerVector
>
;
// get DstScalarPerVector # of read-only references to src vectors from
// src_thread_scratch_
const
auto
src_vector_refs
=
generate_tie
(
[
&
](
auto
i
)
->
const
src_vector_t
&
{
// i increment corresponds to movement in DstVectorDim
return
src_thread_scratch_
.
GetVectorTypeReference
(
data_idx_seq
+
i
*
dst_scalar_step_in_vector
);
},
Number
<
num_src_vector
>
{});
// get SrcScalarPerVector # of references to dst vectors from dst_thread_scratch_
auto
dst_vector_refs
=
generate_tie
(
[
&
](
auto
i
)
->
dst_vector_t
&
{
// i increment corresponds to movement in SrcVectorDim
return
dst_thread_scratch_
.
GetVectorTypeReference
(
data_idx_seq
+
i
*
src_scalar_step_in_vector
);
},
Number
<
num_dst_vector
>
{});
// do data transpose
// TODO type_convert is not used yet!!!!!
transpose_vectors
<
SrcData
,
DstScalarPerVector
,
SrcScalarPerVector
>
{}(
src_vector_refs
,
dst_vector_refs
);
});
}
else
{
static_ford
<
SliceLengths
>
{}([
&
](
auto
idx
)
{
// convert from SrcData to DstData here
dst_thread_scratch_
(
idx
)
=
type_convert
<
DstData
>
{}(
src_thread_scratch_
[
idx
]);
});
}
#endif
}
template
<
typename
DstBuffer
,
typename
DstStepHacks
>
__device__
void
RunWrite
(
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
,
const
DstStepHacks
&
dst_step_hacks
)
{
// if there is transpose, it's done here
// TODO move this elsewhere
TransferDataFromSrcThreadScratchToDstThreadScratch
();
static_assert
(
DstBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum_t
::
Global
or
DstBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum_t
::
Lds
,
"wrong!"
);
static_assert
(
is_same
<
remove_cvref_t
<
typename
DstBuffer
::
type
>
,
remove_cvref_t
<
DstData
>>::
value
,
"wrong! SrcBuffer or DstBuffer data type is wrong"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
// src scalar per access on each dim
// TODO: don't use this
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
dst_access_lengths
=
SliceLengths
{}
/
dst_scalar_per_access
;
constexpr
auto
dst_dim_access_order
=
DstDimAccessOrder
{};
constexpr
auto
ordered_dst_access_lengths
=
container_reorder_given_new2old
(
dst_access_lengths
,
dst_dim_access_order
);
// make forward steps
const
auto
dst_forward_steps
=
generate_tuple
(
[
&
](
auto
i
)
{
Index
forward_step_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
forward_step_idx
(
j
)
=
(
i
.
value
==
j
.
value
)
?
dst_scalar_per_access
[
i
]
:
0
;
});
return
make_tensor_coordinate_step
(
dst_desc
,
forward_step_idx
,
dst_step_hacks
[
I0
][
i
]);
},
Number
<
nDim
>
{});
// make backward steps
const
auto
dst_backward_steps
=
generate_tuple
(
[
&
](
auto
i
)
{
Index
backward_step_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
backward_step_idx
(
j
)
=
(
i
.
value
==
j
.
value
)
?
-
dst_scalar_per_access
[
i
]
:
0
;
});
return
make_tensor_coordinate_step
(
dst_desc
,
backward_step_idx
,
dst_step_hacks
[
I1
][
i
]);
},
Number
<
nDim
>
{});
// loop over tensor and copy
static_ford
<
decltype
(
ordered_dst_access_lengths
)
>
{}([
&
](
auto
ordered_dst_access_idx
)
{
// judge move forward or move backward
constexpr
auto
forward_sweep
=
[
&
]()
{
StaticallyIndexedArray
<
bool
,
nDim
>
forward_sweep_
;
forward_sweep_
(
I0
)
=
true
;
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_dst_access_idx
[
I0
];
static_for
<
0
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_dst_access_lengths
[
j
]
+
ordered_dst_access_idx
[
j
];
});
forward_sweep_
(
i
)
=
tmp
%
2
==
0
;
});
return
forward_sweep_
;
}();
// calculate dst data index
constexpr
auto
dst_data_idx
=
[
&
]()
{
Index
ordered_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
ordered_idx
(
i
)
=
forward_sweep
[
i
]
?
ordered_dst_access_idx
[
i
]
:
ordered_dst_access_lengths
[
i
]
-
1
-
ordered_dst_access_idx
[
i
];
});
return
container_reorder_given_old2new
(
ordered_idx
,
dst_dim_access_order
)
*
dst_scalar_per_access
;
}();
constexpr
auto
dst_data_idx_seq
=
generate_sequence_v2
(
[
&
](
auto
i
)
{
return
Number
<
dst_data_idx
[
i
]
>
{};
},
Number
<
dst_data_idx
.
Size
()
>
{});
const
bool
is_dst_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
dst_desc
,
dst_coord_
);
using
dst_vector_t
=
typename
vector_type_maker_t
<
DstData
,
DstScalarPerVector
>::
type
;
// copy data from dst_thread_scratch_ to dst_buf
dst_buf
.
template
Set
<
dst_vector_t
>(
dst_coord_
.
GetOffset
(),
is_dst_valid
,
dst_thread_scratch_
.
template
GetAsType
<
dst_vector_t
>(
dst_data_idx_seq
));
constexpr
auto
move_on_dim
=
[
&
]()
constexpr
{
StaticallyIndexedArray
<
bool
,
nDim
>
move_on_dim_
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
move_on_dim_
(
i
)
=
ordered_dst_access_idx
[
i
]
<
ordered_dst_access_lengths
[
i
]
-
1
;
static_for
<
i
+
1
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
move_on_dim_
(
i
)
&=
ordered_dst_access_idx
[
j
]
==
ordered_dst_access_lengths
[
j
]
-
1
;
});
});
return
move_on_dim_
;
}
();
// move dst coord
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
if
constexpr
(
move_on_dim
[
i
])
{
if
constexpr
(
forward_sweep
[
i
])
{
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
dst_forward_steps
[
dst_dim_access_order
[
i
]]);
}
else
{
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
dst_backward_steps
[
dst_dim_access_order
[
i
]]);
}
}
});
});
// move dst coordinate back to slice origin (or not)
if
constexpr
(
DstResetCoordinateAfterRun
)
{
const
auto
dst_reset_step
=
make_tensor_coordinate_step
(
dst_desc
,
GetDstCoordinateResetStep
());
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
dst_reset_step
);
}
}
template
<
typename
SrcBuffer
>
__device__
void
RunRead
(
const
SrcDesc
&
src_desc
,
const
SrcBuffer
&
src_buf
)
{
constexpr
index_t
ntransform_src
=
SrcDesc
::
GetNumOfTransform
();
constexpr
auto
zeros
=
typename
uniform_sequence_gen
<
ntransform_src
,
0
>::
type
{};
constexpr
auto
src_step_hacks
=
make_tuple
(
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}),
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}));
RunRead
(
src_desc
,
src_buf
,
src_step_hacks
);
}
template
<
typename
DstBuffer
>
__device__
void
RunWrite
(
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
)
{
constexpr
index_t
ntransform_dst
=
DstDesc
::
GetNumOfTransform
();
constexpr
auto
zeros
=
typename
uniform_sequence_gen
<
ntransform_dst
,
0
>::
type
{};
constexpr
auto
dst_step_hacks
=
make_tuple
(
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}),
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}));
RunWrite
(
dst_desc
,
dst_buf
,
dst_step_hacks
);
}
__device__
static
constexpr
auto
GetSrcCoordinateResetStep
()
{
constexpr
auto
I0
=
Number
<
0
>
{};
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
constexpr
auto
src_dim_access_order
=
SrcDimAccessOrder
{};
constexpr
auto
ordered_src_access_lengths
=
container_reorder_given_new2old
(
src_access_lengths
,
src_dim_access_order
);
// judge move forward or move backward during the last iteration
constexpr
auto
forward_sweep
=
[
&
]()
{
StaticallyIndexedArray
<
bool
,
nDim
>
forward_sweep_
;
forward_sweep_
(
I0
)
=
true
;
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_src_access_lengths
[
I0
]
-
1
;
static_for
<
0
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_src_access_lengths
[
j
]
+
ordered_src_access_lengths
[
j
]
-
1
;
});
forward_sweep_
(
i
)
=
tmp
%
2
==
0
;
});
return
forward_sweep_
;
}();
// calculate src data index after last iteration in RunRead(), if it has not being reset by
// RunRead()
constexpr
auto
src_data_idx
=
[
&
]()
{
Index
ordered_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
ordered_idx
(
i
)
=
forward_sweep
[
i
]
?
ordered_src_access_lengths
[
i
]
-
1
:
0
;
});
return
container_reorder_given_old2new
(
ordered_idx
,
src_dim_access_order
)
*
src_scalar_per_access
;
}();
//
constexpr
auto
reset_src_data_step
=
[
&
]()
{
Index
reset_src_data_step_
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
reset_src_data_step_
(
i
)
=
-
src_data_idx
[
i
];
});
return
reset_src_data_step_
;
}();
return
reset_src_data_step
;
}
__device__
static
constexpr
auto
GetDstCoordinateResetStep
()
{
constexpr
auto
I0
=
Number
<
0
>
{};
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
dst_access_lengths
=
SliceLengths
{}
/
dst_scalar_per_access
;
constexpr
auto
dst_dim_access_order
=
DstDimAccessOrder
{};
constexpr
auto
ordered_dst_access_lengths
=
container_reorder_given_new2old
(
dst_access_lengths
,
dst_dim_access_order
);
// judge move forward or move backward during the last iteration
constexpr
auto
forward_sweep
=
[
&
]()
{
StaticallyIndexedArray
<
bool
,
nDim
>
forward_sweep_
;
forward_sweep_
(
I0
)
=
true
;
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_dst_access_lengths
[
I0
]
-
1
;
static_for
<
0
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_dst_access_lengths
[
j
]
+
ordered_dst_access_lengths
[
j
]
-
1
;
});
forward_sweep_
(
i
)
=
tmp
%
2
==
0
;
});
return
forward_sweep_
;
}();
// calculate dst data index after last iteration in RunWrite(), if it has not being reset by
// RunWrite()
constexpr
auto
dst_data_idx
=
[
&
]()
{
Index
ordered_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
ordered_idx
(
i
)
=
forward_sweep
[
i
]
?
ordered_dst_access_lengths
[
i
]
-
1
:
0
;
});
return
container_reorder_given_old2new
(
ordered_idx
,
dst_dim_access_order
)
*
dst_scalar_per_access
;
}();
//
constexpr
auto
reset_dst_data_step
=
[
&
]()
{
Index
reset_dst_data_step_
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
reset_dst_data_step_
(
i
)
=
-
dst_data_idx
[
i
];
});
return
reset_dst_data_step_
;
}();
return
reset_dst_data_step
;
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_step_idx
)
{
// if src coord was not reset by RunRead(), then need to adjust the step here
const
auto
adjusted_step_idx
=
SrcResetCoordinateAfterRun
?
src_slice_origin_step_idx
:
src_slice_origin_step_idx
+
GetSrcCoordinateResetStep
();
// is it OK to construct a new step every time?
const
auto
adjusted_step
=
make_tensor_coordinate_step
(
src_desc
,
adjusted_step_idx
);
move_tensor_coordinate
(
src_desc
,
src_coord_
,
adjusted_step
);
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
template
<
typename
SrcMoveSliceWindowStepHack
>
__device__
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_step_idx
,
const
SrcMoveSliceWindowStepHack
&
src_move_slice_window_step_hack
)
{
// if src coord was not reset by RunRead(), then need to adjust the step here
const
auto
adjusted_step_idx
=
SrcResetCoordinateAfterRun
?
src_slice_origin_step_idx
:
src_slice_origin_step_idx
+
GetSrcCoordinateResetStep
();
// is it OK to construct a new step every time?
const
auto
adjusted_step
=
make_tensor_coordinate_step
(
src_desc
,
adjusted_step_idx
,
src_move_slice_window_step_hack
);
move_tensor_coordinate
(
src_desc
,
src_coord_
,
adjusted_step
);
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__
void
MoveDstSliceWindow
(
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin_step_idx
)
{
// if dst coord was not reset by RunWrite(), then need to adjust the step here
const
auto
adjusted_step_idx
=
DstResetCoordinateAfterRun
?
dst_slice_origin_step_idx
:
dst_slice_origin_step_idx
+
GetDstCoordinateResetStep
();
// is it OK to construct a new step every time?
const
auto
adjusted_step
=
make_tensor_coordinate_step
(
dst_desc
,
adjusted_step_idx
);
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
adjusted_step
);
}
__device__
static
constexpr
auto
GetSrcThreadScratchDescriptor
()
{
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
constexpr
auto
src_access_lengths_and_vector_length
=
container_push_back
(
sequence_to_tuple_of_number
(
src_access_lengths
),
Number
<
SrcScalarPerVector
>
{});
// 1st stage of transforms
constexpr
auto
desc0
=
make_naive_tensor_descriptor_packed
(
src_access_lengths_and_vector_length
);
// 2nd stage of transforms
constexpr
auto
transforms
=
generate_tuple
(
[
&
](
auto
i
)
{
if
constexpr
(
i
==
SrcVectorDim
)
{
return
make_merge_transform_v3_division_mod
(
make_tuple
(
src_access_lengths_and_vector_length
[
i
],
src_access_lengths_and_vector_length
[
Number
<
nDim
>
{}]));
}
else
{
return
make_pass_through_transform
(
src_access_lengths_and_vector_length
[
i
]);
}
},
Number
<
nDim
>
{});
constexpr
auto
low_dim_idss
=
generate_tuple
(
[
&
](
auto
i
)
{
if
constexpr
(
i
==
SrcVectorDim
)
{
return
Sequence
<
i
.
value
,
nDim
>
{};
}
else
{
return
Sequence
<
i
.
value
>
{};
}
},
Number
<
nDim
>
{});
constexpr
auto
up_dim_idss
=
generate_tuple
([
&
](
auto
i
)
{
return
Sequence
<
i
.
value
>
{};
},
Number
<
nDim
>
{});
return
transform_tensor_descriptor
(
desc0
,
transforms
,
low_dim_idss
,
up_dim_idss
);
}
__device__
static
constexpr
auto
GetDstThreadScratchDescriptor
()
{
// 1st stage of transforms
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
dst_access_lengths
=
SliceLengths
{}
/
dst_scalar_per_access
;
constexpr
auto
dst_access_lengths_and_vector_length
=
container_push_back
(
sequence_to_tuple_of_number
(
dst_access_lengths
),
Number
<
DstScalarPerVector
>
{});
constexpr
auto
desc0
=
make_naive_tensor_descriptor_packed
(
dst_access_lengths_and_vector_length
);
// 2nd stage of transforms
constexpr
auto
transforms
=
generate_tuple
(
[
&
](
auto
i
)
{
if
constexpr
(
i
==
DstVectorDim
)
{
return
make_merge_transform_v3_division_mod
(
make_tuple
(
dst_access_lengths_and_vector_length
[
i
],
dst_access_lengths_and_vector_length
[
Number
<
nDim
>
{}]));
}
else
{
return
make_pass_through_transform
(
dst_access_lengths_and_vector_length
[
i
]);
}
},
Number
<
nDim
>
{});
constexpr
auto
low_dim_idss
=
generate_tuple
(
[
&
](
auto
i
)
{
if
constexpr
(
i
==
DstVectorDim
)
{
return
Sequence
<
i
.
value
,
nDim
>
{};
}
else
{
return
Sequence
<
i
.
value
>
{};
}
},
Number
<
nDim
>
{});
constexpr
auto
up_dim_idss
=
generate_tuple
([
&
](
auto
i
)
{
return
Sequence
<
i
.
value
>
{};
},
Number
<
nDim
>
{});
return
transform_tensor_descriptor
(
desc0
,
transforms
,
low_dim_idss
,
up_dim_idss
);
}
private:
static
constexpr
auto
src_thread_scratch_desc_
=
decltype
(
GetSrcThreadScratchDescriptor
()){};
static
constexpr
auto
dst_thread_scratch_desc_
=
decltype
(
GetDstThreadScratchDescriptor
()){};
StaticTensorTupleOfVectorBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
SrcData
,
SrcScalarPerVector
,
decltype
(
src_thread_scratch_desc_
),
true
>
src_thread_scratch_
;
StaticTensorTupleOfVectorBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
DstData
,
DstScalarPerVector
,
decltype
(
dst_thread_scratch_desc_
),
true
>
dst_thread_scratch_
;
SrcCoord
src_coord_
;
DstCoord
dst_coord_
;
};
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
View file @
8c4e33f1
...
@@ -12,18 +12,19 @@ enum struct MfmaInstr
...
@@ -12,18 +12,19 @@ enum struct MfmaInstr
mfma_f32_32x32x1xf32
=
0
,
mfma_f32_32x32x1xf32
=
0
,
mfma_f32_16x16x1xf32
,
mfma_f32_16x16x1xf32
,
mfma_f32_4x4x1xf32
,
mfma_f32_4x4x1xf32
,
mfma_f32_32x32x2xf32
,
// k reduction
mfma_f32_32x32x2xf32
,
mfma_f32_16x16x4xf32
,
// k reduction
mfma_f32_16x16x4xf32
,
mfma_f32_32x32x4f16
,
mfma_f32_32x32x4f16
,
mfma_f32_16x16x4f16
,
mfma_f32_16x16x4f16
,
mfma_f32_4x4x4f16
,
mfma_f32_4x4x4f16
,
mfma_f32_32x32x8f16
,
// k reduction
mfma_f32_32x32x8f16
,
mfma_f32_16x16x16f16
,
// k reduction
mfma_f32_16x16x16f16
,
mfma_f32_32x32x2bf16
,
mfma_f32_32x32x8bf16_1k
,
mfma_f32_16x16x2bf16
,
mfma_f32_16x16x16bf16_1k
,
mfma_f32_4x4x2bf16
,
mfma_f32_32x32x4bf16
,
mfma_f32_32x32x4bf16
,
// k reduction
mfma_f32_16x16x8bf16
,
mfma_f32_16x16x8bf16
,
// k reduction
mfma_i32_32x32x8i8
,
mfma_i32_16x16x16i8
,
};
};
template
<
MfmaInstr
instr
>
template
<
MfmaInstr
instr
>
...
@@ -250,9 +251,8 @@ struct mfma_type<MfmaInstr::mfma_f32_4x4x4f16>
...
@@ -250,9 +251,8 @@ struct mfma_type<MfmaInstr::mfma_f32_4x4x4f16>
}
}
};
};
#if 0
template
<
>
template
<
>
struct mfma_type<MfmaInstr::mfma_f32_32x32x
2
bf16>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x
8
bf16
_1k
>
{
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
4
;
...
@@ -260,26 +260,38 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x2bf16>
...
@@ -260,26 +260,38 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x2bf16>
static
constexpr
index_t
num_threads_per_blk
=
32
;
static
constexpr
index_t
num_threads_per_blk
=
32
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
2
;
static
constexpr
index_t
num_input_blks
=
2
;
static constexpr index_t num_output_blks =
2
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
32
;
static
constexpr
index_t
m_per_blk
=
32
;
static
constexpr
index_t
n_per_blk
=
32
;
static
constexpr
index_t
n_per_blk
=
32
;
static constexpr index_t k_per_blk =
2
;
static
constexpr
index_t
k_per_blk
=
4
;
static constexpr bool is_k_reduction =
fals
e;
static
constexpr
bool
is_k_reduction
=
tru
e
;
template <index_t MPerXdlops,
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
index_t NPerXdlops,
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
index_t AStride,
index_t BStride,
class FloatA,
class FloatB,
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{
{
const auto p_a = c_style_pointer_cast<const ushort2_t*>(a);
intrin_mfma_f32_32x32x8bf16_1k
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
}
};
return intrin_mfma_f32_32x32x2bf16<MPerXdlops, NPerXdlops, AStride, BStride>::run(
template
<
>
p_a, p_b, reg_c);
struct
mfma_type
<
MfmaInstr
::
mfma_f32_16x16x16bf16_1k
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
1
;
static
constexpr
index_t
num_regs_per_blk
=
4
;
static
constexpr
index_t
num_threads_per_blk
=
16
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
4
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
16
;
static
constexpr
index_t
n_per_blk
=
16
;
static
constexpr
index_t
k_per_blk
=
4
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_f32_16x16x16bf16_1k
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
}
};
};
...
@@ -298,19 +310,10 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x4bf16>
...
@@ -298,19 +310,10 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x4bf16>
static
constexpr
index_t
k_per_blk
=
2
;
static
constexpr
index_t
k_per_blk
=
2
;
static
constexpr
bool
is_k_reduction
=
true
;
static
constexpr
bool
is_k_reduction
=
true
;
template <index_t MPerXdlops,
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
index_t NPerXdlops,
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
index_t AStride,
index_t BStride,
class FloatA,
class FloatB,
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{
{
const auto p_a = c_style_pointer_cast<const ushort2_t*>(a);
intrin_mfma_f32_32x32x4bf16
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
return intrin_mfma_f32_32x32x4bf16(p_a, p_b, reg_c);
}
}
};
};
...
@@ -329,84 +332,56 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x8bf16>
...
@@ -329,84 +332,56 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x8bf16>
static
constexpr
index_t
k_per_blk
=
2
;
static
constexpr
index_t
k_per_blk
=
2
;
static
constexpr
bool
is_k_reduction
=
true
;
static
constexpr
bool
is_k_reduction
=
true
;
template <index_t MPerXdlops,
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
index_t NPerXdlops,
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
index_t AStride,
index_t BStride,
class FloatA,
class FloatB,
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{
{
const auto p_a = c_style_pointer_cast<const ushort2_t*>(a);
intrin_mfma_f32_16x16x8bf16
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
return intrin_mfma_f32_16x16x8bf16(p_a, p_b, reg_c);
}
}
};
};
template
<
>
template
<
>
struct mfma_type<MfmaInstr::mfma_
f
32_
16x16x2bf16
>
struct
mfma_type
<
MfmaInstr
::
mfma_
i
32_
32x32x8i8
>
{
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
group_size
=
4
;
static constexpr index_t num_groups_per_blk =
1
;
static
constexpr
index_t
num_groups_per_blk
=
4
;
static constexpr index_t num_regs_per_blk =
4
;
static
constexpr
index_t
num_regs_per_blk
=
16
;
static constexpr index_t num_threads_per_blk =
16
;
static
constexpr
index_t
num_threads_per_blk
=
32
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
wave_size
=
64
;
static constexpr index_t num_input_blks =
4
;
static
constexpr
index_t
num_input_blks
=
2
;
static constexpr index_t num_output_blks =
4
;
static
constexpr
index_t
num_output_blks
=
1
;
static constexpr index_t m_per_blk =
16
;
static
constexpr
index_t
m_per_blk
=
32
;
static constexpr index_t n_per_blk =
16
;
static
constexpr
index_t
n_per_blk
=
32
;
static constexpr index_t k_per_blk =
2
;
static
constexpr
index_t
k_per_blk
=
4
;
static constexpr bool is_k_reduction =
fals
e;
static
constexpr
bool
is_k_reduction
=
tru
e
;
template <index_t MPerXdlops,
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
index_t NPerXdlops,
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
index_t AStride,
index_t BStride,
class FloatA,
class FloatB,
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{
{
const auto p_a = c_style_pointer_cast<const ushort2_t*>(a);
intrin_mfma_i32_32x32x8i8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
return intrin_mfma_f32_16x16x2bf16<MPerXdlops, NPerXdlops>(p_a, p_b, reg_c);
}
}
};
};
template
<
>
template
<
>
struct mfma_type<MfmaInstr::mfma_
f
32_
4x4x2bf16
>
struct
mfma_type
<
MfmaInstr
::
mfma_
i
32_
16x16x16i8
>
{
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
1
;
static
constexpr
index_t
num_groups_per_blk
=
1
;
static
constexpr
index_t
num_regs_per_blk
=
4
;
static
constexpr
index_t
num_regs_per_blk
=
4
;
static constexpr index_t num_threads_per_blk = 6
4
;
static
constexpr
index_t
num_threads_per_blk
=
1
6
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
wave_size
=
64
;
static constexpr index_t num_input_blks =
1
;
static
constexpr
index_t
num_input_blks
=
4
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
num_output_blks
=
1
;
static constexpr index_t m_per_blk =
4
;
static
constexpr
index_t
m_per_blk
=
16
;
static constexpr index_t n_per_blk = 6
4
;
static
constexpr
index_t
n_per_blk
=
1
6
;
static constexpr index_t k_per_blk =
2
;
static
constexpr
index_t
k_per_blk
=
4
;
static constexpr bool is_k_reduction =
fals
e;
static
constexpr
bool
is_k_reduction
=
tru
e
;
template <index_t MPerXdlops,
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
index_t NPerXdlops,
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
index_t AStride,
index_t BStride,
class FloatA,
class FloatB,
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{
{
const auto p_a = c_style_pointer_cast<const ushort2_t*>(a);
intrin_mfma_i32_16x16x16i8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
return intrin_mfma_f32_4x4x2bf16<MPerXdlops, NPerXdlops>::run(p_a, p_b, reg_c);
}
}
};
};
#endif
template
<
typename
base_type
,
index_t
MPerXdlops
,
index_t
NPerXdlops
>
template
<
typename
base_type
,
index_t
MPerXdlops
,
index_t
NPerXdlops
>
struct
MfmaSelector
struct
MfmaSelector
...
@@ -498,73 +473,37 @@ struct MfmaSelector
...
@@ -498,73 +473,37 @@ struct MfmaSelector
return
MfmaInstr
::
mfma_f32_4x4x4f16
;
return
MfmaInstr
::
mfma_f32_4x4x4f16
;
}
}
#if 0
template <>
static constexpr auto GetMfma<ushort, 128, 64>()
{
return xdlops_info<MfmaInstr::mfma_f32_32x32x2bf16, 64, 64, 2, 1, c_vec32_4_t>{};
}
template <>
static constexpr auto GetMfma<ushort, 64, 128>()
{
return xdlops_info<MfmaInstr::mfma_f32_32x32x2bf16, 64, 64, 1, 2, c_vec32_4_t>{};
}
template <>
static constexpr auto GetMfma<ushort, 64, 64>()
{
return xdlops_info<MfmaInstr::mfma_f32_32x32x2bf16, 64, 64, 1, 1, c_vec32_2_t>{};
}
template <>
static constexpr auto GetMfma<ushort, 64, 32>()
{
return xdlops_info<MfmaInstr::mfma_f32_32x32x2bf16, 64, 32, 1, 1, c_vec32_1_t>{};
}
template <>
static constexpr auto GetMfma<ushort, 32, 64>()
{
return xdlops_info<MfmaInstr::mfma_f32_32x32x2bf16, 32, 64, 1, 1, c_vec32_1_t>{};
}
template <>
static constexpr auto GetMfma<ushort, 64, 16>()
{
return xdlops_info<MfmaInstr::mfma_f32_16x16x2bf16, 64, 16, 1, 1, c_vec16_1_t>{};
}
template
<
>
template
<
>
static constexpr auto GetMfma<ushort, 16, 64>()
static
constexpr
auto
GetMfma
<
ushort
,
32
,
32
>
()
{
return xdlops_info<MfmaInstr::mfma_f32_16x16x2bf16, 16, 64, 1, 1, c_vec16_1_t>{};
}
template <>
static constexpr auto GetMfma<ushort, 8, 64>()
{
{
return xdlops_info<MfmaInstr::mfma_f32_4x4x2bf16, 8, 64, 1, 1, c_vec4_2_t>{};
#if defined(CK_AMD_GPU_GFX90A)
return
MfmaInstr
::
mfma_f32_32x32x8bf16_1k
;
#else
return
MfmaInstr
::
mfma_f32_32x32x4bf16
;
#endif
}
}
template
<
>
template
<
>
static constexpr auto GetMfma<ushort,
4, 64
>()
static
constexpr
auto
GetMfma
<
ushort
,
16
,
16
>
()
{
{
return xdlops_info<MfmaInstr::mfma_f32_4x4x2bf16, 4, 64, 1, 1, c_vec4_1_t>{};
#if defined(CK_AMD_GPU_GFX90A)
return
MfmaInstr
::
mfma_f32_16x16x16bf16_1k
;
#else
return
MfmaInstr
::
mfma_f32_16x16x8bf16
;
#endif
}
}
template
<
>
template
<
>
static constexpr auto GetMfma<
ushor
t, 32, 32>()
static
constexpr
auto
GetMfma
<
int8_
t
,
32
,
32
>
()
{
{
return
xdlops_info<
MfmaInstr::mfma_
f
32_32x32x
4bf16, 32, 32, 1, 1, c_vec16_1_t>{}
;
return
MfmaInstr
::
mfma_
i
32_32x32x
8i8
;
}
}
template
<
>
template
<
>
static constexpr auto GetMfma<
ushor
t, 16, 16>()
static
constexpr
auto
GetMfma
<
int8_
t
,
16
,
16
>
()
{
{
return
xdlops_info<
MfmaInstr::mfma_
f
32_16x16x
8bf16, 16, 16, 1, 1, c_vec4_1_t>{}
;
return
MfmaInstr
::
mfma_
i
32_16x16x
16i8
;
}
}
#endif
static
constexpr
auto
selected_mfma
=
mfma_type
<
GetMfma
<
base_type
,
MPerXdlops
,
NPerXdlops
>
()
>
{};
static
constexpr
auto
selected_mfma
=
mfma_type
<
GetMfma
<
base_type
,
MPerXdlops
,
NPerXdlops
>
()
>
{};
...
@@ -686,8 +625,8 @@ struct XdlopsGemm
...
@@ -686,8 +625,8 @@ struct XdlopsGemm
__device__
void
Run
(
const
FloatA
&
p_a_wave
,
const
FloatB
&
p_b_wave
,
FloatC
&
p_c_thread
)
const
__device__
void
Run
(
const
FloatA
&
p_a_wave
,
const
FloatB
&
p_b_wave
,
FloatC
&
p_c_thread
)
const
{
{
static_assert
(
is_same
<
base_type
,
float
>::
value
||
is_same
<
base_type
,
half_t
>::
value
||
static_assert
(
is_same
<
base_type
,
float
>::
value
||
is_same
<
base_type
,
half_t
>::
value
||
is_same
<
base_type
,
ushort
>::
value
,
is_same
<
base_type
,
ushort
>::
value
||
is_same
<
base_type
,
int8_t
>::
value
,
"base base_type must be float, half, ushort!"
);
"base base_type must be float, half, ushort
, and int8_t
!"
);
static_for
<
0
,
KPack
/
mfma_instr
.
k_per_blk
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
KPack
/
mfma_instr
.
k_per_blk
,
1
>
{}([
&
](
auto
k
)
{
mfma_instr
.
template
run
<
MPerXdlops
,
NPerXdlops
>(
p_a_wave
[
k
],
p_b_wave
[
k
],
p_c_thread
);
mfma_instr
.
template
run
<
MPerXdlops
,
NPerXdlops
>(
p_a_wave
[
k
],
p_b_wave
[
k
],
p_c_thread
);
...
...
composable_kernel/include/utility/amd_buffer_addressing.hpp
View file @
8c4e33f1
...
@@ -50,11 +50,24 @@ llvm_amdgcn_raw_buffer_load_i8x4(int32x4_t srsrc,
...
@@ -50,11 +50,24 @@ llvm_amdgcn_raw_buffer_load_i8x4(int32x4_t srsrc,
index_t
soffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.load.v4i8"
);
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.load.v4i8"
);
__device__
int16_
t
__device__
ushor
t
llvm_amdgcn_raw_buffer_load_i16
(
int32x4_t
srsrc
,
llvm_amdgcn_raw_buffer_load_i16
(
int32x4_t
srsrc
,
index_t
voffset
,
index_t
voffset
,
index_t
soffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.load.i32"
);
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.load.i16"
);
__device__
ushort2_t
llvm_amdgcn_raw_buffer_load_i16x2
(
int32x4_t
srsrc
,
index_t
voffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.load.v2i16"
);
__device__
ushort4_t
llvm_amdgcn_raw_buffer_load_i16x4
(
int32x4_t
srsrc
,
index_t
voffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.load.v4i16"
);
__device__
int32_t
__device__
int32_t
llvm_amdgcn_raw_buffer_load_i32
(
int32x4_t
srsrc
,
llvm_amdgcn_raw_buffer_load_i32
(
int32x4_t
srsrc
,
index_t
voffset
,
index_t
voffset
,
...
@@ -133,12 +146,26 @@ llvm_amdgcn_raw_buffer_store_i8x4(int8x4_t vdata,
...
@@ -133,12 +146,26 @@ llvm_amdgcn_raw_buffer_store_i8x4(int8x4_t vdata,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.store.v4i8"
);
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.store.v4i8"
);
__device__
void
__device__
void
llvm_amdgcn_raw_buffer_store_i16
(
int16_
t
vdata
,
llvm_amdgcn_raw_buffer_store_i16
(
ushor
t
vdata
,
int32x4_t
rsrc
,
int32x4_t
rsrc
,
index_t
voffset
,
index_t
voffset
,
index_t
soffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.store.i16"
);
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.store.i16"
);
__device__
void
llvm_amdgcn_raw_buffer_store_i16x2
(
ushort2_t
vdata
,
int32x4_t
rsrc
,
index_t
voffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.store.v2i16"
);
__device__
void
llvm_amdgcn_raw_buffer_store_i16x4
(
ushort4_t
vdata
,
int32x4_t
rsrc
,
index_t
voffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.store.v4i16"
);
__device__
void
__device__
void
llvm_amdgcn_raw_buffer_store_i32
(
int32_t
vdata
,
llvm_amdgcn_raw_buffer_store_i32
(
int32_t
vdata
,
int32x4_t
rsrc
,
int32x4_t
rsrc
,
...
@@ -228,6 +255,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
...
@@ -228,6 +255,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
(
is_same
<
T
,
double
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
))
||
(
is_same
<
T
,
double
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
))
||
(
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
half_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
half_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
ushort
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
int32_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
int32_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
)),
(
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
)),
"wrong! not implemented"
);
"wrong! not implemented"
);
...
@@ -326,6 +354,31 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
...
@@ -326,6 +354,31 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
return
as_type
<
half8_t
>
(
tmp
);
return
as_type
<
half8_t
>
(
tmp
);
}
}
}
}
else
if
constexpr
(
is_same
<
T
,
ushort
>::
value
)
{
if
constexpr
(
N
==
1
)
{
return
llvm_amdgcn_raw_buffer_load_i16
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
}
else
if
constexpr
(
N
==
2
)
{
return
llvm_amdgcn_raw_buffer_load_i16x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
}
else
if
constexpr
(
N
==
4
)
{
return
llvm_amdgcn_raw_buffer_load_i16x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
}
else
if
constexpr
(
N
==
8
)
{
int32x4_t
tmp
=
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
return
as_type
<
ushort8_t
>
(
tmp
);
}
}
else
if
constexpr
(
is_same
<
T
,
int32_t
>::
value
)
else
if
constexpr
(
is_same
<
T
,
int32_t
>::
value
)
{
{
if
constexpr
(
N
==
1
)
if
constexpr
(
N
==
1
)
...
@@ -458,6 +511,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
...
@@ -458,6 +511,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
(
is_same
<
T
,
double
>::
value
&&
(
N
==
1
||
N
==
2
))
||
(
is_same
<
T
,
double
>::
value
&&
(
N
==
1
||
N
==
2
))
||
(
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
))
||
(
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
))
||
(
is_same
<
T
,
half_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
half_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
ushort
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
int32_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
))
||
(
is_same
<
T
,
int32_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
))
||
(
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
)),
(
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
)),
"wrong! not implemented"
);
"wrong! not implemented"
);
...
@@ -560,6 +614,49 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
...
@@ -560,6 +614,49 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
#endif
#endif
}
}
}
}
else
if
constexpr
(
is_same
<
T
,
ushort
>::
value
)
{
if
constexpr
(
N
==
1
)
{
llvm_amdgcn_raw_buffer_store_i16
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
}
else
if
constexpr
(
N
==
2
)
{
llvm_amdgcn_raw_buffer_store_fp16x2
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
}
else
if
constexpr
(
N
==
4
)
{
llvm_amdgcn_raw_buffer_store_fp16x4
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
}
else
if
constexpr
(
N
==
8
)
{
vector_type
<
half_t
,
8
>
tmp
{
src_thread_data
};
llvm_amdgcn_raw_buffer_store_fp16x4
(
tmp
.
AsType
<
half4_t
>
()[
Number
<
0
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
llvm_amdgcn_raw_buffer_store_fp16x4
(
tmp
.
AsType
<
half4_t
>
()[
Number
<
1
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
4
*
sizeof
(
half_t
),
0
);
}
}
else
if
constexpr
(
is_same
<
T
,
int32_t
>::
value
)
else
if
constexpr
(
is_same
<
T
,
int32_t
>::
value
)
{
{
if
constexpr
(
N
==
1
)
if
constexpr
(
N
==
1
)
...
...
composable_kernel/include/utility/amd_xdlops.hpp
View file @
8c4e33f1
...
@@ -6,6 +6,7 @@
...
@@ -6,6 +6,7 @@
namespace
ck
{
namespace
ck
{
// A, B, C, cbsz, abid, blgp
// A, B, C, cbsz, abid, blgp
// fp32
extern
"C"
__device__
float32_t
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
extern
"C"
__device__
float32_t
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
float
,
float
,
float32_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.32x32x1f32"
);
float
,
float
,
float32_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.32x32x1f32"
);
...
@@ -21,6 +22,7 @@ extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x1f32(
...
@@ -21,6 +22,7 @@ extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x1f32(
extern
"C"
__device__
float4_t
llvm_intrin_amdgcn_mfma_f32_4x4x1f32
(
extern
"C"
__device__
float4_t
llvm_intrin_amdgcn_mfma_f32_4x4x1f32
(
float
,
float
,
float4_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.4x4x1f32"
);
float
,
float
,
float4_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.4x4x1f32"
);
// fp16
extern
"C"
__device__
float32_t
llvm_intrin_amdgcn_mfma_f32_32x32x4f16
(
extern
"C"
__device__
float32_t
llvm_intrin_amdgcn_mfma_f32_32x32x4f16
(
half4_t
,
half4_t
,
float32_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.32x32x4f16"
);
half4_t
,
half4_t
,
float32_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.32x32x4f16"
);
...
@@ -36,6 +38,13 @@ extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x4f16(
...
@@ -36,6 +38,13 @@ extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x4f16(
extern
"C"
__device__
float4_t
llvm_intrin_amdgcn_mfma_f32_4x4x4f16
(
extern
"C"
__device__
float4_t
llvm_intrin_amdgcn_mfma_f32_4x4x4f16
(
half4_t
,
half4_t
,
float4_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.4x4x4f16"
);
half4_t
,
half4_t
,
float4_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.4x4x4f16"
);
// bfp16
extern
"C"
__device__
float16_t
llvm_intrin_amdgcn_mfma_f32_32x32x8bf16_1k
(
ushort4_t
,
ushort4_t
,
float16_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.32x32x8bf16.1k"
);
extern
"C"
__device__
float4_t
llvm_intrin_amdgcn_mfma_f32_16x16x16bf16_1k
(
ushort4_t
,
ushort4_t
,
float4_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.16x16x16bf16.1k"
);
extern
"C"
__device__
float32_t
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16
(
extern
"C"
__device__
float32_t
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16
(
ushort2_t
,
ushort2_t
,
float32_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.32x32x2bf16"
);
ushort2_t
,
ushort2_t
,
float32_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.32x32x2bf16"
);
...
@@ -51,6 +60,23 @@ extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(
...
@@ -51,6 +60,23 @@ extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(
extern
"C"
__device__
float4_t
llvm_intrin_amdgcn_mfma_f32_4x4x2bf16
(
extern
"C"
__device__
float4_t
llvm_intrin_amdgcn_mfma_f32_4x4x2bf16
(
ushort2_t
,
ushort2_t
,
float4_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.4x4x2bf16"
);
ushort2_t
,
ushort2_t
,
float4_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.4x4x2bf16"
);
// int8
extern
"C"
__device__
int32x32_t
llvm_intrin_amdgcn_mfma_i32_32x32x4i8
(
int
,
int
,
int32x32_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.i32.32x32x4i8"
);
extern
"C"
__device__
int32x16_t
llvm_intrin_amdgcn_mfma_i32_16x16x4i8
(
int
,
int
,
int32x16_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.i32.16x16x4i8"
);
extern
"C"
__device__
int32x4_t
llvm_intrin_amdgcn_mfma_i32_4x4x4i8
(
int
,
int
,
int32x4_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.i32.4x4x4i8"
);
extern
"C"
__device__
int32x16_t
llvm_intrin_amdgcn_mfma_i32_32x32x8i8
(
int
,
int
,
int32x16_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.i32.32x32x8i8"
);
extern
"C"
__device__
int32x4_t
llvm_intrin_amdgcn_mfma_i32_16x16x16i8
(
int
,
int
,
int32x4_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.i32.16x16x16i8"
);
// fp32
template
<
index_t
MPerWave
,
index_t
NPerWave
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x1f32
;
struct
intrin_mfma_f32_32x32x1f32
;
...
@@ -148,6 +174,7 @@ struct intrin_mfma_f32_4x4x1f32<8, 64>
...
@@ -148,6 +174,7 @@ struct intrin_mfma_f32_4x4x1f32<8, 64>
}
}
};
};
// fp16
template
<
index_t
MPerWave
,
index_t
NPerWave
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x4f16
;
struct
intrin_mfma_f32_32x32x4f16
;
...
@@ -244,147 +271,102 @@ struct intrin_mfma_f32_4x4x4f16<8, 64>
...
@@ -244,147 +271,102 @@ struct intrin_mfma_f32_4x4x4f16<8, 64>
}
}
};
};
#if 0
// bfp16
template <index_t MPerWave, index_t NPerWave
, index_t AStride, index_t BStride
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct intrin_mfma_f32_32x32x
2
bf16;
struct
intrin_mfma_f32_32x32x
8
bf16
_1k
;
template <
index_t AStride, index_t BStride
>
template
<
>
struct intrin_mfma_f32_32x32x
2
bf16
<128, 64, AStride, BStride
>
struct
intrin_mfma_f32_32x32x
8
bf16
_1k
<
32
,
32
>
{
{
__device__ static c_vec32_4_t::VecType
template
<
class
FloatC
>
r
un(const ushort
2
_t
*
reg_a, const ushort
2
_t
*
reg_b,
c_vec32_4_t::VecType
reg_c)
__device__
static
void
R
un
(
const
ushort
4
_t
&
reg_a
,
const
ushort
4
_t
&
reg_b
,
FloatC
&
reg_c
)
{
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
llvm_intrin_amdgcn_mfma_f32_32x32x8bf16_1k
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
reg_c.s.z =
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[AStride], reg_b[0], reg_c.s.z, 1, 0, 0);
reg_c.s.w =
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[AStride], reg_b[0], reg_c.s.w, 1, 1, 0);
return reg_c;
}
}
};
};
template <index_t AStride, index_t BStride>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct intrin_mfma_f32_32x32x2bf16<64, 128, AStride, BStride>
struct
intrin_mfma_f32_16x16x16bf16_1k
;
{
__device__ static c_vec32_4_t::VecType
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_4_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
reg_c.s.z =
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[BStride], reg_c.s.z, 1, 0, 0);
reg_c.s.w =
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[BStride], reg_c.s.w, 1, 1, 0);
return reg_c;
}
};
template <
index_t AStride, index_t BStride
>
template
<
>
struct intrin_mfma_f32_
32x32x2bf16<64, 64, AStride, BStride
>
struct
intrin_mfma_f32_
16x16x16bf16_1k
<
16
,
16
>
{
{
__device__ static c_vec32_2_t::VecType
template
<
class
FloatC
>
r
un(const ushort
2
_t
*
reg_a, const ushort
2
_t
*
reg_b,
c_vec32_2_t::VecType
reg_c)
__device__
static
void
R
un
(
const
ushort
4
_t
&
reg_a
,
const
ushort
4
_t
&
reg_b
,
FloatC
&
reg_c
)
{
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
llvm_intrin_amdgcn_mfma_f32_16x16x16bf16_1k
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
return reg_c;
}
}
};
};
template <index_t AStride, index_t BStride>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct intrin_mfma_f32_32x32x2bf16<64, 32, AStride, BStride>
struct
intrin_mfma_f32_32x32x4bf16
;
{
__device__ static c_vec32_1_t::VecType
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 1);
return reg_c;
}
};
template <
index_t AStride, index_t BStride
>
template
<
>
struct intrin_mfma_f32_32x32x
2
bf16<32,
64, AStride, BStride
>
struct
intrin_mfma_f32_32x32x
4
bf16
<
32
,
32
>
{
{
__device__ static c_vec32_1_t::VecType
template
<
class
FloatC
>
r
un(const ushort2_t
*
reg_a, const ushort2_t
*
reg_b,
c_vec32_1_t::VecType
reg_c)
__device__
static
void
R
un
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
FloatC
&
reg_c
)
{
{
reg_c.
s.x
= llvm_intrin_amdgcn_mfma_f32_32x32x
2
bf16(
reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_32x32x
4
bf16
(
return reg_c
;
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
)
;
}
}
};
};
__device__ c_vec16_1_t::VecType intrin_mfma_f32_32x32x4bf16(const ushort2_t* reg_a,
const ushort2_t* reg_b,
c_vec16_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x4bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 0);
return reg_c;
}
__device__ c_vec4_1_t::VecType intrin_mfma_f32_16x16x8bf16(const ushort2_t* reg_a,
const ushort2_t* reg_b,
c_vec4_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x8bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 0);
return reg_c;
}
template
<
index_t
MPerWave
,
index_t
NPerWave
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16(const ushort2_t* reg_a,
struct
intrin_mfma_f32_16x16x8bf16
;
const ushort2_t* reg_b,
c_vec16_1_t::VecType reg_c);
template <>
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16<16, 64>(const ushort2_t* reg_a,
const ushort2_t* reg_b,
c_vec16_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 2, 0, 0);
return reg_c;
}
template
<
>
template
<
>
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16<64, 16>(const ushort2_t* reg_a,
struct
intrin_mfma_f32_16x16x8bf16
<
16
,
16
>
const ushort2_t* reg_b,
c_vec16_1_t::VecType reg_c)
{
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 4);
template
<
class
FloatC
>
return reg_c;
__device__
static
void
Run
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
FloatC
&
reg_c
)
}
{
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_16x16x8bf16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct intrin_mfma_
f
32_
4x4x2bf16
;
struct
intrin_mfma_
i
32_
32x32x8i8
;
template
<
>
template
<
>
struct intrin_mfma_
f
32_
4x4x2bf16<4, 64
>
struct
intrin_mfma_
i
32_
32x32x8i8
<
32
,
32
>
{
{
__device__ static c_vec4_1_t::VecType
template
<
class
FloatC
>
r
un(const
ushort2
_t
*
reg_a, const
ushort2
_t
*
reg_b,
c_vec4_1_t::VecType
reg_c)
__device__
static
void
R
un
(
const
int8x4
_t
&
reg_a
,
const
int8x4
_t
&
reg_b
,
FloatC
&
reg_c
)
{
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 4, 0, 0);
reg_c
.
template
AsType
<
int32x16_t
>()(
Number
<
0
>
{})
=
return reg_c;
llvm_intrin_amdgcn_mfma_i32_32x32x8i8
(
as_type
<
int
>
(
reg_a
),
as_type
<
int
>
(
reg_b
),
reg_c
.
template
AsType
<
int32x16_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
}
}
};
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_i32_16x16x16i8
;
template
<
>
template
<
>
struct intrin_mfma_
f
32_
4x4x2bf16<8, 64
>
struct
intrin_mfma_
i
32_
16x16x16i8
<
16
,
16
>
{
{
__device__ static c_vec4_2_t::VecType
template
<
class
FloatC
>
r
un(const
ushort2
_t
*
reg_a, const
ushort2
_t
*
reg_b,
c_vec4_2_t::VecType
reg_c)
__device__
static
void
R
un
(
const
int8x4
_t
&
reg_a
,
const
int8x4
_t
&
reg_b
,
FloatC
&
reg_c
)
{
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 4, 0, 0);
reg_c
.
template
AsType
<
int32x4_t
>()(
Number
<
0
>
{})
=
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 4, 1, 0);
llvm_intrin_amdgcn_mfma_i32_16x16x16i8
(
as_type
<
int
>
(
reg_a
),
return reg_c;
as_type
<
int
>
(
reg_b
),
reg_c
.
template
AsType
<
int32x4_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
}
}
};
};
#endif
}
// namespace ck
}
// namespace ck
#endif
#endif
composable_kernel/include/utility/common_header.hpp
View file @
8c4e33f1
...
@@ -30,7 +30,11 @@
...
@@ -30,7 +30,11 @@
#include "amd_address_space.hpp"
#include "amd_address_space.hpp"
#include "amd_buffer_addressing.hpp"
#include "amd_buffer_addressing.hpp"
#include "static_buffer.hpp"
#include "static_buffer.hpp"
// TODO remove this
#include "static_buffer_of_vector_type_v2.hpp"
#include "dynamic_buffer.hpp"
#include "dynamic_buffer.hpp"
#include "is_known_at_compile_time.hpp"
#include "transpose_vectors.hpp"
#include "inner_product.hpp"
#include "inner_product.hpp"
...
...
composable_kernel/include/utility/config.hpp
View file @
8c4e33f1
...
@@ -76,7 +76,7 @@
...
@@ -76,7 +76,7 @@
#define CK_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
#define CK_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
#endif
#endif
// experimental implementation
// experimental implementation
for buffer load/store/atomic
#ifndef CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
#ifndef CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1
#endif
#endif
...
@@ -89,6 +89,11 @@
...
@@ -89,6 +89,11 @@
#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK 1
#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK 1
#endif
#endif
// experimental implementation for in-regsiter sub-dword transpose
#ifndef CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE
#define CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE 1
#endif
// pass tensor descriptor by value or void*
// pass tensor descriptor by value or void*
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE 0
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE 0
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER 0
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER 0
...
...
composable_kernel/include/utility/container_helper.hpp
View file @
8c4e33f1
...
@@ -373,19 +373,6 @@ set_container_subset(Tuple<Ys...>& y, Sequence<Is...> picks, const Tuple<Xs...>&
...
@@ -373,19 +373,6 @@ set_container_subset(Tuple<Ys...>& y, Sequence<Is...> picks, const Tuple<Xs...>&
static_for
<
0
,
sizeof
...(
Is
),
1
>
{}([
&
](
auto
i
)
{
y
(
picks
[
i
])
=
x
[
i
];
});
static_for
<
0
,
sizeof
...(
Is
),
1
>
{}([
&
](
auto
i
)
{
y
(
picks
[
i
])
=
x
[
i
];
});
}
}
template
<
typename
Container
>
__host__
__device__
constexpr
auto
to_tuple_of_number
(
const
Container
&
)
{
static_assert
(
is_known_at_compile_time
<
Container
>::
value
,
"wrong!"
);
return
generate_tuple
(
[
&
](
auto
i
)
{
constexpr
index_t
tmp
=
Container
::
At
(
i
);
return
Number
<
tmp
>
{};
},
Container
::
Size
());
}
template
<
index_t
...
Is
>
template
<
index_t
...
Is
>
__host__
__device__
constexpr
auto
sequence_to_tuple_of_number
(
Sequence
<
Is
...
>
)
__host__
__device__
constexpr
auto
sequence_to_tuple_of_number
(
Sequence
<
Is
...
>
)
{
{
...
...
composable_kernel/include/utility/data_type.hpp
View file @
8c4e33f1
...
@@ -58,6 +58,18 @@ __host__ __device__ constexpr auto make_vector_type(Number<N>)
...
@@ -58,6 +58,18 @@ __host__ __device__ constexpr auto make_vector_type(Number<N>)
template
<
typename
TV
>
template
<
typename
TV
>
struct
scalar_type
;
struct
scalar_type
;
// is_scalar_type
template
<
typename
TV
>
struct
is_scalar_type
{
static
constexpr
bool
value
=
(
scalar_type
<
remove_cvref_t
<
TV
>>::
vector_size
==
1
);
};
// has_same_scalar_type
template
<
typename
X
,
typename
Y
>
using
has_same_scalar_type
=
is_same
<
typename
scalar_type
<
remove_cvref_t
<
X
>>::
type
,
typename
scalar_type
<
remove_cvref_t
<
Y
>>::
type
>
;
template
<
typename
T
,
index_t
N
>
template
<
typename
T
,
index_t
N
>
struct
scalar_type
<
T
__attribute__
((
ext_vector_type
(
N
)))
>
struct
scalar_type
<
T
__attribute__
((
ext_vector_type
(
N
)))
>
{
{
...
...
composable_kernel/include/utility/ignore.hpp
0 → 100644
View file @
8c4e33f1
#ifndef CK_IGNORE_HPP
#define CK_IGNORE_HPP
// https://en.cppreference.com/w/cpp/utility/tuple/ignore
namespace
ck
{
namespace
detail
{
struct
ignore_t
{
template
<
typename
T
>
constexpr
void
operator
=
(
T
&&
)
const
noexcept
{
}
};
}
// namespace detail
inline
constexpr
detail
::
ignore_t
ignore
;
}
// namespace ck
#endif
composable_kernel/include/utility/is_known_at_compile_time.hpp
0 → 100644
View file @
8c4e33f1
#ifndef IS_KNOWN_AT_COMPILE_TIME_HPP
#define IS_KNOWN_AT_COMPILE_TIME_HPP
#include "config.hpp"
#include "integral_constant.hpp"
#include "sequence.hpp"
#include "tuple.hpp"
namespace
ck
{
template
<
typename
T
>
struct
is_known_at_compile_time
;
template
<
>
struct
is_known_at_compile_time
<
index_t
>
{
static
constexpr
bool
value
=
false
;
};
template
<
typename
T
,
T
X
>
struct
is_known_at_compile_time
<
integral_constant
<
T
,
X
>>
{
static
constexpr
bool
value
=
true
;
};
template
<
index_t
...
Is
>
struct
is_known_at_compile_time
<
Sequence
<
Is
...
>>
{
static
constexpr
bool
value
=
true
;
};
template
<
typename
...
Ts
>
struct
is_known_at_compile_time
<
Tuple
<
Ts
...
>>
{
__host__
__device__
static
constexpr
bool
IsKnownAtCompileTime
()
{
return
container_reduce
(
Tuple
<
Ts
...
>
{},
[](
auto
x
,
bool
r
)
{
return
is_known_at_compile_time
<
remove_cvref_t
<
decltype
(
x
)
>>::
value
&
r
;
},
true
);
}
static
constexpr
bool
value
=
IsKnownAtCompileTime
();
};
}
// namespace ck
#endif
composable_kernel/include/utility/static_buffer.hpp
View file @
8c4e33f1
...
@@ -5,158 +5,156 @@
...
@@ -5,158 +5,156 @@
namespace
ck
{
namespace
ck
{
template
<
AddressSpaceEnum_t
BufferAddressSpace
,
// static buffer for scalar
template
<
AddressSpaceEnum_t
AddressSpace
,
typename
T
,
typename
T
,
index_t
N
,
index_t
N
,
bool
InvalidElementUseNumericalZeroValue
>
bool
InvalidElementUseNumericalZeroValue
>
// TODO remove this bool, no longer needed
struct
StaticBuffer
:
public
StaticallyIndexedArray
<
T
,
N
>
struct
StaticBuffer
:
public
StaticallyIndexedArray
<
T
,
N
>
{
{
using
type
=
T
;
using
type
=
T
;
using
base
=
StaticallyIndexedArray
<
T
,
N
>
;
using
base
=
StaticallyIndexedArray
<
T
,
N
>
;
T
invalid_element_value_
=
T
{
0
};
__host__
__device__
constexpr
StaticBuffer
()
:
base
{}
{}
__host__
__device__
constexpr
StaticBuffer
()
:
base
{}
{}
__host__
__device__
constexpr
StaticBuffer
(
T
invalid_element_value
)
:
base
{},
invalid_element_value_
{
invalid_element_value
}
{
}
__host__
__device__
static
constexpr
AddressSpaceEnum_t
GetAddressSpace
()
__host__
__device__
static
constexpr
AddressSpaceEnum_t
GetAddressSpace
()
{
{
return
Buffer
AddressSpace
;
return
AddressSpace
;
}
}
__host__
__device__
static
constexpr
bool
IsStaticBuffer
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsDynamicBuffer
()
{
return
false
;
}
// read access
template
<
index_t
I
>
template
<
index_t
I
>
__host__
__device__
constexpr
auto
Get
(
Number
<
I
>
i
,
bool
is_valid_element
)
const
__host__
__device__
constexpr
const
T
&
operator
[](
Number
<
I
>
i
)
const
{
if
constexpr
(
InvalidElementUseNumericalZeroValue
)
{
return
is_valid_element
?
At
(
i
)
:
T
{
0
};
}
else
{
{
return
is_valid_element
?
At
(
i
)
:
invalid_element_value_
;
return
base
::
operator
[](
i
);
}
}
}
// write access
template
<
index_t
I
>
template
<
index_t
I
>
__host__
__device__
void
Set
(
Number
<
I
>
i
,
bool
is_valid_element
,
const
T
&
x
)
__host__
__device__
constexpr
T
&
operator
()(
Number
<
I
>
i
)
{
if
(
is_valid_element
)
{
{
At
(
i
)
=
x
;
return
base
::
operator
()(
i
);
}
}
}
__host__
__device__
static
constexpr
bool
IsStaticBuffer
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsDynamicBuffer
()
{
return
false
;
}
};
};
template
<
AddressSpaceEnum_t
BufferAddressSpace
,
// static buffer for vector
typename
T
,
template
<
AddressSpaceEnum_t
AddressSpace
,
index_t
N
,
typename
S
,
bool
InvalidElementUseNumericalZeroValue
>
index_t
NumOfVector
,
struct
StaticBufferV2
:
public
StaticallyIndexedArray
<
T
,
N
>
index_t
ScalarPerVector
,
bool
InvalidElementUseNumericalZeroValue
,
// TODO remove this bool, no longer needed,
typename
enable_if
<
is_scalar_type
<
S
>
::
value
,
bool
>::
type
=
false
>
struct
StaticBufferTupleOfVector
:
public
StaticallyIndexedArray
<
vector_type
<
S
,
ScalarPerVector
>
,
NumOfVector
>
{
{
using
type
=
T
;
using
V
=
typename
vector_type
<
S
,
ScalarPerVector
>::
type
;
using
base
=
StaticallyIndexedArray
<
T
,
N
>
;
using
base
=
StaticallyIndexedArray
<
vector_type
<
S
,
ScalarPerVector
>
,
NumOfVector
>
;
static
constexpr
auto
s_per_v
=
Number
<
ScalarPerVector
>
{};
static
constexpr
auto
num_of_v_
=
Number
<
NumOfVector
>
{};
using
VecBaseType
=
typename
T
::
d1_t
;
__host__
__device__
constexpr
StaticBufferTupleOfVector
()
:
base
{}
{}
__host__
__device__
static
constexpr
index_t
GetVectorSiz
e
()
__host__
__device__
static
constexpr
AddressSpaceEnum_t
GetAddressSpac
e
()
{
{
return
sizeof
(
typename
T
::
type
)
/
sizeof
(
VecBaseType
)
;
return
AddressSpace
;
}
}
static
constexpr
index_t
vector_size
=
GetVectorSize
();
__host__
__device__
static
constexpr
bool
IsStaticBuffer
()
{
return
true
;
}
VecBaseType
invalid_element_value_
=
VecBaseType
{
0
};
T
invalid_vec_value_
=
T
{
0
};
__host__
__device__
constexpr
Stat
icBuffer
V2
()
:
base
{}
{
}
__host__
__device__
static
constexpr
bool
IsDynam
icBuffer
()
{
return
false
;
}
__host__
__device__
constexpr
StaticBufferV2
(
VecBaseType
invalid_element_value
)
// Get S
:
base
{},
// i is offset of S
invalid_vec_value_
{
invalid_element_value
},
template
<
index_t
I
>
invalid_element_value_
{
invalid_element_value
}
__host__
__device__
constexpr
const
S
&
operator
[](
Number
<
I
>
i
)
const
{
{
}
constexpr
auto
i_v
=
i
/
s_per_v
;
constexpr
auto
i_s
=
i
%
s_per_v
;
__host__
__device__
static
constexpr
AddressSpaceEnum_t
GetAddressSpace
()
return
base
::
operator
[](
i_v
).
template
AsType
<
S
>()[
i_s
];
{
return
BufferAddressSpace
;
}
}
// Set S
// i is offset of S
template
<
index_t
I
>
template
<
index_t
I
>
__host__
__device__
constexpr
auto
&
GetVec
tor
(
Number
<
I
>
vec_id
)
__host__
__device__
constexpr
S
&
opera
tor
(
)(
Number
<
I
>
i
)
{
{
return
this
->
At
(
vec_id
)
;
constexpr
auto
i_v
=
i
/
s_per_v
;
}
constexpr
auto
i_s
=
i
%
s_per_v
;
template
<
index_t
I
>
return
base
::
operator
()(
i_v
).
template
AsType
<
S
>()(
i_s
);
__host__
__device__
constexpr
const
auto
&
GetVector
(
Number
<
I
>
vec_id
)
const
{
return
this
->
At
(
vec_id
);
}
}
template
<
index_t
I
>
// Get X
__host__
__device__
constexpr
auto
&
GetElement
(
Number
<
I
>
i
,
bool
)
// i is offset of S, not X. i should be aligned to X
template
<
typename
X
,
index_t
I
,
typename
enable_if
<
has_same_scalar_type
<
S
,
X
>
::
value
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
auto
GetAsType
(
Number
<
I
>
i
)
const
{
{
constexpr
auto
vec_id
=
Number
<
i
/
vector_size
>
{};
constexpr
auto
s_per_x
=
Number
<
scalar_type
<
remove_cvref_t
<
X
>>::
vector_size
>
{};
constexpr
auto
vec_off
=
Number
<
i
%
vector_size
>
{};
return
this
->
At
(
vec_id
).
template
AsType
<
VecBaseType
>()(
vec_off
);
static_assert
(
s_per_v
%
s_per_x
==
0
,
"wrong! V must one or multiple X"
);
}
static_assert
(
i
%
s_per_x
==
0
,
"wrong!"
);
template
<
index_t
I
>
constexpr
auto
i_v
=
i
/
s_per_v
;
__host__
__device__
constexpr
auto
GetElement
(
Number
<
I
>
i
,
bool
is_valid_element
)
const
constexpr
auto
i_x
=
(
i
%
s_per_v
)
/
s_per_x
;
{
constexpr
auto
vec_id
=
Number
<
i
/
vector_size
>
{};
constexpr
auto
vec_off
=
Number
<
i
%
vector_size
>
{};
if
constexpr
(
InvalidElementUseNumericalZeroValue
)
return
base
::
operator
[](
i_v
).
template
AsType
<
X
>()[
i_x
];
{
return
is_valid_element
?
this
->
At
(
vec_id
).
template
AsType
<
VecBaseType
>()[
vec_off
]
:
VecBaseType
{
0
};
}
}
else
// Set X
// i is offset of S, not X. i should be aligned to X
template
<
typename
X
,
index_t
I
,
typename
enable_if
<
has_same_scalar_type
<
S
,
X
>
::
value
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
void
SetAsType
(
Number
<
I
>
i
,
X
x
)
{
{
return
is_valid_element
?
this
->
At
(
vec_id
).
template
AsType
<
VecBaseType
>()[
vec_off
]
constexpr
auto
s_per_x
=
Number
<
scalar_type
<
remove_cvref_t
<
X
>>::
vector_size
>
{};
:
invalid_element_value_
;
}
static_assert
(
s_per_v
%
s_per_x
==
0
,
"wrong! V must contain one or multiple X"
);
static_assert
(
i
%
s_per_x
==
0
,
"wrong!"
);
constexpr
auto
i_v
=
i
/
s_per_v
;
constexpr
auto
i_x
=
(
i
%
s_per_v
)
/
s_per_x
;
base
::
operator
()(
i_v
).
template
AsType
<
X
>()(
i_x
)
=
x
;
}
}
// Get read access to vector_type V
// i is offset of S, not V. i should be aligned to V
template
<
index_t
I
>
template
<
index_t
I
>
__host__
__device__
constexpr
auto
operator
[]
(
Number
<
I
>
i
)
const
__host__
__device__
constexpr
const
auto
&
GetVectorTypeReference
(
Number
<
I
>
i
)
const
{
{
return
GetElement
(
i
,
true
);
static_assert
(
i
%
s_per_v
==
0
,
"wrong!"
);
constexpr
auto
i_v
=
i
/
s_per_v
;
return
base
::
operator
[](
i_v
);
}
}
// Get write access to vector_type V
// i is offset of S, not V. i should be aligned to V
template
<
index_t
I
>
template
<
index_t
I
>
__host__
__device__
constexpr
auto
&
operator
()
(
Number
<
I
>
i
)
__host__
__device__
constexpr
auto
&
GetVectorTypeReference
(
Number
<
I
>
i
)
{
{
return
GetElement
(
i
,
true
);
static_assert
(
i
%
s_per_v
==
0
,
"wrong!"
);
}
__host__
__device__
static
constexpr
bool
IsStaticBuffer
()
{
return
true
;
}
constexpr
auto
i_v
=
i
/
s_per_v
;
__host__
__device__
static
constexpr
bool
IsDynamicBuffer
()
{
return
false
;
}
return
base
::
operator
()(
i_v
);
}
};
};
template
<
AddressSpaceEnum_t
Buffer
AddressSpace
,
typename
T
,
index_t
N
>
template
<
AddressSpaceEnum_t
AddressSpace
,
typename
T
,
index_t
N
>
__host__
__device__
constexpr
auto
make_static_buffer
(
Number
<
N
>
)
__host__
__device__
constexpr
auto
make_static_buffer
(
Number
<
N
>
)
{
{
return
StaticBuffer
<
BufferAddressSpace
,
T
,
N
,
true
>
{};
return
StaticBuffer
<
AddressSpace
,
T
,
N
,
true
>
{};
}
template
<
AddressSpaceEnum_t
BufferAddressSpace
,
typename
T
,
index_t
N
>
__host__
__device__
constexpr
auto
make_static_buffer
(
Number
<
N
>
,
T
invalid_element_value
)
{
return
StaticBuffer
<
BufferAddressSpace
,
T
,
N
,
false
>
{
invalid_element_value
};
}
}
}
// namespace ck
}
// namespace ck
...
...
composable_kernel/include/utility/static_buffer_of_vector_type_v2.hpp
0 → 100644
View file @
8c4e33f1
#ifndef CK_STATIC_BUFFER_OF_VECTOR_TYPE_V2_HPP
#define CK_STATIC_BUFFER_OF_VECTOR_TYPE_V2_HPP
#include "statically_indexed_array.hpp"
namespace
ck
{
template
<
AddressSpaceEnum_t
BufferAddressSpace
,
typename
T
,
index_t
N
,
bool
InvalidElementUseNumericalZeroValue
>
struct
StaticBufferOfVectorTypeV2
:
public
StaticallyIndexedArray
<
T
,
N
>
{
using
type
=
T
;
using
base
=
StaticallyIndexedArray
<
T
,
N
>
;
using
VecBaseType
=
typename
T
::
d1_t
;
__host__
__device__
static
constexpr
index_t
GetVectorSize
()
{
return
sizeof
(
typename
T
::
type
)
/
sizeof
(
VecBaseType
);
}
static
constexpr
index_t
vector_size
=
GetVectorSize
();
VecBaseType
invalid_element_value_
=
VecBaseType
{
0
};
T
invalid_vec_value_
=
T
{
0
};
__host__
__device__
constexpr
StaticBufferOfVectorTypeV2
()
:
base
{}
{}
__host__
__device__
constexpr
StaticBufferOfVectorTypeV2
(
VecBaseType
invalid_element_value
)
:
base
{},
invalid_vec_value_
{
invalid_element_value
},
invalid_element_value_
{
invalid_element_value
}
{
}
__host__
__device__
static
constexpr
AddressSpaceEnum_t
GetAddressSpace
()
{
return
BufferAddressSpace
;
}
template
<
index_t
I
>
__host__
__device__
constexpr
auto
&
GetVector
(
Number
<
I
>
vec_id
)
{
return
this
->
At
(
vec_id
);
}
template
<
index_t
I
>
__host__
__device__
constexpr
const
auto
&
GetVector
(
Number
<
I
>
vec_id
)
const
{
return
this
->
At
(
vec_id
);
}
template
<
index_t
I
>
__host__
__device__
constexpr
auto
&
GetElement
(
Number
<
I
>
i
,
bool
)
{
constexpr
auto
vec_id
=
Number
<
i
/
vector_size
>
{};
constexpr
auto
vec_off
=
Number
<
i
%
vector_size
>
{};
return
this
->
At
(
vec_id
).
template
AsType
<
VecBaseType
>()(
vec_off
);
}
template
<
index_t
I
>
__host__
__device__
constexpr
auto
GetElement
(
Number
<
I
>
i
,
bool
is_valid_element
)
const
{
constexpr
auto
vec_id
=
Number
<
i
/
vector_size
>
{};
constexpr
auto
vec_off
=
Number
<
i
%
vector_size
>
{};
if
constexpr
(
InvalidElementUseNumericalZeroValue
)
{
return
is_valid_element
?
this
->
At
(
vec_id
).
template
AsType
<
VecBaseType
>()[
vec_off
]
:
VecBaseType
{
0
};
}
else
{
return
is_valid_element
?
this
->
At
(
vec_id
).
template
AsType
<
VecBaseType
>()[
vec_off
]
:
invalid_element_value_
;
}
}
template
<
index_t
I
>
__host__
__device__
constexpr
auto
operator
[](
Number
<
I
>
i
)
const
{
return
GetElement
(
i
,
true
);
}
template
<
index_t
I
>
__host__
__device__
constexpr
auto
&
operator
()(
Number
<
I
>
i
)
{
return
GetElement
(
i
,
true
);
}
__host__
__device__
static
constexpr
bool
IsStaticBuffer
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsDynamicBuffer
()
{
return
false
;
}
};
}
// namespace ck
#endif
composable_kernel/include/utility/statically_indexed_array.hpp
View file @
8c4e33f1
...
@@ -8,20 +8,38 @@
...
@@ -8,20 +8,38 @@
namespace
ck
{
namespace
ck
{
namespace
detail
{
namespace
detail
{
template
<
typename
X
,
typename
Y
>
struct
tuple_concat
;
template
<
typename
T
,
index_t
NSize
>
template
<
typename
...
Xs
,
typename
...
Ys
>
__host__
__device__
constexpr
auto
generate_same_type_tuple
()
struct
tuple_concat
<
Tuple
<
Xs
...
>
,
Tuple
<
Ys
...
>>
{
{
return
generate_tuple
([](
auto
)
->
T
{
return
T
{};
},
Number
<
NSize
>
{})
;
using
type
=
Tuple
<
Xs
...,
Ys
...
>
;
}
}
;
template
<
typename
T
,
index_t
NSize
>
template
<
typename
T
,
index_t
N
>
using
same_type_tuple
=
decltype
(
generate_same_type_tuple
<
T
,
NSize
>
());
struct
StaticallyIndexedArrayImpl
{
using
type
=
typename
tuple_concat
<
typename
StaticallyIndexedArrayImpl
<
T
,
N
/
2
>::
type
,
typename
StaticallyIndexedArrayImpl
<
T
,
N
-
N
/
2
>::
type
>::
type
;
};
template
<
typename
T
>
struct
StaticallyIndexedArrayImpl
<
T
,
0
>
{
using
type
=
Tuple
<>
;
};
template
<
typename
T
>
struct
StaticallyIndexedArrayImpl
<
T
,
1
>
{
using
type
=
Tuple
<
T
>
;
};
}
// namespace detail
}
// namespace detail
template
<
typename
T
,
index_t
N
Size
>
template
<
typename
T
,
index_t
N
>
using
StaticallyIndexedArray
=
detail
::
same_type_tu
pl
e
<
T
,
N
Size
>
;
using
StaticallyIndexedArray
=
typename
detail
::
StaticallyIndexedArrayIm
pl
<
T
,
N
>::
type
;
template
<
typename
X
,
typename
...
Xs
>
template
<
typename
X
,
typename
...
Xs
>
__host__
__device__
constexpr
auto
make_statically_indexed_array
(
const
X
&
x
,
const
Xs
&
...
xs
)
__host__
__device__
constexpr
auto
make_statically_indexed_array
(
const
X
&
x
,
const
Xs
&
...
xs
)
...
...
composable_kernel/include/utility/transpose_vectors.hpp
0 → 100644
View file @
8c4e33f1
#ifndef CK_TRANSPOSE_VECTORS_AMD_HPP
#define CK_TRANSPOSE_VECTORS_AMD_HPP
#include "config.hpp"
#include "statically_indexed_array.hpp"
#include "data_type.hpp"
namespace
ck
{
template
<
typename
S
,
index_t
NX
,
index_t
NY
,
typename
enable_if
<
is_scalar_type
<
S
>
::
value
,
bool
>::
type
=
false
>
struct
transpose_vectors
;
// transpose fp16 2x2
__device__
void
transpose_fp16_2x2
(
const
half2_t
&
x0
,
const
half2_t
&
x1
,
half2_t
&
y0
,
half2_t
&
y1
)
{
#if 0
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
const vector_type<half_t, 2> vx0{x0}, vx1{x1};
vector_type<half_t, 2> vy0, vy1;
vy0.template AsType<half_t>()(I0) = vx0.template AsType<half_t>()[I0];
vy0.template AsType<half_t>()(I1) = vx1.template AsType<half_t>()[I0];
vy1.template AsType<half_t>()(I0) = vx0.template AsType<half_t>()[I1];
vy1.template AsType<half_t>()(I1) = vx1.template AsType<half_t>()[I1];
y0 = vy0.template AsType<half2_t>()[I0];
y1 = vy1.template AsType<half2_t>()[I0];
#else
asm
volatile
(
"
\n
\
v_pack_b32_f16 %0, %1, %2
\n
\
"
:
"=v"
(
y0
)
:
"v"
(
x0
),
"v"
(
x1
));
asm
volatile
(
"
\n
\
v_pack_b32_f16 %0, %1, %2, op_sel:[1, 1]
\n
\
"
:
"=v"
(
y1
)
:
"v"
(
x0
),
"v"
(
x1
));
#endif
}
template
<
index_t
NX
,
index_t
NY
>
struct
transpose_vectors
<
half_t
,
NX
,
NY
>
{
// we got [NY * NX] ammount of S data to be transposed
static
constexpr
index_t
s_per_x
=
NY
;
static
constexpr
index_t
s_per_y
=
NX
;
using
S
=
half_t
;
using
VX
=
vector_type
<
half_t
,
s_per_x
>
;
using
VY
=
vector_type
<
half_t
,
s_per_y
>
;
__device__
void
operator
()(
const
StaticallyIndexedArray
<
const
VX
&
,
NX
>&
vx_tuple
,
StaticallyIndexedArray
<
VY
&
,
NY
>&
vy_tuple
)
{
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static_assert
((
NX
%
2
==
0
&&
NY
%
2
==
0
),
"wrong!"
);
// loop over 2x2 tile and transpose data from vx_tuple into vy_tuple
static_for
<
0
,
NY
,
2
>
{}([
&
](
auto
iy
)
{
static_for
<
0
,
NX
,
2
>
{}([
&
](
auto
ix
)
{
// reference to 2 half2_t data from vx_tuple
const
auto
&
x_s2_0
=
vx_tuple
[
ix
].
template
AsType
<
half2_t
>()[
iy
/
I2
];
const
auto
&
x_s2_1
=
vx_tuple
[
ix
+
I1
].
template
AsType
<
half2_t
>()[
iy
/
I2
];
// reference to 2 half2_t data from vy_tuple
auto
&
y_s2_0
=
vy_tuple
(
iy
).
template
AsType
<
half2_t
>()(
ix
/
I2
);
auto
&
y_s2_1
=
vy_tuple
(
iy
+
I1
).
template
AsType
<
half2_t
>()(
ix
/
I2
);
// transpose
transpose_fp16_2x2
(
x_s2_0
,
x_s2_1
,
y_s2_0
,
y_s2_1
);
});
});
}
};
}
// namespace ck
#endif
composable_kernel/include/utility/tuple.hpp
View file @
8c4e33f1
...
@@ -117,6 +117,7 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
...
@@ -117,6 +117,7 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
__host__
__device__
static
constexpr
index_t
Size
()
{
return
sizeof
...(
Xs
);
}
__host__
__device__
static
constexpr
index_t
Size
()
{
return
sizeof
...(
Xs
);
}
// read access
template
<
index_t
I
>
template
<
index_t
I
>
__host__
__device__
constexpr
const
auto
&
At
(
Number
<
I
>
)
const
__host__
__device__
constexpr
const
auto
&
At
(
Number
<
I
>
)
const
{
{
...
@@ -124,6 +125,7 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
...
@@ -124,6 +125,7 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
return
base
::
GetElementByKey
(
detail
::
TupleElementKey
<
I
>
{});
return
base
::
GetElementByKey
(
detail
::
TupleElementKey
<
I
>
{});
}
}
// write access
template
<
index_t
I
>
template
<
index_t
I
>
__host__
__device__
constexpr
auto
&
At
(
Number
<
I
>
)
__host__
__device__
constexpr
auto
&
At
(
Number
<
I
>
)
{
{
...
@@ -131,12 +133,14 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
...
@@ -131,12 +133,14 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
return
base
::
GetElementByKey
(
detail
::
TupleElementKey
<
I
>
{});
return
base
::
GetElementByKey
(
detail
::
TupleElementKey
<
I
>
{});
}
}
// read access
template
<
index_t
I
>
template
<
index_t
I
>
__host__
__device__
constexpr
const
auto
&
operator
[](
Number
<
I
>
i
)
const
__host__
__device__
constexpr
const
auto
&
operator
[](
Number
<
I
>
i
)
const
{
{
return
At
(
i
);
return
At
(
i
);
}
}
// write access
template
<
index_t
I
>
template
<
index_t
I
>
__host__
__device__
constexpr
auto
&
operator
()(
Number
<
I
>
i
)
__host__
__device__
constexpr
auto
&
operator
()(
Number
<
I
>
i
)
{
{
...
@@ -162,5 +166,12 @@ __host__ __device__ constexpr auto make_tuple(Xs&&... xs)
...
@@ -162,5 +166,12 @@ __host__ __device__ constexpr auto make_tuple(Xs&&... xs)
return
Tuple
<
remove_cvref_t
<
Xs
>
...
>
(
std
::
forward
<
Xs
>
(
xs
)...);
return
Tuple
<
remove_cvref_t
<
Xs
>
...
>
(
std
::
forward
<
Xs
>
(
xs
)...);
}
}
// https://en.cppreference.com/w/cpp/utility/tuple/tie
template
<
typename
...
Args
>
constexpr
Tuple
<
Args
&
...
>
tie
(
Args
&
...
args
)
noexcept
{
return
{
args
...};
}
}
// namespace ck
}
// namespace ck
#endif
#endif
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment