Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
3c4fb1dd
Commit
3c4fb1dd
authored
Nov 23, 2023
by
Umang Yadav
Browse files
Merge remote-tracking branch 'origin/develop' into migx_merge
parents
57cdd70b
e8cddfdc
Changes
386
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2054 additions
and
1285 deletions
+2054
-1285
include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl_dpp8.hpp
...r_operation/gpu/thread/threadwise_contraction_dl_dpp8.hpp
+0
-136
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+4
-5
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
...tion/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
+54
-23
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp
...tion/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp
+2
-2
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp
...tion/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp
+386
-0
include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp
include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp
+538
-0
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
+196
-10
include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp
...operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp
+18
-15
include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp
...eration/operator_transform/transform_conv_fwd_to_gemm.hpp
+6
-339
include/ck/utility/amd_buffer_addressing.hpp
include/ck/utility/amd_buffer_addressing.hpp
+209
-576
include/ck/utility/amd_gemm_dpp.hpp
include/ck/utility/amd_gemm_dpp.hpp
+51
-5
include/ck/utility/amd_xdlops.hpp
include/ck/utility/amd_xdlops.hpp
+191
-5
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+109
-11
include/ck/utility/dynamic_buffer.hpp
include/ck/utility/dynamic_buffer.hpp
+30
-4
include/ck/utility/f8_utils.hpp
include/ck/utility/f8_utils.hpp
+173
-130
include/ck/utility/inner_product.hpp
include/ck/utility/inner_product.hpp
+14
-0
include/ck/utility/inner_product_dpp8.hpp
include/ck/utility/inner_product_dpp8.hpp
+4
-0
include/ck/utility/is_detected.hpp
include/ck/utility/is_detected.hpp
+43
-0
include/ck/utility/loop_scheduler.hpp
include/ck/utility/loop_scheduler.hpp
+26
-0
include/ck/utility/math.hpp
include/ck/utility/math.hpp
+0
-24
No files found.
Too many changes to show.
To preserve performance only
386 of 386+
files are displayed.
Plain diff
Email patch
include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl_dpp8.hpp
deleted
100644 → 0
View file @
57cdd70b
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/amd_gemm_dpp.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/utility/inner_product_dpp8.hpp"
#include "ck/utility/math.hpp"
namespace
ck
{
/**
* Threadwise contraction using dot instructions with DPP8 modifier.
*
* Assumptions:
* 1. `AThreadDesc_TK0_TM0_TM1_TK1`, `BThreadDesc_TK0_TN0_TN1_TK1`, `CThreadDesc_TM0_TM1_TN0_TN1`
* are known at compile-time;
* 2. `AOriginIdx`, `BOriginIdx`, `COriginIdx` are known at compile-time;
* 3. `TM0` is equal to 1 and `TN0` is equal to 1;
* 4. When `ShareA` is set (unset, respectively), `TM1` (`TN1`, respectively) is divisible by
* the size of the lane group (`dpp8::lane_group_size`).
*/
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
,
typename
AThreadDesc_TK0_TM0_TM1_TK1
,
typename
BThreadDesc_TK0_TN0_TN1_TK1
,
typename
CThreadDesc_TM0_TM1_TN0_TN1
,
typename
TKLengths
,
typename
TMLengths
,
typename
TNLengths
,
bool
ShareA
,
typename
enable_if
<
AThreadDesc_TK0_TM0_TM1_TK1
::
IsKnownAtCompileTime
()
&&
BThreadDesc_TK0_TN0_TN1_TK1
::
IsKnownAtCompileTime
()
&&
CThreadDesc_TM0_TM1_TN0_TN1
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
struct
ThreadwiseContractionDlDpp8_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
index_t
TK0
=
TKLengths
{}[
I0
];
static
constexpr
index_t
TK1
=
TKLengths
{}[
I1
];
static
constexpr
index_t
TM0
=
TMLengths
{}[
I0
];
static
constexpr
index_t
TM1
=
TMLengths
{}[
I1
];
static
constexpr
index_t
TN0
=
TNLengths
{}[
I0
];
static
constexpr
index_t
TN1
=
TNLengths
{}[
I1
];
static_assert
(
TM0
==
1
&&
TN0
==
1
);
static_assert
((
ShareA
&&
TM1
%
dpp8
::
lane_group_size
==
0
)
||
(
!
ShareA
&&
TN1
%
dpp8
::
lane_group_size
==
0
));
static
constexpr
index_t
shared_elems_per_lane
=
ShareA
?
TM1
/
dpp8
::
lane_group_size
:
TN1
/
dpp8
::
lane_group_size
;
__device__
constexpr
ThreadwiseContractionDlDpp8_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1
()
{
static_assert
(
AThreadDesc_TK0_TM0_TM1_TK1
::
IsKnownAtCompileTime
()
&&
BThreadDesc_TK0_TN0_TN1_TK1
::
IsKnownAtCompileTime
()
&&
CThreadDesc_TM0_TM1_TN0_TN1
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
static_assert
(
TKLengths
::
Size
()
==
2
&&
TMLengths
::
Size
()
==
2
&&
TNLengths
::
Size
()
==
2
,
"wrong!"
);
}
template
<
typename
ABuffer
,
typename
AOriginIdx
,
typename
BBuffer
,
typename
BOriginIdx
,
typename
CBuffer
,
typename
COriginIdx
>
__device__
static
void
Run
(
const
ABuffer
&
a_buf
,
AOriginIdx
,
const
BBuffer
&
b_buf
,
BOriginIdx
,
CBuffer
&
c_buf
,
COriginIdx
)
{
static_assert
(
is_known_at_compile_time
<
remove_cvref_t
<
AOriginIdx
>>::
value
&&
is_known_at_compile_time
<
remove_cvref_t
<
BOriginIdx
>>::
value
&&
is_known_at_compile_time
<
remove_cvref_t
<
COriginIdx
>>::
value
,
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"
);
static_assert
(
is_same
<
remove_cvref_t
<
typename
ABuffer
::
type
>
,
remove_cvref_t
<
FloatA
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
BBuffer
::
type
>
,
remove_cvref_t
<
FloatB
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
CBuffer
::
type
>
,
remove_cvref_t
<
FloatC
>>::
value
&&
"wrong! inconsistent type"
);
constexpr
auto
a_origin_idx
=
to_multi_index
(
AOriginIdx
{});
constexpr
auto
b_origin_idx
=
to_multi_index
(
BOriginIdx
{});
constexpr
auto
c_origin_idx
=
to_multi_index
(
COriginIdx
{});
static_for
<
0
,
TK0
,
1
>
{}([
&
](
auto
tk0
)
{
static_for
<
0
,
TM1
,
1
>
{}([
&
](
auto
tm1
)
{
static_for
<
0
,
TN1
,
1
>
{}([
&
](
auto
tn1
)
{
vector_type
<
FloatA
,
TK1
>
a_vec
;
vector_type
<
FloatB
,
TK1
>
b_vec
;
static_for
<
0
,
TK1
,
1
>
{}([
&
](
auto
tk1
)
{
constexpr
index_t
local_tm1
=
ShareA
?
tm1
%
shared_elems_per_lane
:
tm1
;
constexpr
index_t
a_offset
=
AThreadDesc_TK0_TM0_TM1_TK1
{}.
CalculateOffset
(
a_origin_idx
+
make_multi_index
(
tk0
,
0
,
local_tm1
,
tk1
));
constexpr
index_t
local_tn1
=
ShareA
?
tn1
:
tn1
%
shared_elems_per_lane
;
constexpr
index_t
b_offset
=
BThreadDesc_TK0_TN0_TN1_TK1
{}.
CalculateOffset
(
b_origin_idx
+
make_multi_index
(
tk0
,
0
,
local_tn1
,
tk1
));
a_vec
.
template
AsType
<
FloatA
>()(
tk1
)
=
a_buf
[
Number
<
a_offset
>
{}];
b_vec
.
template
AsType
<
FloatB
>()(
tk1
)
=
b_buf
[
Number
<
b_offset
>
{}];
});
using
a_vector_t
=
typename
vector_type
<
FloatA
,
TK1
>::
type
;
using
b_vector_t
=
typename
vector_type
<
FloatB
,
TK1
>::
type
;
constexpr
index_t
c_offset
=
CThreadDesc_TM0_TM1_TN0_TN1
{}.
CalculateOffset
(
c_origin_idx
+
make_multi_index
(
0
,
tm1
,
0
,
tn1
));
constexpr
int
src_lane
=
ShareA
?
(
tm1
/
shared_elems_per_lane
)
%
dpp8
::
lane_group_size
:
(
tn1
/
shared_elems_per_lane
)
%
dpp8
::
lane_group_size
;
dpp8
::
inner_product_dpp
<
a_vector_t
,
b_vector_t
,
FloatC
,
src_lane
,
ShareA
>
(
a_vec
.
template
AsType
<
a_vector_t
>()[
I0
],
b_vec
.
template
AsType
<
b_vector_t
>()[
I0
],
c_buf
(
Number
<
c_offset
>
{}));
});
});
});
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
3c4fb1dd
...
...
@@ -137,13 +137,12 @@ struct ThreadwiseTensorSliceTransfer_v1r3
constexpr
index_t
src_offset
=
src_desc
.
CalculateOffset
(
src_slice_origin_idx
+
idx_md
+
i
*
dst_scalar_step_in_vector
);
Src
Data
v
;
Dst
Data
v
;
// apply element-wise operation
element_op_
(
v
,
src_buf
[
Number
<
src_offset
>
{}]);
// apply type convert
dst_vector
.
template
AsType
<
DstData
>()(
i
)
=
type_convert
<
DstData
>
(
v
);
dst_vector
.
template
AsType
<
DstData
>()(
i
)
=
v
;
});
const
bool
is_dst_valid
=
...
...
@@ -1289,13 +1288,13 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
constexpr
index_t
dst_offset
=
dst_desc
.
CalculateOffset
(
dst_slice_origin_idx
+
idx_md
+
i
*
dst_scalar_step_in_vector
);
Src
Data
v
;
Dst
Data
v
;
// apply element-wise operation
element_op_
(
v
,
src_buf
[
Number
<
src_offset
>
{}]);
// apply type convert
dst_buf
(
Number
<
dst_offset
>
{})
=
type_convert
<
DstData
>
(
v
)
;
dst_buf
(
Number
<
dst_offset
>
{})
=
v
;
});
});
}
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
View file @
3c4fb1dd
...
...
@@ -9,6 +9,7 @@
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor/static_tensor.hpp"
#include "ck/utility/is_detected.hpp"
namespace
ck
{
...
...
@@ -211,10 +212,44 @@ struct ThreadwiseTensorSliceTransfer_v3r1
auto
src_vector_container
=
src_vector_type
{
src_buf
.
template
Get
<
src_vector_t
>(
src_coord_
.
GetOffset
(),
is_src_valid
)};
using
dst_vector_type
=
vector_type_maker_t
<
DstData
,
SrcScalarPerVector
>
;
using
dst_vector_t
=
typename
dst_vector_type
::
type
;
dst_vector_type
op_r_v
;
constexpr
auto
get_elem_op_vec_len
=
[]()
{
if
constexpr
(
is_detected
<
is_pack8_invocable_t
,
decltype
(
src_element_op_
)
>::
value
)
{
if
constexpr
(
decltype
(
src_element_op_
)
::
is_pack8_invocable
)
return
math
::
min
(
8
,
SrcScalarPerVector
);
}
if
constexpr
(
is_detected
<
is_pack4_invocable_t
,
decltype
(
src_element_op_
)
>::
value
)
{
if
constexpr
(
decltype
(
src_element_op_
)
::
is_pack4_invocable
)
return
math
::
min
(
4
,
SrcScalarPerVector
);
}
if
constexpr
(
is_detected
<
is_pack2_invocable_t
,
decltype
(
src_element_op_
)
>::
value
)
{
if
constexpr
(
decltype
(
src_element_op_
)
::
is_pack2_invocable
)
return
math
::
min
(
2
,
SrcScalarPerVector
);
}
return
1
;
};
constexpr
index_t
elem_op_vec_len
=
get_elem_op_vec_len
();
using
src_elem_op_vec_t
=
typename
vector_type
<
SrcData
,
elem_op_vec_len
>::
type
;
using
dst_elem_op_vec_t
=
typename
vector_type
<
DstData
,
elem_op_vec_len
>::
type
;
static_for
<
0
,
SrcScalarPerVector
/
elem_op_vec_len
,
1
>
{}([
&
](
auto
idx
)
{
// apply the src elementwise op and convert to DstData under the hood if needed
src_element_op_
(
op_r_v
.
template
AsType
<
dst_elem_op_vec_t
>()(
idx
),
src_vector_container
.
template
AsType
<
src_elem_op_vec_t
>()[
idx
]);
});
// copy data from src_vector_container into src_thread_scratch_
src_thread_scratch_tuple_
(
thread_scratch_id
)
.
template
SetAsType
<
src
_vector_t
>(
src_data_idx_seq
,
src_vector_container
.
template
AsType
<
src
_vector_t
>()[
I0
]);
.
template
SetAsType
<
dst
_vector_t
>(
src_data_idx_seq
,
op_r_v
.
template
AsType
<
dst
_vector_t
>()[
I0
]);
constexpr
auto
move_on_dim
=
[
&
]()
constexpr
{
...
...
@@ -267,19 +302,15 @@ struct ThreadwiseTensorSliceTransfer_v3r1
{
#if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE
static_ford
<
SliceLengths
>
{}([
&
](
auto
idx
)
{
// convert from SrcData to DstData here
dst_thread_scratch_
(
idx
)
=
type_convert
<
DstData
>
(
src_thread_scratch_tuple_
[
thread_scratch_id
][
idx
]);
dst_thread_scratch_
(
idx
)
=
src_thread_scratch_tuple_
[
thread_scratch_id
][
idx
];
});
#else
// sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_
// TODO make this logic more generic for more sub-dword datatype
if
constexpr
(
SrcVectorDim
!=
DstVectorDim
&&
((
is_same
<
half_t
,
remove_cvref_t
<
SrcData
>>::
value
&&
is_same
<
half_t
,
remove_cvref_t
<
DstData
>>::
value
&&
((
is_same
<
half_t
,
remove_cvref_t
<
DstData
>>::
value
&&
SrcScalarPerVector
%
2
==
0
&&
DstScalarPerVector
%
2
==
0
)
||
(
is_same
<
int8_t
,
remove_cvref_t
<
SrcData
>>::
value
&&
is_same
<
int8_t
,
remove_cvref_t
<
DstData
>>::
value
&&
(
is_same
<
int8_t
,
remove_cvref_t
<
DstData
>>::
value
&&
SrcScalarPerVector
%
4
==
0
&&
DstScalarPerVector
%
4
==
0
)))
{
// each transpose does
...
...
@@ -313,7 +344,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
constexpr
auto
data_idx_seq
=
generate_sequence_v2
(
[
&
](
auto
i
)
{
return
Number
<
data_idx
[
i
]
>
{};
},
Number
<
nDim
>
{});
using
src_vector_t
=
vector_type_maker_t
<
Src
Data
,
SrcScalarPerVector
>
;
using
src_vector_t
=
vector_type_maker_t
<
Dst
Data
,
SrcScalarPerVector
>
;
using
dst_vector_t
=
vector_type_maker_t
<
DstData
,
DstScalarPerVector
>
;
// get DstScalarPerVector # of read-only references to src vectors from
...
...
@@ -336,17 +367,16 @@ struct ThreadwiseTensorSliceTransfer_v3r1
Number
<
num_dst_vector
>
{});
// do data transpose
transpose_vectors
<
Src
Data
,
DstScalarPerVector
,
SrcScalarPerVector
>
{}(
transpose_vectors
<
Dst
Data
,
DstScalarPerVector
,
SrcScalarPerVector
>
{}(
src_vector_refs
,
dst_vector_refs
);
});
}
else
{
static_ford
<
SliceLengths
>
{}([
&
](
auto
idx
)
{
// apply the src elementwise op and convert to DstData under the hood if needed
DstData
dst_v
;
src_element_op_
(
dst_v
,
src_thread_scratch_tuple_
[
thread_scratch_id
][
idx
]);
dst_thread_scratch_
(
idx
)
=
dst_v
;
dst_thread_scratch_
(
idx
)
=
src_thread_scratch_tuple_
[
thread_scratch_id
][
idx
];
});
}
#endif
}
...
...
@@ -761,8 +791,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
static
constexpr
auto
src_thread_scratch_desc_
=
decltype
(
GetSrcThreadScratchDescriptor
()){};
static
constexpr
auto
dst_thread_scratch_desc_
=
decltype
(
GetDstThreadScratchDescriptor
()){};
using
SrcThreadScratch
=
StaticTensorTupleOfVectorBuffer
<
AddressSpaceEnum
::
Vgpr
,
SrcData
,
using
SrcThreadScratch
=
StaticTensorTupleOfVectorBuffer
<
AddressSpaceEnum
::
Vgpr
,
DstData
,
// apply data_convert with SrcThreadScratch
SrcScalarPerVector
,
decltype
(
src_thread_scratch_desc_
),
true
>
;
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp
View file @
3c4fb1dd
...
...
@@ -104,13 +104,13 @@ struct ThreadwiseTensorSliceTransfer_v6r1
// apply pointwise operation
static_for
<
0
,
ScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
Src
Data
v
;
Dst
Data
v
;
// apply element-wise operation
element_op_
(
v
,
src_vector_container
.
template
AsType
<
SrcData
>()[
i
]);
// apply type convert
dst_vector_container
.
template
AsType
<
DstData
>()(
i
)
=
type_convert
<
DstData
>
(
v
)
;
dst_vector_container
.
template
AsType
<
DstData
>()(
i
)
=
v
;
});
const
bool
is_dst_valid
=
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp
0 → 100644
View file @
3c4fb1dd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
#include "ck/utility/is_detected.hpp"
namespace
ck
{
// Thread-level multi-source, multi-destination tensor slice data movement
// Assume:
// 1. All sources and destinations are DynamicBuffer
// 2. Same VectorDim and ScalerPerVector for all sources and destinations
// 3. DstInMemOps are per destination tensor
// 4. ThreadTransferSrcResetCoordinateAfterRunFlags are per source tensor
// 5. ThreadTransferDstResetCoordinateAfterRunFlags are per destination tensor
// 6. Does not need to know src_descs and dst_descs at compile-time
// 7. Does not need to know src_slice_origins and dst_slice_origins at compile-time,
//
// Does following things to avoid scratch memory issue
// 1. Use StaticallyIndexedArray or vector_type instead of C array for thread buffer
// 2. Pass tensor descritpors by reference (or tuple of references)
// 3. Does not keep reference to tensor descriptor
// 4. Does not construct new tensor coordinate when call Run()
template
<
typename
SrcDatas
,
typename
DstDatas
,
typename
SrcDescs
,
typename
DstDescs
,
typename
ElementwiseOperation
,
typename
DstInMemOps
,
// Sequence<InMemoryDataOperationEnum ...>
typename
SliceLengths
,
typename
SrcDimAccessOrder
,
typename
DstDimAccessOrder
,
index_t
SrcVectorDim
,
index_t
DstVectorDim
,
index_t
SrcScalarPerVector
,
index_t
DstScalarPerVector
,
typename
SrcResetCoordinateAfterRunFlags
,
// Sequence<bool ...>
typename
DstResetCoordinateAfterRunFlags
>
// Sequence<bool ...>
struct
ThreadwiseTensorSliceTransfer_v7r2
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
static
constexpr
index_t
nSrc
=
SrcDescs
::
Size
();
static
constexpr
index_t
nDst
=
DstDescs
::
Size
();
using
Index
=
MultiIndex
<
nDim
>
;
// return a tuple of coordiantes for a tuple of tensor
template
<
typename
Descs
,
typename
Indices
,
enable_if_t
<
Descs
::
Size
()
==
Indices
::
Size
(),
bool
>
=
false
>
static
constexpr
auto
MakeCoordinates
(
const
Descs
&
descs
,
const
Indices
&
indices
)
{
return
generate_tuple
([
&
](
auto
i
)
{
return
make_tensor_coordinate
(
descs
[
i
],
indices
[
i
]);
},
Number
<
Descs
::
Size
()
>
{});
}
using
SrcCoords
=
decltype
(
MakeCoordinates
(
SrcDescs
{},
StaticallyIndexedArray
<
Index
,
nSrc
>
{}));
using
DstCoords
=
decltype
(
MakeCoordinates
(
DstDescs
{},
StaticallyIndexedArray
<
Index
,
nDst
>
{}));
// scalar per access on each dim
// FIXME: don't use lambda_scalar_per_access
static
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
using
SrcSpaceFillingCurve
=
SpaceFillingCurve
<
SliceLengths
,
SrcDimAccessOrder
,
remove_cv_t
<
decltype
(
src_scalar_per_access
)
>>
;
static
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
using
DstSpaceFillingCurve
=
SpaceFillingCurve
<
SliceLengths
,
DstDimAccessOrder
,
remove_cv_t
<
decltype
(
dst_scalar_per_access
)
>>
;
__device__
constexpr
ThreadwiseTensorSliceTransfer_v7r2
(
const
SrcDescs
&
src_descs
,
const
StaticallyIndexedArray
<
Index
,
nSrc
>&
src_slice_origins
,
const
DstDescs
&
dst_descs
,
const
StaticallyIndexedArray
<
Index
,
nDst
>&
dst_slice_origins
,
const
ElementwiseOperation
&
element_op
)
:
src_coords_
(
MakeCoordinates
(
src_descs
,
src_slice_origins
)),
dst_coords_
(
MakeCoordinates
(
dst_descs
,
dst_slice_origins
)),
element_op_
(
element_op
)
{
static_assert
(
SliceLengths
::
At
(
Number
<
SrcVectorDim
>
{})
%
SrcScalarPerVector
==
0
,
"wrong! cannot evenly divide"
);
static_assert
(
SliceLengths
::
At
(
Number
<
DstVectorDim
>
{})
%
DstScalarPerVector
==
0
,
"wrong! cannot evenly divide"
);
}
template
<
typename
Indices
,
enable_if_t
<
SrcDescs
::
Size
()
==
Indices
::
Size
(),
bool
>
=
false
>
__device__
void
SetSrcSliceOrigins
(
const
SrcDescs
&
src_descs
,
const
Indices
&
src_slice_origin_idxs
)
{
static_for
<
0
,
nSrc
,
1
>
{}([
&
](
auto
i
)
{
src_coords_
(
i
)
=
make_tensor_coordinate
(
src_descs
[
i
],
src_slice_origin_idxs
[
i
]);
});
}
template
<
typename
Indices
,
enable_if_t
<
DstDescs
::
Size
()
==
Indices
::
Size
(),
bool
>
=
false
>
__device__
void
SetDstSliceOrigins
(
const
DstDescs
&
dst_descs
,
const
Indices
&
dst_slice_origin_idxs
)
{
static_for
<
0
,
nDst
,
1
>
{}([
&
](
auto
i
)
{
dst_coords_
(
i
)
=
make_tensor_coordinate
(
dst_descs
[
i
],
dst_slice_origin_idxs
[
i
]);
});
}
template
<
typename
DataTypes
,
index_t
ScalarPerVector
>
__device__
static
auto
generate_vectors
()
{
auto
data_types
=
DataTypes
{};
constexpr
index_t
num
=
data_types
.
Size
();
return
generate_tuple
(
[
&
](
auto
i
)
{
using
DataType
=
remove_cvref_t
<
decltype
(
data_types
[
i
])
>
;
return
vector_type_maker_t
<
DataType
,
ScalarPerVector
>
{};
},
Number
<
num
>
{});
}
// SrcDescs: Tuple<const SrcDesc0&, const SrcDesc1&, ...>
// SrcBuffers: Tuple<const SrcBuffer0&, const SrcBuffer1&, ...>
template
<
typename
SrcBuffers
,
enable_if_t
<
SrcDescs
::
Size
()
==
SrcBuffers
::
Size
(),
bool
>
=
false
>
__device__
void
RunRead
(
const
SrcDescs
&
src_descs
,
const
SrcBuffers
&
src_bufs
)
{
// loop over space-filling curve
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
iAccess
)
{
auto
src_vectors
=
generate_vectors
<
SrcDatas
,
SrcScalarPerVector
>
();
auto
dst_vectors
=
generate_vectors
<
DstDatas
,
DstScalarPerVector
>
();
// copy data from src_bufs into src_vectors
static_for
<
0
,
nSrc
,
1
>
{}([
&
](
auto
i
)
{
using
src_vector_t
=
typename
remove_cvref_t
<
decltype
(
src_vectors
[
i
])
>::
type
;
const
bool
is_src_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_descs
[
i
],
src_coords_
[
i
]);
src_vectors
(
i
).
template
AsType
<
src_vector_t
>()(
I0
)
=
src_bufs
[
i
].
template
Get
<
src_vector_t
>(
src_coords_
[
i
].
GetOffset
(),
is_src_valid
);
});
constexpr
auto
get_elem_op_vec_len
=
[]()
{
if
constexpr
(
is_detected
<
is_pack8_invocable_t
,
decltype
(
element_op_
)
>::
value
)
{
if
constexpr
(
decltype
(
element_op_
)
::
is_pack8_invocable
)
return
math
::
min
(
8
,
SrcScalarPerVector
);
}
if
constexpr
(
is_detected
<
is_pack4_invocable_t
,
decltype
(
element_op_
)
>::
value
)
{
if
constexpr
(
decltype
(
element_op_
)
::
is_pack4_invocable
)
return
math
::
min
(
4
,
SrcScalarPerVector
);
}
if
constexpr
(
is_detected
<
is_pack2_invocable_t
,
decltype
(
element_op_
)
>::
value
)
{
if
constexpr
(
decltype
(
element_op_
)
::
is_pack2_invocable
)
return
math
::
min
(
2
,
SrcScalarPerVector
);
}
return
1
;
};
constexpr
index_t
elem_op_vec_len
=
get_elem_op_vec_len
();
// apply pointwise function
static_for
<
0
,
SrcScalarPerVector
/
elem_op_vec_len
,
1
>
{}([
&
](
auto
i
)
{
// get reference to src data
const
auto
src_data_refs
=
generate_tie
(
// return type should be lvalue
[
&
](
auto
iSrc
)
->
const
auto
&
{
using
SrcData
=
remove_cvref_t
<
tuple_element_t
<
iSrc
.
value
,
SrcDatas
>>
;
using
elem_op_vec_t
=
typename
vector_type
<
SrcData
,
elem_op_vec_len
>::
type
;
return
src_vectors
[
iSrc
].
template
AsType
<
elem_op_vec_t
>()[
i
];
},
Number
<
nSrc
>
{});
// get reference to dst data
auto
dst_data_refs
=
generate_tie
(
// return type should be lvalue
[
&
](
auto
iDst
)
->
auto
&
{
using
DstData
=
remove_cvref_t
<
tuple_element_t
<
iDst
.
value
,
DstDatas
>>
;
using
elem_op_vec_t
=
typename
vector_type
<
DstData
,
elem_op_vec_len
>::
type
;
return
dst_vectors
(
iDst
).
template
AsType
<
elem_op_vec_t
>()(
i
);
},
Number
<
nDst
>
{});
// apply pointwise function
// pointwise function signature:
// element_op_(dst_data_refs[I0],
// dst_data_refs[I1],
// ...,
// src_data_refs[I0],
// src_data_refs[I1],
// ...)
unpack2
(
element_op_
,
dst_data_refs
,
src_data_refs
);
});
dst_vectors_tuple_
(
iAccess
)
=
dst_vectors
;
// move coordinate
if
constexpr
(
iAccess
.
value
!=
num_access
-
1
)
{
constexpr
auto
forward_step
=
SrcSpaceFillingCurve
::
GetForwardStep
(
iAccess
);
static_for
<
0
,
nSrc
,
1
>
{}([
&
](
auto
i
)
{
move_tensor_coordinate
(
src_descs
[
i
],
src_coords_
(
i
),
make_tensor_coordinate_step
(
src_descs
[
i
],
forward_step
));
});
}
});
// move coordinate back to slice origin (or not)
static_for
<
0
,
nSrc
,
1
>
{}([
&
](
auto
i
)
{
if
constexpr
(
SrcResetCoordinateAfterRunFlags
::
At
(
i
))
{
const
auto
src_reset_step
=
make_tensor_coordinate_step
(
src_descs
[
i
],
GetSrcCoordinateResetStep
());
move_tensor_coordinate
(
src_descs
[
i
],
src_coords_
(
i
),
src_reset_step
);
}
});
}
// DstDescs: Tuple<const DstDesc0&, const DstDesc1&, ...>
// DstBuffers: Tuple<const DstBuffer0&, const DstBuffer1&, ...>
template
<
typename
DstBuffers
,
enable_if_t
<
DstDescs
::
Size
()
==
DstBuffers
::
Size
(),
bool
>
=
false
>
__device__
void
RunWrite
(
const
DstDescs
&
dst_descs
,
DstBuffers
dst_bufs
)
{
// loop over space-filling curve
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
iAccess
)
{
auto
dst_vectors
=
dst_vectors_tuple_
[
iAccess
];
// copy data from buf_vectors into dst_bufs
static_for
<
0
,
nDst
,
1
>
{}([
&
](
auto
i
)
{
using
dst_vector_t
=
typename
remove_cvref_t
<
decltype
(
dst_vectors
[
i
])
>::
type
;
const
bool
is_dst_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
dst_descs
[
i
],
dst_coords_
[
i
]);
constexpr
InMemoryDataOperationEnum
DstInMemOp
=
static_cast
<
InMemoryDataOperationEnum
>
(
DstInMemOps
::
At
(
i
.
value
));
dst_bufs
(
i
).
template
Update
<
DstInMemOp
,
dst_vector_t
>(
dst_coords_
[
i
].
GetOffset
(),
is_dst_valid
,
dst_vectors
[
i
].
template
AsType
<
dst_vector_t
>()[
I0
]);
});
// move coordinate
if
constexpr
(
iAccess
.
value
!=
num_access
-
1
)
{
constexpr
auto
forward_step
=
DstSpaceFillingCurve
::
GetForwardStep
(
iAccess
);
static_for
<
0
,
nDst
,
1
>
{}([
&
](
auto
i
)
{
move_tensor_coordinate
(
dst_descs
[
i
],
dst_coords_
(
i
),
make_tensor_coordinate_step
(
dst_descs
[
i
],
forward_step
));
});
}
});
static_for
<
0
,
nDst
,
1
>
{}([
&
](
auto
i
)
{
if
constexpr
(
DstResetCoordinateAfterRunFlags
::
At
(
i
))
{
const
auto
dst_reset_step
=
make_tensor_coordinate_step
(
dst_descs
[
i
],
GetDstCoordinateResetStep
());
move_tensor_coordinate
(
dst_descs
[
i
],
dst_coords_
(
i
),
dst_reset_step
);
}
});
}
// SrcDescs: Tuple<const SrcDesc0&, const SrcDesc1&, ...>
// SrcBuffers: Tuple<const SrcBuffer0&, const SrcBuffer1&, ...>
// DstDescs: Tuple<const DstDesc0&, const DstDesc1&, ...>
// DstBuffers: Tuple<const DstBuffer0&, const DstBuffer1&, ...>
template
<
typename
SrcBuffers
,
typename
DstBuffers
,
enable_if_t
<
SrcDescs
::
Size
()
==
SrcBuffers
::
Size
()
&&
DstDescs
::
Size
()
==
DstBuffers
::
Size
(),
bool
>
=
false
>
__device__
void
Run
(
const
SrcDescs
&
src_descs
,
const
SrcBuffers
&
src_bufs
,
const
DstDescs
&
dst_descs
,
DstBuffers
dst_bufs
)
{
RunRead
(
src_descs
,
src_bufs
);
RunWrite
(
dst_descs
,
dst_bufs
);
}
__device__
static
constexpr
auto
GetSrcCoordinateResetStep
()
{
if
constexpr
(
num_access
==
0
)
{
return
typename
SrcSpaceFillingCurve
::
Index
{};
}
else
{
return
SrcSpaceFillingCurve
::
GetStepBetween
(
Number
<
num_access
-
1
>
{},
Number
<
0
>
{});
}
}
__device__
static
constexpr
auto
GetDstCoordinateResetStep
()
{
if
constexpr
(
num_access
==
0
)
{
return
typename
DstSpaceFillingCurve
::
Index
{};
}
else
{
return
DstSpaceFillingCurve
::
GetStepBetween
(
Number
<
num_access
-
1
>
{},
Number
<
0
>
{});
}
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
template
<
index_t
ISrc
>
__device__
void
MoveSrcSliceWindow
(
const
SrcDescs
&
src_descs
,
Number
<
ISrc
>
iSrc
,
const
Index
&
src_slice_origin_step_idx
)
{
// if src coord was not reset by RunRead(), then need to adjust the step here
const
auto
adjusted_step_idx
=
SrcResetCoordinateAfterRunFlags
::
At
(
iSrc
)
?
src_slice_origin_step_idx
:
src_slice_origin_step_idx
+
GetSrcCoordinateResetStep
();
// is it OK to construct a new step every time?
const
auto
adjusted_step
=
make_tensor_coordinate_step
(
src_descs
[
iSrc
],
adjusted_step_idx
);
move_tensor_coordinate
(
src_descs
[
iSrc
],
src_coords_
(
iSrc
),
adjusted_step
);
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
template
<
index_t
IDst
>
__device__
void
MoveDstSliceWindow
(
const
DstDescs
&
dst_descs
,
Number
<
IDst
>
iDst
,
const
Index
&
dst_slice_origin_step_idx
)
{
// if dst coord was not reset by Run(), then need to adjust the step here
const
auto
adjusted_step_idx
=
DstResetCoordinateAfterRunFlags
::
At
(
iDst
)
?
dst_slice_origin_step_idx
:
dst_slice_origin_step_idx
+
GetDstCoordinateResetStep
();
// is it OK to construct a new step every time?
const
auto
adjusted_step
=
make_tensor_coordinate_step
(
dst_descs
[
iDst
],
adjusted_step_idx
);
move_tensor_coordinate
(
dst_descs
[
iDst
],
dst_coords_
(
iDst
),
adjusted_step
);
}
private:
using
SrcVectorsType
=
decltype
(
generate_vectors
<
SrcDatas
,
SrcScalarPerVector
>
());
using
DstVectorsType
=
decltype
(
generate_vectors
<
DstDatas
,
DstScalarPerVector
>
());
static
constexpr
auto
num_access
=
SrcSpaceFillingCurve
::
GetNumOfAccess
();
StaticallyIndexedArray
<
DstVectorsType
,
num_access
>
dst_vectors_tuple_
;
SrcCoords
src_coords_
;
DstCoords
dst_coords_
;
const
ElementwiseOperation
element_op_
;
};
}
// namespace ck
include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp
0 → 100644
View file @
3c4fb1dd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/amd_gemm_dpp.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/utility/math.hpp"
namespace
ck
{
enum
struct
DppInstr
{
dpp8_f16_1x32x2
=
0
,
dpp8_f16_2x16x2
,
dpp8_f16_2x32x2
,
dpp8_f16_4x16x2
,
dpp8_f16_4x32x2
,
dpp8_f16_8x16x2
,
dpp8_f16_8x32x2
,
dpp8_f16_16x16x2
,
dpp8_f16_32x8x2
};
/**
* Structure representing DPP GEMM executed by a single wavefront.
*
* Each structure instantiation must contain the following fields:
* - wave_size - number of threads that execute a single DPP GEMM operation, usually equal to the
* number of threads in a wavefront;
* - lanegroup_size - number of threads (lanes) that share data using DPP instruction modifier,
* it's 8 in case of DPP8;
* - m_per_wave - size along M dimension of matrix C that is processed in a single DPP GEMM
* operation;
* - n_per_wave - size along N dimension of matrix C that is processed in a single DPP GEMM
* operation;
* - m_per_lanegroup - size along M dimension that is processed by a single lanegroup;
* - n_per_lanegroup - size along N dimension that is processed by a single lanegroup;
* - m_per_thread - size along M dimension of the tile calculated by a single thread;
* - n_per_thread - size along N dimension of the tile calculated by a single thread;
* - k_per_dpp - size along K dimension that is reduced in a single DPP GEMM operation;
* - share_a - indicates whether we share matrix A or matrix B between lanes using DPP modifiers.
*
* Not all the combinarions are supported now, for current restrictions see the static asserts
* in the DppSelector's contructor.
*/
template
<
DppInstr
instr
>
struct
dpp_type
;
template
<
>
struct
dpp_type
<
DppInstr
::
dpp8_f16_32x8x2
>
{
static
constexpr
index_t
wave_size
=
32
;
static
constexpr
index_t
lanegroup_size
=
8
;
static
constexpr
index_t
m_per_wave
=
32
;
static
constexpr
index_t
n_per_wave
=
8
;
static
constexpr
index_t
m_per_lanegroup
=
8
;
static
constexpr
index_t
n_per_lanegroup
=
8
;
static
constexpr
index_t
m_per_thread
=
8
;
static
constexpr
index_t
n_per_thread
=
1
;
static
constexpr
index_t
k_per_dpp
=
2
;
static
constexpr
bool
share_a
=
true
;
using
BaseType
=
half_t
;
template
<
index_t
MPerDpp
,
index_t
NPerDpp
,
class
ADataType
,
class
BDataType
,
class
CDataType
>
__device__
void
run
(
const
ADataType
&
a
,
const
BDataType
&
b
,
CDataType
&
reg_c
)
const
{
dpp8
::
DppLanegroupGemm
<
m_per_thread
,
n_per_thread
,
k_per_dpp
,
BaseType
,
ADataType
,
BDataType
,
CDataType
,
share_a
>
{}
.
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
dpp_type
<
DppInstr
::
dpp8_f16_8x32x2
>
{
static
constexpr
index_t
wave_size
=
32
;
static
constexpr
index_t
lanegroup_size
=
8
;
static
constexpr
index_t
m_per_wave
=
8
;
static
constexpr
index_t
n_per_wave
=
32
;
static
constexpr
index_t
m_per_lanegroup
=
8
;
static
constexpr
index_t
n_per_lanegroup
=
8
;
static
constexpr
index_t
m_per_thread
=
8
;
static
constexpr
index_t
n_per_thread
=
1
;
static
constexpr
index_t
k_per_dpp
=
2
;
static
constexpr
bool
share_a
=
true
;
using
BaseType
=
half_t
;
template
<
index_t
MPerDpp
,
index_t
NPerDpp
,
class
ADataType
,
class
BDataType
,
class
CDataType
>
__device__
void
run
(
const
ADataType
&
a
,
const
BDataType
&
b
,
CDataType
&
reg_c
)
const
{
dpp8
::
DppLanegroupGemm
<
m_per_thread
,
n_per_thread
,
k_per_dpp
,
BaseType
,
ADataType
,
BDataType
,
CDataType
,
share_a
>
{}
.
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
dpp_type
<
DppInstr
::
dpp8_f16_8x16x2
>
{
static
constexpr
index_t
wave_size
=
32
;
static
constexpr
index_t
lanegroup_size
=
8
;
static
constexpr
index_t
m_per_wave
=
8
;
static
constexpr
index_t
n_per_wave
=
16
;
static
constexpr
index_t
m_per_lanegroup
=
4
;
static
constexpr
index_t
n_per_lanegroup
=
8
;
static
constexpr
index_t
m_per_thread
=
4
;
static
constexpr
index_t
n_per_thread
=
1
;
static
constexpr
index_t
k_per_dpp
=
2
;
static
constexpr
bool
share_a
=
true
;
using
BaseType
=
half_t
;
template
<
index_t
MPerDpp
,
index_t
NPerDpp
,
class
ADataType
,
class
BDataType
,
class
CDataType
>
__device__
void
run
(
const
ADataType
&
a
,
const
BDataType
&
b
,
CDataType
&
reg_c
)
const
{
dpp8
::
DppLanegroupGemm
<
m_per_thread
,
n_per_thread
,
k_per_dpp
,
BaseType
,
ADataType
,
BDataType
,
CDataType
,
share_a
>
{}
.
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
dpp_type
<
DppInstr
::
dpp8_f16_16x16x2
>
{
static
constexpr
index_t
wave_size
=
32
;
static
constexpr
index_t
lanegroup_size
=
8
;
static
constexpr
index_t
m_per_wave
=
16
;
static
constexpr
index_t
n_per_wave
=
16
;
static
constexpr
index_t
m_per_lanegroup
=
8
;
static
constexpr
index_t
n_per_lanegroup
=
8
;
static
constexpr
index_t
m_per_thread
=
8
;
static
constexpr
index_t
n_per_thread
=
1
;
static
constexpr
index_t
k_per_dpp
=
2
;
static
constexpr
bool
share_a
=
true
;
using
BaseType
=
half_t
;
template
<
index_t
MPerDpp
,
index_t
NPerDpp
,
class
ADataType
,
class
BDataType
,
class
CDataType
>
__device__
void
run
(
const
ADataType
&
a
,
const
BDataType
&
b
,
CDataType
&
reg_c
)
const
{
dpp8
::
DppLanegroupGemm
<
m_per_thread
,
n_per_thread
,
k_per_dpp
,
BaseType
,
ADataType
,
BDataType
,
CDataType
,
share_a
>
{}
.
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
dpp_type
<
DppInstr
::
dpp8_f16_4x32x2
>
{
static
constexpr
index_t
wave_size
=
32
;
static
constexpr
index_t
lanegroup_size
=
8
;
static
constexpr
index_t
m_per_wave
=
4
;
static
constexpr
index_t
n_per_wave
=
32
;
static
constexpr
index_t
m_per_lanegroup
=
4
;
static
constexpr
index_t
n_per_lanegroup
=
8
;
static
constexpr
index_t
m_per_thread
=
4
;
static
constexpr
index_t
n_per_thread
=
1
;
static
constexpr
index_t
k_per_dpp
=
2
;
static
constexpr
bool
share_a
=
true
;
using
BaseType
=
half_t
;
template
<
index_t
MPerDpp
,
index_t
NPerDpp
,
class
ADataType
,
class
BDataType
,
class
CDataType
>
__device__
void
run
(
const
ADataType
&
a
,
const
BDataType
&
b
,
CDataType
&
reg_c
)
const
{
dpp8
::
DppLanegroupGemm
<
m_per_thread
,
n_per_thread
,
k_per_dpp
,
BaseType
,
ADataType
,
BDataType
,
CDataType
,
share_a
>
{}
.
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
dpp_type
<
DppInstr
::
dpp8_f16_4x16x2
>
{
static
constexpr
index_t
wave_size
=
32
;
static
constexpr
index_t
lanegroup_size
=
8
;
static
constexpr
index_t
m_per_wave
=
4
;
static
constexpr
index_t
n_per_wave
=
16
;
static
constexpr
index_t
m_per_lanegroup
=
2
;
static
constexpr
index_t
n_per_lanegroup
=
8
;
static
constexpr
index_t
m_per_thread
=
2
;
static
constexpr
index_t
n_per_thread
=
1
;
static
constexpr
index_t
k_per_dpp
=
2
;
static
constexpr
bool
share_a
=
true
;
using
BaseType
=
half_t
;
template
<
index_t
MPerDpp
,
index_t
NPerDpp
,
class
ADataType
,
class
BDataType
,
class
CDataType
>
__device__
void
run
(
const
ADataType
&
a
,
const
BDataType
&
b
,
CDataType
&
reg_c
)
const
{
dpp8
::
DppLanegroupGemm
<
m_per_thread
,
n_per_thread
,
k_per_dpp
,
BaseType
,
ADataType
,
BDataType
,
CDataType
,
share_a
>
{}
.
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
dpp_type
<
DppInstr
::
dpp8_f16_1x32x2
>
{
static
constexpr
index_t
wave_size
=
32
;
static
constexpr
index_t
lanegroup_size
=
8
;
static
constexpr
index_t
m_per_wave
=
1
;
static
constexpr
index_t
n_per_wave
=
32
;
static
constexpr
index_t
m_per_lanegroup
=
1
;
static
constexpr
index_t
n_per_lanegroup
=
8
;
static
constexpr
index_t
m_per_thread
=
1
;
static
constexpr
index_t
n_per_thread
=
1
;
static
constexpr
index_t
k_per_dpp
=
2
;
static
constexpr
bool
share_a
=
true
;
using
BaseType
=
half_t
;
template
<
index_t
MPerDpp
,
index_t
NPerDpp
,
class
ADataType
,
class
BDataType
,
class
CDataType
>
__device__
void
run
(
const
ADataType
&
a
,
const
BDataType
&
b
,
CDataType
&
reg_c
)
const
{
dpp8
::
DppLanegroupGemm
<
m_per_thread
,
n_per_thread
,
k_per_dpp
,
BaseType
,
ADataType
,
BDataType
,
CDataType
,
share_a
>
{}
.
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
dpp_type
<
DppInstr
::
dpp8_f16_2x32x2
>
{
static
constexpr
index_t
wave_size
=
32
;
static
constexpr
index_t
lanegroup_size
=
8
;
static
constexpr
index_t
m_per_wave
=
2
;
static
constexpr
index_t
n_per_wave
=
32
;
static
constexpr
index_t
m_per_lanegroup
=
2
;
static
constexpr
index_t
n_per_lanegroup
=
8
;
static
constexpr
index_t
m_per_thread
=
2
;
static
constexpr
index_t
n_per_thread
=
1
;
static
constexpr
index_t
k_per_dpp
=
2
;
static
constexpr
bool
share_a
=
true
;
using
BaseType
=
half_t
;
template
<
index_t
MPerDpp
,
index_t
NPerDpp
,
class
ADataType
,
class
BDataType
,
class
CDataType
>
__device__
void
run
(
const
ADataType
&
a
,
const
BDataType
&
b
,
CDataType
&
reg_c
)
const
{
dpp8
::
DppLanegroupGemm
<
m_per_thread
,
n_per_thread
,
k_per_dpp
,
BaseType
,
ADataType
,
BDataType
,
CDataType
,
share_a
>
{}
.
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
dpp_type
<
DppInstr
::
dpp8_f16_2x16x2
>
{
static
constexpr
index_t
wave_size
=
32
;
static
constexpr
index_t
lanegroup_size
=
8
;
static
constexpr
index_t
m_per_wave
=
2
;
static
constexpr
index_t
n_per_wave
=
16
;
static
constexpr
index_t
m_per_lanegroup
=
1
;
static
constexpr
index_t
n_per_lanegroup
=
8
;
static
constexpr
index_t
m_per_thread
=
1
;
static
constexpr
index_t
n_per_thread
=
1
;
static
constexpr
index_t
k_per_dpp
=
2
;
static
constexpr
bool
share_a
=
true
;
using
BaseType
=
half_t
;
template
<
index_t
MPerDpp
,
index_t
NPerDpp
,
class
ADataType
,
class
BDataType
,
class
CDataType
>
__device__
void
run
(
const
ADataType
&
a
,
const
BDataType
&
b
,
CDataType
&
reg_c
)
const
{
dpp8
::
DppLanegroupGemm
<
m_per_thread
,
n_per_thread
,
k_per_dpp
,
BaseType
,
ADataType
,
BDataType
,
CDataType
,
share_a
>
{}
.
Run
(
a
,
b
,
reg_c
);
}
};
template
<
typename
BaseType
,
index_t
MPerDpp
,
index_t
NPerDpp
>
struct
DppSelector
{
template
<
typename
BaseType_
,
index_t
MPerDpp_
,
index_t
NPerDpp_
>
static
constexpr
auto
GetDpp
();
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
8
,
32
>
()
{
return
DppInstr
::
dpp8_f16_8x32x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
8
,
16
>
()
{
return
DppInstr
::
dpp8_f16_8x16x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
16
,
16
>
()
{
return
DppInstr
::
dpp8_f16_16x16x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
32
,
8
>
()
{
return
DppInstr
::
dpp8_f16_32x8x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
1
,
32
>
()
{
return
DppInstr
::
dpp8_f16_1x32x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
2
,
32
>
()
{
return
DppInstr
::
dpp8_f16_2x32x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
2
,
16
>
()
{
return
DppInstr
::
dpp8_f16_2x16x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
4
,
16
>
()
{
return
DppInstr
::
dpp8_f16_4x16x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
4
,
32
>
()
{
return
DppInstr
::
dpp8_f16_4x32x2
;
}
static
constexpr
auto
selected_dpp
=
dpp_type
<
GetDpp
<
BaseType
,
MPerDpp
,
NPerDpp
>
()
>
{};
__host__
__device__
constexpr
DppSelector
()
{
static_assert
(
selected_dpp
.
m_per_wave
%
selected_dpp
.
m_per_lanegroup
==
0
);
static_assert
(
selected_dpp
.
n_per_wave
%
selected_dpp
.
n_per_lanegroup
==
0
);
static_assert
(
selected_dpp
.
k_per_dpp
%
2
==
0
);
static_assert
(
selected_dpp
.
wave_size
%
selected_dpp
.
lanegroup_size
==
0
);
constexpr
index_t
num_dpp_per_wave
=
selected_dpp
.
wave_size
/
selected_dpp
.
lanegroup_size
;
constexpr
index_t
num_wave_c_elems
=
selected_dpp
.
m_per_wave
*
selected_dpp
.
n_per_wave
;
constexpr
index_t
num_dpp_c_elems
=
selected_dpp
.
m_per_lanegroup
*
selected_dpp
.
n_per_lanegroup
;
static_assert
(
num_wave_c_elems
%
num_dpp_c_elems
==
0
);
static_assert
(
num_dpp_per_wave
==
num_wave_c_elems
/
num_dpp_c_elems
);
if
constexpr
(
selected_dpp
.
share_a
)
{
static_assert
(
selected_dpp
.
m_per_lanegroup
==
selected_dpp
.
m_per_thread
);
static_assert
(
selected_dpp
.
n_per_lanegroup
%
selected_dpp
.
n_per_thread
==
0
);
static_assert
(
selected_dpp
.
n_per_lanegroup
/
selected_dpp
.
n_per_thread
==
selected_dpp
.
lanegroup_size
);
}
else
{
static_assert
(
selected_dpp
.
m_per_lanegroup
%
selected_dpp
.
n_per_thread
==
0
);
static_assert
(
selected_dpp
.
m_per_lanegroup
/
selected_dpp
.
n_per_thread
==
selected_dpp
.
lanegroup_size
);
static_assert
(
selected_dpp
.
n_per_lanegroup
==
selected_dpp
.
n_per_thread
);
}
// Below checks come from the restrictions of the current implementation, could be removed
// in the future when the implementation is more generalized.
static_assert
(
selected_dpp
.
share_a
);
static_assert
(
selected_dpp
.
n_per_thread
==
1
);
static_assert
(
selected_dpp
.
m_per_lanegroup
==
selected_dpp
.
m_per_thread
);
static_assert
(
selected_dpp
.
n_per_lanegroup
==
selected_dpp
.
n_per_thread
*
selected_dpp
.
lanegroup_size
);
}
static
constexpr
index_t
GetK1PerDpp
()
{
return
selected_dpp
.
k_per_dpp
;
}
};
template
<
typename
BaseType
,
index_t
MPerDpp
,
index_t
NPerDpp
,
index_t
KPack
>
struct
DppGemm
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
using
CIndex
=
MultiIndex
<
2
>
;
using
CIndex4D
=
MultiIndex
<
4
>
;
__host__
__device__
constexpr
DppGemm
()
{
static_assert
(
KPack
%
dpp_instr
.
k_per_dpp
==
0
,
"KPack must be divisible by k_per_dpp."
);
}
__device__
static
constexpr
index_t
GetRegSizePerDpp
()
{
return
MPerDpp
*
NPerDpp
/
dpp_instr
.
wave_size
;
}
template
<
class
ADataType
,
class
BDataType
,
class
CDataType
>
__device__
void
Run
(
const
ADataType
&
p_a_wave
,
const
BDataType
&
p_b_wave
,
CDataType
&
p_c_thread
)
const
{
static_assert
(
is_same
<
BaseType
,
double
>::
value
||
is_same
<
BaseType
,
float
>::
value
||
is_same
<
BaseType
,
half_t
>::
value
||
is_same
<
BaseType
,
bhalf_t
>::
value
||
is_same
<
BaseType
,
int8_t
>::
value
||
is_same
<
BaseType
,
f8_t
>::
value
,
"base BaseType must be double, float, half, bfloat16, and int8_t!"
);
static_for
<
0
,
KPack
/
dpp_instr
.
k_per_dpp
,
1
>
{}([
&
](
auto
k
)
{
dpp_instr
.
template
run
<
MPerDpp
,
NPerDpp
>(
p_a_wave
[
k
],
p_b_wave
[
k
],
p_c_thread
);
});
}
__device__
static
auto
GetLaneIdInWave
()
{
return
get_thread_local_1d_id
()
%
dpp_instr
.
wave_size
;
}
__device__
static
auto
GetWaveId
()
{
return
get_thread_local_1d_id
()
/
dpp_instr
.
wave_size
;
}
__device__
static
auto
GetLaneIdInLaneGroup
()
{
return
get_thread_local_1d_id
()
%
dpp_instr
.
lanegroup_size
;
}
__device__
static
auto
GetLaneGroupIdInWave
()
{
return
GetLaneIdInWave
()
/
dpp_instr
.
lanegroup_size
;
}
__device__
static
auto
GetDppOpIdx
()
{
const
auto
lanegroupId
=
GetLaneGroupIdInWave
();
constexpr
auto
lanegroup_idx_1d_to_dpp_idx_2d_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
dpp_instr
.
m_per_wave
/
dpp_instr
.
m_per_lanegroup
,
dpp_instr
.
n_per_wave
/
dpp_instr
.
n_per_lanegroup
))),
make_tuple
(
Sequence
<
0
,
1
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
dpp_idx
=
lanegroup_idx_1d_to_dpp_idx_2d_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
lanegroupId
));
const
auto
m_dpp_idx
=
dpp_idx
[
I0
];
const
auto
n_dpp_idx
=
dpp_idx
[
I1
];
return
make_tuple
(
m_dpp_idx
,
n_dpp_idx
);
}
__host__
__device__
static
auto
CalculateAThreadOriginDataIndex_K_M
()
{
const
auto
laneId
=
get_thread_local_1d_id
();
const
auto
wave_row
=
laneId
/
dpp_instr
.
n_per_wave
;
auto
m_idx
=
dpp_instr
.
m_per_thread
*
wave_row
+
GetLaneIdInLaneGroup
();
return
make_tuple
(
0
,
m_idx
%
dpp_instr
.
m_per_wave
);
}
__host__
__device__
static
auto
CalculateBThreadOriginDataIndex_K_N
()
{
const
auto
laneId
=
get_thread_local_1d_id
();
return
make_tuple
(
0
,
laneId
%
dpp_instr
.
n_per_wave
);
}
__device__
static
CIndex
GetBeginOfThreadBlk
()
{
const
auto
dpp_op_idx
=
GetDppOpIdx
();
const
auto
m_dpp_op_idx
=
dpp_op_idx
[
I0
];
const
auto
n_dpp_op_idx
=
dpp_op_idx
[
I1
];
index_t
n_offset
=
n_dpp_op_idx
*
dpp_instr
.
n_per_lanegroup
+
GetLaneIdInLaneGroup
();
index_t
m_offset
=
m_dpp_op_idx
*
dpp_instr
.
m_per_lanegroup
;
return
CIndex
{
m_offset
,
n_offset
};
}
static
constexpr
auto
dpp
=
DppSelector
<
BaseType
,
MPerDpp
,
NPerDpp
>
{};
static
constexpr
auto
dpp_instr
=
dpp
.
selected_dpp
;
static
constexpr
auto
K0PerDpp
=
1
;
static
constexpr
auto
K1PerDpp
=
dpp
.
GetK1PerDpp
();
__host__
__device__
static
constexpr
auto
GetCMNThreadBlkLengths
()
{
return
make_tuple
(
Number
<
dpp_instr
.
m_per_thread
>
{},
Number
<
dpp_instr
.
n_per_thread
>
{});
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
View file @
3c4fb1dd
...
...
@@ -31,7 +31,13 @@ enum struct MfmaInstr
mfma_i32_16x16x32i8
,
mfma_f64_16x16x4f64
,
mfma_f32_32x32x16f8f8
,
mfma_f32_16x16x32f8f8
mfma_f32_16x16x32f8f8
,
mfma_f32_32x32x16bf8bf8
,
mfma_f32_16x16x32bf8bf8
,
mfma_f32_32x32x16f8bf8
,
mfma_f32_16x16x32f8bf8
,
mfma_f32_32x32x16bf8f8
,
mfma_f32_16x16x32bf8f8
};
template
<
MfmaInstr
instr
>
...
...
@@ -500,10 +506,148 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8f8>
}
};
template
<
typename
base_type
,
index_t
MPerXdlops
,
index_t
NPerXdlops
>
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x16bf8bf8
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
4
;
static
constexpr
index_t
num_regs_per_blk
=
16
;
static
constexpr
index_t
num_threads_per_blk
=
32
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
2
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
32
;
static
constexpr
index_t
n_per_blk
=
32
;
static
constexpr
index_t
k_per_blk
=
8
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_f32_32x32x16bf8bf8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_16x16x32bf8bf8
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
1
;
static
constexpr
index_t
num_regs_per_blk
=
4
;
static
constexpr
index_t
num_threads_per_blk
=
16
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
4
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
16
;
static
constexpr
index_t
n_per_blk
=
16
;
static
constexpr
index_t
k_per_blk
=
8
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_f32_16x16x32bf8bf8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x16f8bf8
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
4
;
static
constexpr
index_t
num_regs_per_blk
=
16
;
static
constexpr
index_t
num_threads_per_blk
=
32
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
2
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
32
;
static
constexpr
index_t
n_per_blk
=
32
;
static
constexpr
index_t
k_per_blk
=
8
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_f32_32x32x16f8bf8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_16x16x32f8bf8
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
1
;
static
constexpr
index_t
num_regs_per_blk
=
4
;
static
constexpr
index_t
num_threads_per_blk
=
16
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
4
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
16
;
static
constexpr
index_t
n_per_blk
=
16
;
static
constexpr
index_t
k_per_blk
=
8
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_f32_16x16x32f8bf8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x16bf8f8
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
4
;
static
constexpr
index_t
num_regs_per_blk
=
16
;
static
constexpr
index_t
num_threads_per_blk
=
32
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
2
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
32
;
static
constexpr
index_t
n_per_blk
=
32
;
static
constexpr
index_t
k_per_blk
=
8
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_f32_32x32x16bf8f8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_16x16x32bf8f8
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
1
;
static
constexpr
index_t
num_regs_per_blk
=
4
;
static
constexpr
index_t
num_threads_per_blk
=
16
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
4
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
16
;
static
constexpr
index_t
n_per_blk
=
16
;
static
constexpr
index_t
k_per_blk
=
8
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_f32_16x16x32bf8f8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
template
<
typename
base_type
,
index_t
MPerXdlops
,
index_t
NPerXdlops
,
typename
additional_type
=
base_type
>
struct
MfmaSelector
{
template
<
typename
base_type_
,
index_t
MPerXdlops_
,
index_t
NPerXdlops_
>
template
<
typename
base_type_
,
index_t
MPerXdlops_
,
index_t
NPerXdlops_
,
typename
additional_type_
=
base_type_
>
static
constexpr
auto
GetMfma
();
template
<
>
...
...
@@ -652,7 +796,44 @@ struct MfmaSelector
return
MfmaInstr
::
mfma_f32_16x16x32f8f8
;
}
static
constexpr
auto
selected_mfma
=
mfma_type
<
GetMfma
<
base_type
,
MPerXdlops
,
NPerXdlops
>
()
>
{};
template
<
>
static
constexpr
auto
GetMfma
<
bf8_t
,
32
,
32
>
()
{
return
MfmaInstr
::
mfma_f32_32x32x16bf8bf8
;
}
template
<
>
static
constexpr
auto
GetMfma
<
bf8_t
,
16
,
16
>
()
{
return
MfmaInstr
::
mfma_f32_16x16x32bf8bf8
;
}
template
<
>
static
constexpr
auto
GetMfma
<
f8_t
,
32
,
32
,
bf8_t
>
()
{
return
MfmaInstr
::
mfma_f32_32x32x16f8bf8
;
}
template
<
>
static
constexpr
auto
GetMfma
<
f8_t
,
16
,
16
,
bf8_t
>
()
{
return
MfmaInstr
::
mfma_f32_16x16x32f8bf8
;
}
template
<
>
static
constexpr
auto
GetMfma
<
bf8_t
,
32
,
32
,
f8_t
>
()
{
return
MfmaInstr
::
mfma_f32_32x32x16bf8f8
;
}
template
<
>
static
constexpr
auto
GetMfma
<
bf8_t
,
16
,
16
,
f8_t
>
()
{
return
MfmaInstr
::
mfma_f32_16x16x32bf8f8
;
}
static
constexpr
auto
selected_mfma
=
mfma_type
<
GetMfma
<
base_type
,
MPerXdlops
,
NPerXdlops
,
additional_type
>
()
>
{};
__host__
__device__
constexpr
MfmaSelector
()
{
...
...
@@ -699,6 +880,7 @@ template <typename base_type,
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
KPack
,
typename
additional_type
=
base_type
,
bool
TransposeC
=
false
>
struct
XdlopsGemm
{
...
...
@@ -850,10 +1032,14 @@ struct XdlopsGemm
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
Run
(
const
FloatA
&
p_a_wave
,
const
FloatB
&
p_b_wave
,
FloatC
&
p_c_thread
)
const
{
static_assert
(
is_same
<
base_type
,
double
>::
value
||
is_same
<
base_type
,
float
>::
value
||
static_assert
(
is_same
<
base_type
,
double
>::
value
||
is_same
<
base_type
,
float
>::
value
||
is_same
<
base_type
,
half_t
>::
value
||
is_same
<
base_type
,
bhalf_t
>::
value
||
is_same
<
base_type
,
int8_t
>::
value
||
is_same
<
base_type
,
f8_t
>::
value
,
"base base_type must be double, float, half, bfloat16, and int8_t!"
);
is_same
<
base_type
,
int8_t
>::
value
||
is_same
<
base_type
,
f8_t
>::
value
||
is_same
<
base_type
,
bf8_t
>::
value
||
(
is_same
<
base_type
,
f8_t
>::
value
&&
is_same
<
additional_type
,
bf8_t
>::
value
)
||
(
is_same
<
base_type
,
bf8_t
>::
value
&&
is_same
<
additional_type
,
f8_t
>::
value
),
"base base_type must be double, float, half, bfloat16, int8_t, f8_t or bf8_t!"
);
static_for
<
0
,
KPack
/
mfma_instr
.
k_per_blk
,
1
>
{}([
&
](
auto
k
)
{
if
constexpr
(
!
TransposeC
)
...
...
@@ -949,7 +1135,7 @@ struct XdlopsGemm
return
TransposeC
?
CIndex4D
{
blk_td
,
I0
,
blk_id
,
I0
}
:
CIndex4D
{
I0
,
blk_id
,
I0
,
blk_td
};
}
static
constexpr
auto
mfma
=
MfmaSelector
<
base_type
,
MPerXdlops
,
NPerXdlops
>
{};
static
constexpr
auto
mfma
=
MfmaSelector
<
base_type
,
MPerXdlops
,
NPerXdlops
,
additional_type
>
{};
static
constexpr
auto
mfma_instr
=
mfma
.
selected_mfma
;
...
...
include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp
View file @
3c4fb1dd
...
...
@@ -164,6 +164,7 @@ template <
index_t
BK1
,
index_t
GemmMPerBlock
,
index_t
GemmNPerBlock
,
index_t
GemmKPerBlock
,
bool
DoPadGemmM
,
bool
DoPadGemmN
>
struct
TransformConvBwdDataToGemm_v1
...
...
@@ -308,9 +309,6 @@ struct TransformConvBwdDataToGemm_v1
const
auto
YDotSlice
=
math
::
integer_divide_ceil
(
Y
-
i_ytilde
,
YTilde
);
const
auto
XDotSlice
=
math
::
integer_divide_ceil
(
X
-
i_xtilde
,
XTilde
);
const
index_t
AK0
=
math
::
integer_divide_ceil
(
ZDotSlice
*
YDotSlice
*
XDotSlice
*
K
,
AK1
);
if
constexpr
(
NDimSpatial
==
2
)
{
// A: output tensor
...
...
@@ -367,9 +365,11 @@ struct TransformConvBwdDataToGemm_v1
const
auto
out_gemmk_gemmm_padded_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
out_gemmk_gemmmraw_grid_desc
,
make_tuple
(
AK1
,
GemmMPerBlock
),
make_tuple
(
GemmKPerBlock
,
GemmMPerBlock
),
Sequence
<
true
,
DoPadGemmM
>
{});
const
index_t
AK0
=
out_gemmk_gemmm_padded_grid_desc
.
GetLength
(
I0
)
/
AK1
;
const
auto
out_gemmak0_gemmm_gemmak1_grid_desc
=
transform_tensor_descriptor
(
out_gemmk_gemmm_padded_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
...
...
@@ -460,9 +460,11 @@ struct TransformConvBwdDataToGemm_v1
const
auto
out_gemmk_gemmm_padded_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
out_gemmk_gemmmraw_grid_desc
,
make_tuple
(
AK1
,
GemmMPerBlock
),
make_tuple
(
GemmKPerBlock
,
GemmMPerBlock
),
Sequence
<
true
,
DoPadGemmM
>
{});
const
index_t
AK0
=
out_gemmk_gemmm_padded_grid_desc
.
GetLength
(
I0
)
/
AK1
;
const
auto
out_gemmak0_gemmm_gemmak1_grid_desc
=
transform_tensor_descriptor
(
out_gemmk_gemmm_padded_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
...
...
@@ -568,9 +570,6 @@ struct TransformConvBwdDataToGemm_v1
const
auto
YDotSlice
=
math
::
integer_divide_ceil
(
Y
-
i_ytilde
,
YTilde
);
const
auto
XDotSlice
=
math
::
integer_divide_ceil
(
X
-
i_xtilde
,
XTilde
);
const
index_t
BK0
=
math
::
integer_divide_ceil
(
ZDotSlice
*
YDotSlice
*
XDotSlice
*
K
,
BK1
);
// B weight tensor
if
constexpr
(
NDimSpatial
==
2
)
{
...
...
@@ -617,9 +616,11 @@ struct TransformConvBwdDataToGemm_v1
const
auto
wei_gemmk_gemmn_padded_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
wei_gemmk_gemmnraw_grid_desc
,
make_tuple
(
BK1
,
GemmNPerBlock
),
make_tuple
(
GemmKPerBlock
,
GemmNPerBlock
),
Sequence
<
true
,
DoPadGemmN
>
{});
const
index_t
BK0
=
wei_gemmk_gemmn_padded_grid_desc
.
GetLength
(
I0
)
/
BK1
;
const
auto
wei_gemmbk0_gemmn_gemmbk1_grid_desc
=
transform_tensor_descriptor
(
wei_gemmk_gemmn_padded_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
...
...
@@ -690,17 +691,19 @@ struct TransformConvBwdDataToGemm_v1
make_tuple
(
Sequence
<
1
,
2
,
3
,
0
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
wei_gemmk_gemm_padded_grid_desc
=
const
auto
wei_gemmk_gemm
n
_padded_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
wei_gemmk_gemmnraw_grid_desc
,
make_tuple
(
BK1
,
GemmNPerBlock
),
make_tuple
(
GemmKPerBlock
,
GemmNPerBlock
),
Sequence
<
true
,
DoPadGemmN
>
{});
const
index_t
BK0
=
wei_gemmk_gemmn_padded_grid_desc
.
GetLength
(
I0
)
/
BK1
;
const
auto
wei_gemmbk0_gemm_gemmbk1_grid_desc
=
transform_tensor_descriptor
(
wei_gemmk_gemm_padded_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
wei_gemmk_gemm_padded_grid_desc
.
GetLength
(
I1
))),
wei_gemmk_gemm
n
_padded_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
wei_gemmk_gemm
n
_padded_grid_desc
.
GetLength
(
I1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
...
...
include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp
View file @
3c4fb1dd
...
...
@@ -20,348 +20,13 @@ struct TransformConvFwdToGemm
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
template
<
typename
ALayout
,
typename
std
::
enable_if
<
NDimSpatial
==
1
&&
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNWC
>,
bool
>::
type
=
false
>
static
auto
MakeADescriptor_M_K
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* a_g_n_c_wis_strides */
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* b_g_k_c_xs_strides */
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* c_g_n_k_wos_strides */
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
)
{
const
index_t
N
=
a_g_n_c_wis_lengths
[
1
];
const
index_t
C
=
a_g_n_c_wis_lengths
[
2
];
const
index_t
Wi
=
a_g_n_c_wis_lengths
[
3
];
const
index_t
Wo
=
c_g_n_k_wos_lengths
[
3
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
0
];
if
constexpr
(
ConvForwardSpecialization
==
device
::
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
const
index_t
NWo
=
N
*
ck
::
accumulate_n
<
index_t
>
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
const
auto
in_gemmm_gemmk_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
NWo
,
C
));
return
in_gemmm_gemmk_desc
;
}
else
if
constexpr
(
ConvForwardSpecialization
==
device
::
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
{
const
auto
in_n_wi_c_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Wi
,
C
));
const
auto
in_n_wo_c_desc
=
transform_tensor_descriptor
(
in_n_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Wo
),
make_tuple
(
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
in_gemmm_gemmk_desc
=
transform_tensor_descriptor
(
in_n_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Wo
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
in_gemmm_gemmk_desc
;
}
else
{
const
index_t
X
=
b_g_k_c_xs_lengths
[
3
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
0
];
const
index_t
InLeftPadW
=
input_left_pads
[
0
];
const
index_t
InRightPadW
=
input_right_pads
[
0
];
const
auto
in_n_wi_c_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Wi
,
C
));
const
auto
in_n_wip_c_desc
=
transform_tensor_descriptor
(
in_n_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
in_n_x_wo_c_desc
=
transform_tensor_descriptor
(
in_n_wip_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_gemmm_gemmk_desc
=
transform_tensor_descriptor
(
in_n_x_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Wo
)),
make_merge_transform
(
make_tuple
(
X
,
C
))),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
in_gemmm_gemmk_desc
;
}
}
template
<
typename
ALayout
,
typename
std
::
enable_if
<
NDimSpatial
==
2
&&
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNHWC
>,
bool
>::
type
=
false
>
static
auto
MakeADescriptor_M_K
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* a_g_n_c_wis_strides */
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* b_g_k_c_xs_strides */
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* c_g_n_k_wos_strides */
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
)
{
const
index_t
N
=
a_g_n_c_wis_lengths
[
1
];
const
index_t
C
=
a_g_n_c_wis_lengths
[
2
];
const
index_t
Hi
=
a_g_n_c_wis_lengths
[
3
];
const
index_t
Wi
=
a_g_n_c_wis_lengths
[
4
];
const
index_t
Ho
=
c_g_n_k_wos_lengths
[
3
];
const
index_t
Wo
=
c_g_n_k_wos_lengths
[
4
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
1
];
if
constexpr
(
ConvForwardSpecialization
==
device
::
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
const
index_t
NHoWo
=
N
*
ck
::
accumulate_n
<
index_t
>
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
const
auto
in_gemmm_gemmk_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
NHoWo
,
C
));
return
in_gemmm_gemmk_desc
;
}
else
if
constexpr
(
ConvForwardSpecialization
==
device
::
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
{
const
auto
in_n_hi_wi_c_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Hi
,
Wi
,
C
));
const
auto
in_n_ho_wo_c_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Ho
),
make_tuple
(
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
Wo
),
make_tuple
(
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_gemmm_gemmk_desc
=
transform_tensor_descriptor
(
in_n_ho_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
in_gemmm_gemmk_desc
;
}
else
{
const
index_t
Y
=
b_g_k_c_xs_lengths
[
3
];
const
index_t
X
=
b_g_k_c_xs_lengths
[
4
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
1
];
const
index_t
InLeftPadH
=
input_left_pads
[
0
];
const
index_t
InLeftPadW
=
input_left_pads
[
1
];
const
index_t
InRightPadH
=
input_right_pads
[
0
];
const
index_t
InRightPadW
=
input_right_pads
[
1
];
const
auto
in_n_hi_wi_c_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Hi
,
Wi
,
C
));
const
auto
in_n_hip_wip_c_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_y_ho_x_wo_c_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_gemmm_gemmk_desc
=
transform_tensor_descriptor
(
in_n_y_ho_x_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
)),
make_merge_transform
(
make_tuple
(
Y
,
X
,
C
))),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
in_gemmm_gemmk_desc
;
}
}
template
<
typename
ALayout
,
typename
std
::
enable_if
<
NDimSpatial
==
3
&&
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNDHWC
>,
bool
>::
type
=
false
>
static
auto
MakeADescriptor_M_K
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* a_g_n_c_wis_strides */
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* b_g_k_c_xs_strides */
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* c_g_n_k_wos_strides */
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
)
{
const
index_t
N
=
a_g_n_c_wis_lengths
[
1
];
const
index_t
C
=
a_g_n_c_wis_lengths
[
2
];
const
index_t
Di
=
a_g_n_c_wis_lengths
[
3
];
const
index_t
Hi
=
a_g_n_c_wis_lengths
[
4
];
const
index_t
Wi
=
a_g_n_c_wis_lengths
[
5
];
const
index_t
Do
=
c_g_n_k_wos_lengths
[
3
];
const
index_t
Ho
=
c_g_n_k_wos_lengths
[
4
];
const
index_t
Wo
=
c_g_n_k_wos_lengths
[
5
];
const
index_t
ConvStrideD
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
1
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
2
];
if
constexpr
(
ConvForwardSpecialization
==
device
::
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
const
index_t
NDoHoWo
=
N
*
ck
::
accumulate_n
<
index_t
>
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
const
auto
in_gemmm_gemmk_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
NDoHoWo
,
C
));
return
in_gemmm_gemmk_desc
;
}
else
if
constexpr
(
ConvForwardSpecialization
==
device
::
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
{
const
auto
in_n_di_hi_wi_c_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
C
));
const
auto
in_n_do_ho_wo_c_desc
=
transform_tensor_descriptor
(
in_n_di_hi_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Do
),
make_tuple
(
ConvStrideD
)),
make_embed_transform
(
make_tuple
(
Ho
),
make_tuple
(
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
Wo
),
make_tuple
(
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
auto
in_gemmm_gemmk_desc
=
transform_tensor_descriptor
(
in_n_do_ho_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
in_gemmm_gemmk_desc
;
}
else
{
const
index_t
Z
=
b_g_k_c_xs_lengths
[
3
];
const
index_t
Y
=
b_g_k_c_xs_lengths
[
4
];
const
index_t
X
=
b_g_k_c_xs_lengths
[
5
];
const
index_t
ConvDilationD
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
1
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
2
];
const
index_t
InLeftPadD
=
input_left_pads
[
0
];
const
index_t
InLeftPadH
=
input_left_pads
[
1
];
const
index_t
InLeftPadW
=
input_left_pads
[
2
];
const
index_t
InRightPadD
=
input_right_pads
[
0
];
const
index_t
InRightPadH
=
input_right_pads
[
1
];
const
index_t
InRightPadW
=
input_right_pads
[
2
];
const
auto
in_n_di_hi_wi_c_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
C
));
const
auto
in_n_hip_wip_c_desc
=
transform_tensor_descriptor
(
in_n_di_hi_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Di
,
InLeftPadD
,
InRightPadD
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
auto
in_n_z_do_y_ho_x_wo_c_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Z
,
Do
),
make_tuple
(
ConvDilationD
,
ConvStrideD
)),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
,
6
>
{},
Sequence
<
7
>
{}));
const
auto
in_gemmm_gemmk_desc
=
transform_tensor_descriptor
(
in_n_z_do_y_ho_x_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
)),
make_merge_transform
(
make_tuple
(
Z
,
Y
,
X
,
C
))),
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
>
{},
Sequence
<
1
,
3
,
5
,
7
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
in_gemmm_gemmk_desc
;
}
}
// TODO: implement ck::tensor_layout::convolution that describe packed/strided dimemsion as
// properties
template
<
typename
ALayout
,
typename
std
::
enable_if
<
NDimSpatial
==
1
&&
(
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
G_NW_C
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
NWGC
>
),
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
NWGC
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNWC
>
),
bool
>::
type
=
false
>
static
auto
MakeADescriptor_M_K
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
...
...
@@ -473,7 +138,8 @@ struct TransformConvFwdToGemm
template
<
typename
ALayout
,
typename
std
::
enable_if
<
NDimSpatial
==
2
&&
(
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
G_NHW_C
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
NHWGC
>
),
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
NHWGC
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNHWC
>
),
bool
>::
type
=
false
>
static
auto
MakeADescriptor_M_K
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
...
...
@@ -601,7 +267,8 @@ struct TransformConvFwdToGemm
template
<
typename
ALayout
,
typename
std
::
enable_if
<
NDimSpatial
==
3
&&
(
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
G_NDHW_C
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
NDHWGC
>
),
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
NDHWGC
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNDHWC
>
),
bool
>::
type
=
false
>
static
auto
MakeADescriptor_M_K
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
...
...
include/ck/utility/amd_buffer_addressing.hpp
View file @
3c4fb1dd
...
...
@@ -299,368 +299,146 @@ enum struct AmdBufferCoherenceEnum
GLC_SLC
=
3
,
};
template
<
typename
T
,
index_t
N
,
AmdBufferCoherenceEnum
coherence
=
AmdBufferCoherenceEnum
::
DefaultCoherence
>
__device__
typename
vector_type
<
T
,
N
>::
type
amd_buffer_load_impl
(
int32x4_t
src_wave_buffer_resource
,
template
<
index_t
N
,
AmdBufferCoherenceEnum
coherence
=
AmdBufferCoherenceEnum
::
DefaultCoherence
>
__device__
typename
vector_type
<
int8_t
,
N
>::
type
amd_buffer_load_impl_raw
(
int32x4_t
src_wave_buffer_resource
,
index_t
src_thread_addr_offset
,
index_t
src_wave_addr_offset
)
{
static_assert
(
(
is_same
<
T
,
double
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
))
||
(
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
half_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
bhalf_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
int32_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
)),
static_assert
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
||
N
==
32
||
N
==
64
,
"wrong! not implemented"
);
if
constexpr
(
is_same
<
T
,
double
>::
value
)
{
// use fp32 load to mimic fp64 load
if
constexpr
(
N
==
1
)
{
const
float2_t
tmp
=
llvm_amdgcn_raw_buffer_load_fp32x2
(
src_wave_buffer_resource
,
return
llvm_amdgcn_raw_buffer_load_i8
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
return
bit_cast
<
double
>
(
tmp
);
}
else
if
constexpr
(
N
==
2
)
{
const
float4_t
tmp
=
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
return
bit_cast
<
double2_t
>
(
tmp
);
}
else
if
constexpr
(
N
==
4
)
{
const
float4_t
f32_0
=
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
int16_t
tmp
=
llvm_amdgcn_raw_buffer_load_i16
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
const
float4_t
f32_1
=
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
+
4
*
sizeof
(
float
),
static_cast
<
index_t
>
(
coherence
));
vector_type
<
double
,
4
>
tmp
;
tmp
.
AsType
<
double2_t
>
()(
Number
<
0
>
{})
=
bit_cast
<
double2_t
>
(
f32_0
);
tmp
.
AsType
<
double2_t
>
()(
Number
<
1
>
{})
=
bit_cast
<
double2_t
>
(
f32_1
);
return
tmp
.
AsType
<
double4_t
>
()(
Number
<
0
>
{});
}
}
else
if
constexpr
(
is_same
<
T
,
float
>::
value
)
{
if
constexpr
(
N
==
1
)
{
return
llvm_amdgcn_raw_buffer_load_fp32
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
return
llvm_amdgcn_raw_buffer_load_fp32x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
return
bit_cast
<
int8x2_t
>
(
tmp
);
}
else
if
constexpr
(
N
==
4
)
{
return
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
8
)
{
vector_type
<
float
,
8
>
tmp
;
tmp
.
AsType
<
float4_t
>
()(
Number
<
0
>
{})
=
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
int32_t
tmp
=
llvm_amdgcn_raw_buffer_load_i32
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
tmp
.
AsType
<
float4_t
>
()(
Number
<
1
>
{})
=
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
+
4
*
sizeof
(
float
),
static_cast
<
index_t
>
(
coherence
));
return
tmp
.
AsType
<
float8_t
>
()(
Number
<
0
>
{});
}
}
else
if
constexpr
(
is_same
<
T
,
half_t
>::
value
)
{
if
constexpr
(
N
==
1
)
{
return
llvm_amdgcn_raw_buffer_load_fp16
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
return
llvm_amdgcn_raw_buffer_load_fp16x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
4
)
{
return
llvm_amdgcn_raw_buffer_load_fp16x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
return
bit_cast
<
int8x4_t
>
(
tmp
);
}
else
if
constexpr
(
N
==
8
)
{
// use fp32 load to mimic fp16 load
float4_t
tmp
=
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
int32x2_t
tmp
=
llvm_amdgcn_raw_buffer_load_i32x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
return
bit_cast
<
half8_t
>
(
tmp
);
}
}
else
if
constexpr
(
is_same
<
T
,
bhalf_t
>::
value
)
{
if
constexpr
(
N
==
1
)
{
return
llvm_amdgcn_raw_buffer_load_i16
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
return
llvm_amdgcn_raw_buffer_load_i16x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
4
)
{
return
llvm_amdgcn_raw_buffer_load_i16x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
return
bit_cast
<
int8x8_t
>
(
tmp
);
}
else
if
constexpr
(
N
==
8
)
else
if
constexpr
(
N
==
16
)
{
int32x4_t
tmp
=
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
return
bit_cast
<
bhalf8_t
>
(
tmp
);
}
}
else
if
constexpr
(
is_same
<
T
,
int32_t
>::
value
)
{
if
constexpr
(
N
==
1
)
{
return
llvm_amdgcn_raw_buffer_load_i32
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
return
llvm_amdgcn_raw_buffer_load_i32x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
4
)
{
return
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
return
bit_cast
<
int8x16_t
>
(
tmp
);
}
else
if
constexpr
(
N
==
8
)
else
if
constexpr
(
N
==
32
)
{
vector_type
<
int32_t
,
8
>
tmp
;
tmp
.
AsType
<
int32x4_t
>
()(
Number
<
0
>
{})
=
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
int32x4_t
tmp0
=
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
tmp
.
AsType
<
int32x4_t
>
()(
Number
<
1
>
{})
=
int32x4_t
tmp1
=
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
+
4
*
sizeof
(
int32_t
),
static_cast
<
index_t
>
(
coherence
));
return
tmp
.
AsType
<
int32x8_t
>
()(
Number
<
0
>
{});
}
}
else
if
constexpr
(
is_same
<
T
,
int8_t
>::
value
)
{
if
constexpr
(
N
==
1
)
{
return
llvm_amdgcn_raw_buffer_load_i8
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
return
llvm_amdgcn_raw_buffer_load_i8x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
#else
int16_t
tmp
=
llvm_amdgcn_raw_buffer_load_i16
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
vector_type
<
int32_t
,
8
>
tmp
;
return
bit_cast
<
int8x2_t
>
(
tmp
);
#endif
}
else
if
constexpr
(
N
==
4
)
{
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
return
llvm_amdgcn_raw_buffer_load_i8x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
#else
int32_t
tmp
=
llvm_amdgcn_raw_buffer_load_i32
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
tmp
.
AsType
<
int32x4_t
>
()(
Number
<
0
>
{})
=
tmp0
;
tmp
.
AsType
<
int32x4_t
>
()(
Number
<
1
>
{})
=
tmp1
;
return
bit_cast
<
int8x4_t
>
(
tmp
);
#endif
return
bit_cast
<
int8x32_t
>
(
tmp
);
}
else
if
constexpr
(
N
==
8
)
else
if
constexpr
(
N
==
64
)
{
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
vector_type
<
int8_t
,
8
>
tmp
;
tmp
.
AsType
<
int8x4_t
>
()(
Number
<
0
>
{})
=
llvm_amdgcn_raw_buffer_load_i8x4
(
src_wave_buffer_resource
,
int32x4_t
tmp0
=
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
tmp
.
AsType
<
int8x4_t
>
()(
Number
<
1
>
{})
=
llvm_amdgcn_raw_buffer_load_i8x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
+
4
*
sizeof
(
int8_t
),
static_cast
<
index_t
>
(
coherence
));
return
tmp
.
AsType
<
int8x8_t
>
()(
Number
<
0
>
{});
#else
int32x2_t
tmp
=
llvm_amdgcn_raw_buffer_load_i32x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
return
bit_cast
<
int8x8_t
>
(
tmp
);
#endif
}
else
if
constexpr
(
N
==
16
)
{
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
vector_type
<
int8_t
,
16
>
tmp
;
tmp
.
AsType
<
int8x4_t
>
()(
Number
<
0
>
{})
=
llvm_amdgcn_raw_buffer_load_i8x4
(
src_wave_buffer_resource
,
int32x4_t
tmp1
=
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
src_wave_addr_offset
+
4
*
sizeof
(
int32_t
)
,
static_cast
<
index_t
>
(
coherence
));
tmp
.
AsType
<
int8x4_t
>
()(
Number
<
1
>
{})
=
llvm_amdgcn_raw_buffer_load_i8x4
(
src_wave_buffer_resource
,
int32x4_t
tmp2
=
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
+
4
*
sizeof
(
int
8
_t
),
src_wave_addr_offset
+
8
*
sizeof
(
int
32
_t
),
static_cast
<
index_t
>
(
coherence
));
tmp
.
AsType
<
int8x4_t
>
()(
Number
<
2
>
{})
=
llvm_amdgcn_raw_buffer_load_i8x4
(
src_wave_buffer_resource
,
int32x4_t
tmp3
=
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
+
8
*
sizeof
(
int
8
_t
),
src_wave_addr_offset
+
12
*
sizeof
(
int
32
_t
),
static_cast
<
index_t
>
(
coherence
));
tmp
.
AsType
<
int8x4_t
>
()(
Number
<
3
>
{})
=
llvm_amdgcn_raw_buffer_load_i8x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
+
12
*
sizeof
(
int8_t
),
static_cast
<
index_t
>
(
coherence
));
vector_type
<
int32_t
,
16
>
tmp
;
return
tmp
.
AsType
<
int8x16_t
>
()(
Number
<
0
>
{});
#else
int32x4_t
tmp
=
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
tmp
.
AsType
<
int32x4_t
>
()(
Number
<
0
>
{})
=
tmp0
;
tmp
.
AsType
<
int32x4_t
>
()(
Number
<
1
>
{})
=
tmp1
;
tmp
.
AsType
<
int32x4_t
>
()(
Number
<
2
>
{})
=
tmp2
;
tmp
.
AsType
<
int32x4_t
>
()(
Number
<
3
>
{})
=
tmp3
;
return
bit_cast
<
int8x16_t
>
(
tmp
);
#endif
}
return
bit_cast
<
int8x64_t
>
(
tmp
);
}
}
template
<
typename
T
,
index_t
N
,
AmdBufferCoherenceEnum
coherence
=
AmdBufferCoherenceEnum
::
DefaultCoherence
>
__device__
void
amd_buffer_store_impl
(
const
typename
vector_type
<
T
,
N
>::
type
src_thread_data
,
__device__
typename
vector_type
<
T
,
N
>::
type
amd_buffer_load_impl
(
int32x4_t
src_wave_buffer_resource
,
index_t
src_thread_addr_offset
,
index_t
src_wave_addr_offset
)
{
static_assert
(
(
is_same
<
T
,
double
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
half_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
bhalf_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
int32_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
f8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
bf8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
)),
"wrong! not implemented"
);
using
r_t
=
typename
vector_type
<
T
,
N
>::
type
;
auto
raw_data
=
amd_buffer_load_impl_raw
<
sizeof
(
T
)
*
N
,
coherence
>
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
);
return
bit_cast
<
r_t
>
(
raw_data
);
}
template
<
index_t
N
,
AmdBufferCoherenceEnum
coherence
=
AmdBufferCoherenceEnum
::
DefaultCoherence
>
__device__
void
amd_buffer_store_impl_raw
(
const
typename
vector_type
<
int8_t
,
N
>::
type
src_thread_data
,
int32x4_t
dst_wave_buffer_resource
,
index_t
dst_thread_addr_offset
,
index_t
dst_wave_addr_offset
)
{
static_assert
(
(
is_same
<
T
,
double
>::
value
&&
(
N
==
1
||
N
==
2
))
||
(
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
half_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
bhalf_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
int32_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
))
||
(
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
)),
static_assert
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
||
N
==
32
||
N
==
64
,
"wrong! not implemented"
);
if
constexpr
(
is_same
<
T
,
double
>::
value
)
{
// use fp32 store to mimic fp64 store
if
constexpr
(
N
==
1
)
{
llvm_amdgcn_raw_buffer_store_fp32x2
(
bit_cast
<
float2_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
llvm_amdgcn_raw_buffer_store_fp32x4
(
bit_cast
<
float4_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
}
else
if
constexpr
(
is_same
<
T
,
float
>::
value
)
{
if
constexpr
(
N
==
1
)
{
llvm_amdgcn_raw_buffer_store_
fp32
(
src_thread_data
,
llvm_amdgcn_raw_buffer_store_
i8
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
...
...
@@ -668,7 +446,8 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
}
else
if
constexpr
(
N
==
2
)
{
llvm_amdgcn_raw_buffer_store_fp32x2
(
src_thread_data
,
llvm_amdgcn_raw_buffer_store_i16
(
bit_cast
<
int16_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
...
...
@@ -676,7 +455,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
}
else
if
constexpr
(
N
==
4
)
{
llvm_amdgcn_raw_buffer_store_
fp32x4
(
src_thread_data
,
llvm_amdgcn_raw_buffer_store_
i32
(
bit_cast
<
int32_t
>
(
src_thread_data
)
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
...
...
@@ -684,199 +463,91 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
}
else
if
constexpr
(
N
==
8
)
{
vector_type
<
float
,
8
>
tmp
{
src_thread_data
};
llvm_amdgcn_raw_buffer_store_fp32x4
(
tmp
.
AsType
<
float4_t
>
()[
Number
<
0
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
llvm_amdgcn_raw_buffer_store_fp32x4
(
tmp
.
AsType
<
float4_t
>
()[
Number
<
1
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
4
*
sizeof
(
float
),
static_cast
<
index_t
>
(
coherence
));
}
}
else
if
constexpr
(
is_same
<
T
,
half_t
>::
value
)
{
if
constexpr
(
N
==
1
)
{
llvm_amdgcn_raw_buffer_store_fp16
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
llvm_amdgcn_raw_buffer_store_fp16x2
(
src_thread_data
,
llvm_amdgcn_raw_buffer_store_i32x2
(
bit_cast
<
int32x2_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
4
)
else
if
constexpr
(
N
==
16
)
{
llvm_amdgcn_raw_buffer_store_
fp16x4
(
src_thread_data
,
llvm_amdgcn_raw_buffer_store_
i32x4
(
bit_cast
<
int32x4_t
>
(
src_thread_data
)
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
8
)
else
if
constexpr
(
N
==
32
)
{
#if 0
vector_type<half_t, 8> tmp{src_thread_data};
vector_type
<
int32_t
,
8
>
tmp
{
bit_cast
<
int32x8_t
>
(
src_thread_data
)};
llvm_amdgcn_raw_buffer_store_
fp16
x4(tmp.AsType<
half
4_t>()[Number<0>{}],
llvm_amdgcn_raw_buffer_store_
i32
x4
(
tmp
.
template
AsType
<
int32x
4_t
>()[
Number
<
0
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
llvm_amdgcn_raw_buffer_store_
fp16
x4(tmp.AsType<
half
4_t>()[Number<1>{}],
llvm_amdgcn_raw_buffer_store_
i32
x4
(
tmp
.
template
AsType
<
int32x
4_t
>()[
Number
<
1
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset +
4 *
sizeof(
half_t)
,
dst_wave_addr_offset
+
sizeof
(
int32_t
)
*
4
,
static_cast
<
index_t
>
(
coherence
));
#else
llvm_amdgcn_raw_buffer_store_fp32x4
(
bit_cast
<
float4_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
#endif
}
}
else
if
constexpr
(
is_same
<
T
,
bhalf_t
>::
value
)
{
if
constexpr
(
N
==
1
)
{
llvm_amdgcn_raw_buffer_store_i16
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
else
if
constexpr
(
N
==
64
)
{
llvm_amdgcn_raw_buffer_store_i16x2
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
4
)
{
llvm_amdgcn_raw_buffer_store_i16x4
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
8
)
{
vector_type
<
bhalf_t
,
8
>
tmp
{
src_thread_data
};
vector_type
<
int32_t
,
16
>
tmp
{
bit_cast
<
int32x16_t
>
(
src_thread_data
)};
llvm_amdgcn_raw_buffer_store_i
16
x4
(
tmp
.
AsType
<
bhalf
4_t
>
()[
Number
<
0
>
{}],
llvm_amdgcn_raw_buffer_store_i
32
x4
(
tmp
.
template
AsType
<
int32x
4_t
>()[
Number
<
0
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
llvm_amdgcn_raw_buffer_store_i
16
x4
(
tmp
.
AsType
<
bhalf
4_t
>
()[
Number
<
1
>
{}],
llvm_amdgcn_raw_buffer_store_i
32
x4
(
tmp
.
template
AsType
<
int32x
4_t
>()[
Number
<
1
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
4
*
sizeof
(
bhalf_t
)
,
dst_wave_addr_offset
+
sizeof
(
int32_t
)
*
4
,
static_cast
<
index_t
>
(
coherence
));
}
}
else
if
constexpr
(
is_same
<
T
,
int32_t
>::
value
)
{
if
constexpr
(
N
==
1
)
{
llvm_amdgcn_raw_buffer_store_i32
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
llvm_amdgcn_raw_buffer_store_i32x2
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
4
)
{
llvm_amdgcn_raw_buffer_store_i32x4
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
}
else
if
constexpr
(
is_same
<
T
,
int8_t
>::
value
)
{
if
constexpr
(
N
==
1
)
{
llvm_amdgcn_raw_buffer_store_i8
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
llvm_amdgcn_raw_buffer_store_i8x2
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
#else
llvm_amdgcn_raw_buffer_store_i16
(
bit_cast
<
int16_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
#endif
}
else
if
constexpr
(
N
==
4
)
{
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
llvm_amdgcn_raw_buffer_store_i8x4
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
#else
llvm_amdgcn_raw_buffer_store_i32
(
bit_cast
<
int32_t
>
(
src_thread_data
),
llvm_amdgcn_raw_buffer_store_i32x4
(
tmp
.
template
AsType
<
int32x4_t
>()[
Number
<
2
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
dst_wave_addr_offset
+
sizeof
(
int32_t
)
*
8
,
static_cast
<
index_t
>
(
coherence
));
#endif
}
else
if
constexpr
(
N
==
8
)
{
llvm_amdgcn_raw_buffer_store_i32x2
(
bit_cast
<
int32x2_t
>
(
src_thread_data
),
llvm_amdgcn_raw_buffer_store_i32x4
(
tmp
.
template
AsType
<
int32x4_t
>()[
Number
<
3
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
dst_wave_addr_offset
+
sizeof
(
int32_t
)
*
12
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
16
)
{
llvm_amdgcn_raw_buffer_store_i32x4
(
bit_cast
<
int32x4_t
>
(
src_thread_data
),
}
template
<
typename
T
,
index_t
N
,
AmdBufferCoherenceEnum
coherence
=
AmdBufferCoherenceEnum
::
DefaultCoherence
>
__device__
void
amd_buffer_store_impl
(
const
typename
vector_type
<
T
,
N
>::
type
src_thread_data
,
int32x4_t
dst_wave_buffer_resource
,
index_t
dst_thread_addr_offset
,
index_t
dst_wave_addr_offset
)
{
static_assert
(
(
is_same
<
T
,
double
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
half_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
bhalf_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
int32_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
f8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
bf8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
)),
"wrong! not implemented"
);
using
r_t
=
typename
vector_type
<
int8_t
,
sizeof
(
T
)
*
N
>::
type
;
amd_buffer_store_impl_raw
<
sizeof
(
T
)
*
N
,
coherence
>
(
bit_cast
<
r_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
}
dst_wave_addr_offset
);
}
template
<
typename
T
,
index_t
N
>
...
...
@@ -1127,31 +798,14 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t
src_addr_shift
=
src_thread_element_valid
?
0
:
0x80000000
;
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
{
auto
tmp
=
amd_buffer_load_impl
<
int8_t
,
vector_size
,
coherence
>
(
src_wave_buffer_resource
,
src_addr_shift
+
src_thread_addr_offset
,
0
);
return
bit_cast
<
vector_t
>
(
tmp
);
}
else
{
return
amd_buffer_load_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_wave_buffer_resource
,
src_addr_shift
+
src_thread_addr_offset
,
0
);
}
#else
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
{
auto
tmp
=
amd_buffer_load_impl
<
int8_t
,
vector_size
,
coherence
>
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
);
return
src_thread_element_valid
?
bit_cast
<
vector_t
>
(
tmp
)
:
vector_t
(
0
);
}
else
{
vector_t
tmp
=
amd_buffer_load_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
);
return
src_thread_element_valid
?
tmp
:
vector_t
(
0
);
}
#endif
}
...
...
@@ -1209,35 +863,14 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t
dst_addr_shift
=
dst_thread_element_valid
?
0
:
0x80000000
;
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
{
auto
tmp
=
bit_cast
<
typename
vector_type_maker
<
int8_t
,
vector_size
>::
type
::
type
>
(
src_thread_data
);
amd_buffer_store_impl
<
int8_t
,
vector_size
,
coherence
>
(
tmp
,
dst_wave_buffer_resource
,
dst_addr_shift
+
dst_thread_addr_offset
,
0
);
}
else
{
amd_buffer_store_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_addr_shift
+
dst_thread_addr_offset
,
0
);
}
#else
if
(
dst_thread_element_valid
)
{
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
{
auto
tmp
=
bit_cast
<
typename
vector_type_maker
<
int8_t
,
vector_size
>::
type
::
type
>
(
src_thread_data
);
amd_buffer_store_impl
<
int8_t
,
vector_size
,
coherence
>
(
tmp
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
0
);
}
else
{
amd_buffer_store_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
0
);
}
}
#endif
}
...
...
include/ck/utility/amd_gemm_dpp.hpp
View file @
3c4fb1dd
...
...
@@ -5,17 +5,63 @@
#include "ck/utility/common_header.hpp"
#include "ck/utility/math.hpp"
#include "ck/utility/
amd_gemm
_dpp.hpp"
#include "ck/utility/
inner_product
_dpp
8
.hpp"
namespace
ck
{
namespace
dpp8
{
/// Number of lanes that can share data using DPP8 modifiers.
constexpr
index_t
lane_group_size
=
8
;
template
<
class
ABDataType
>
struct
dpp_datatypes
;
__device__
index_t
get_lane_group_local_idx
()
{
return
threadIdx
.
x
/
lane_group_size
;
}
__device__
index_t
get_thread_idx_in_lane_group
()
{
return
threadIdx
.
x
%
lane_group_size
;
}
template
<
>
struct
dpp_datatypes
<
half_t
>
{
// Dot product of `half2_t` and `half2_t` to get `float`. Reducing 2 elements from K in a
// single instruction.
using
a_dtype
=
half_t
;
using
b_dtype
=
half_t
;
using
c_dtype
=
float
;
static
constexpr
index_t
k_per_instr
=
2
;
};
template
<
index_t
MPerThread
,
index_t
NPerThread
,
index_t
KPerThread
,
class
BaseInputType
,
class
AVecDataType
,
class
BVecDataType
,
class
CVecDataType
,
bool
ShareA
>
struct
DppLanegroupGemm
{
using
datatypes_conf
=
dpp_datatypes
<
BaseInputType
>
;
using
ADataType
=
typename
datatypes_conf
::
a_dtype
;
using
BDataType
=
typename
datatypes_conf
::
b_dtype
;
using
CDataType
=
typename
datatypes_conf
::
c_dtype
;
__device__
void
Run
(
const
AVecDataType
&
a_vec
,
const
BVecDataType
&
b_vec
,
CVecDataType
&
c_vec
)
{
constexpr
index_t
num_c_elems_per_thread
=
ShareA
?
MPerThread
:
NPerThread
;
const
vector_type
<
ADataType
,
KPerThread
>
a_vector
{
a_vec
};
const
vector_type
<
BDataType
,
KPerThread
>
b_vector
{
b_vec
};
static_for
<
0
,
num_c_elems_per_thread
,
1
>
{}([
&
](
auto
c_idx
)
{
float
c
=
c_vec
.
template
AsType
<
CDataType
>()(
c_idx
);
// Next `c_idx` implies that we need to pull data from the next lane.
constexpr
index_t
source_lane
=
c_idx
;
static_for
<
0
,
KPerThread
/
datatypes_conf
::
k_per_instr
,
1
>
{}([
&
](
auto
k_chunk
)
{
const
auto
a_k_vec
=
a_vector
.
template
AsType
<
AVecDataType
>()[
k_chunk
];
const
auto
b_k_vec
=
b_vector
.
template
AsType
<
BVecDataType
>()[
k_chunk
];
ck
::
dpp8
::
inner_product_dpp
<
AVecDataType
,
BVecDataType
,
CDataType
,
source_lane
,
ShareA
>
(
a_k_vec
,
b_k_vec
,
c
);
});
c_vec
.
template
AsType
<
CDataType
>()(
c_idx
)
=
c
;
});
}
};
}
// namespace dpp8
...
...
include/ck/utility/amd_xdlops.hpp
View file @
3c4fb1dd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_AMD_XDLOPS_HPP
#define CK_AMD_XDLOPS_HPP
#include "data_type.hpp"
#pragma once
namespace
ck
{
...
...
@@ -417,5 +414,194 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16>
#endif
}
};
}
// namespace ck
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x16bf8bf8
;
template
<
>
struct
intrin_mfma_f32_32x32x16bf8bf8
<
32
,
32
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
bf8x8_t
&
reg_a
,
const
bf8x8_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8
(
bit_cast
<
long
>
(
reg_a
),
bit_cast
<
long
>
(
reg_b
),
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
#else
vector_type
<
bf8_t
,
8
>
reg_a_v
(
reg_a
);
vector_type
<
bf8_t
,
8
>
reg_b_v
(
reg_b
);
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
k
)
{
float
reg_a_f32
=
type_convert
<
float
>
(
reg_a_v
.
template
AsType
<
bf8_t
>()[
Number
<
k
>
{}]);
float
reg_b_f32
=
type_convert
<
float
>
(
reg_b_v
.
template
AsType
<
bf8_t
>()[
Number
<
k
>
{}]);
intrin_mfma_f32_32x32x2f32
<
32
,
32
>::
Run
(
reg_a_f32
,
reg_b_f32
,
reg_c
);
});
#endif
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_16x16x32bf8bf8
;
template
<
>
struct
intrin_mfma_f32_16x16x32bf8bf8
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
bf8x8_t
&
reg_a
,
const
bf8x8_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8
(
bit_cast
<
long
>
(
reg_a
),
bit_cast
<
long
>
(
reg_b
),
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
#else
vector_type
<
bf8_t
,
8
>
reg_a_v
(
reg_a
);
vector_type
<
bf8_t
,
8
>
reg_b_v
(
reg_b
);
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
k
)
{
float
reg_a_f32
=
type_convert
<
float
>
(
reg_a_v
.
template
AsType
<
bf8_t
>()[
Number
<
k
>
{}]);
float
reg_b_f32
=
type_convert
<
float
>
(
reg_b_v
.
template
AsType
<
bf8_t
>()[
Number
<
k
>
{}]);
intrin_mfma_f32_16x16x4f32
<
16
,
16
>::
Run
(
reg_a_f32
,
reg_b_f32
,
reg_c
);
});
#endif
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x16f8bf8
;
template
<
>
struct
intrin_mfma_f32_32x32x16f8bf8
<
32
,
32
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
f8x8_t
&
reg_a
,
const
bf8x8_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8
(
bit_cast
<
long
>
(
reg_a
),
bit_cast
<
long
>
(
reg_b
),
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
#else
vector_type
<
f8_t
,
8
>
reg_a_v
(
reg_a
);
vector_type
<
bf8_t
,
8
>
reg_b_v
(
reg_b
);
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
k
)
{
float
reg_a_f32
=
type_convert
<
float
>
(
reg_a_v
.
template
AsType
<
f8_t
>()[
Number
<
k
>
{}]);
float
reg_b_f32
=
type_convert
<
float
>
(
reg_b_v
.
template
AsType
<
bf8_t
>()[
Number
<
k
>
{}]);
intrin_mfma_f32_32x32x2f32
<
32
,
32
>::
Run
(
reg_a_f32
,
reg_b_f32
,
reg_c
);
});
#endif
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_16x16x32f8bf8
;
template
<
>
struct
intrin_mfma_f32_16x16x32f8bf8
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
f8x8_t
&
reg_a
,
const
bf8x8_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8
(
bit_cast
<
long
>
(
reg_a
),
bit_cast
<
long
>
(
reg_b
),
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
#else
vector_type
<
f8_t
,
8
>
reg_a_v
(
reg_a
);
vector_type
<
bf8_t
,
8
>
reg_b_v
(
reg_b
);
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
k
)
{
float
reg_a_f32
=
type_convert
<
float
>
(
reg_a_v
.
template
AsType
<
f8_t
>()[
Number
<
k
>
{}]);
float
reg_b_f32
=
type_convert
<
float
>
(
reg_b_v
.
template
AsType
<
bf8_t
>()[
Number
<
k
>
{}]);
intrin_mfma_f32_16x16x4f32
<
16
,
16
>::
Run
(
reg_a_f32
,
reg_b_f32
,
reg_c
);
});
#endif
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x16bf8f8
;
template
<
>
struct
intrin_mfma_f32_32x32x16bf8f8
<
32
,
32
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
bf8x8_t
&
reg_a
,
const
f8x8_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8
(
bit_cast
<
long
>
(
reg_a
),
bit_cast
<
long
>
(
reg_b
),
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
#else
vector_type
<
bf8_t
,
8
>
reg_a_v
(
reg_a
);
vector_type
<
f8_t
,
8
>
reg_b_v
(
reg_b
);
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
k
)
{
float
reg_a_f32
=
type_convert
<
float
>
(
reg_a_v
.
template
AsType
<
bf8_t
>()[
Number
<
k
>
{}]);
float
reg_b_f32
=
type_convert
<
float
>
(
reg_b_v
.
template
AsType
<
f8_t
>()[
Number
<
k
>
{}]);
intrin_mfma_f32_32x32x2f32
<
32
,
32
>::
Run
(
reg_a_f32
,
reg_b_f32
,
reg_c
);
});
#endif
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_16x16x32bf8f8
;
template
<
>
struct
intrin_mfma_f32_16x16x32bf8f8
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
bf8x8_t
&
reg_a
,
const
f8x8_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8
(
bit_cast
<
long
>
(
reg_a
),
bit_cast
<
long
>
(
reg_b
),
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
#else
vector_type
<
bf8_t
,
8
>
reg_a_v
(
reg_a
);
vector_type
<
f8_t
,
8
>
reg_b_v
(
reg_b
);
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
k
)
{
float
reg_a_f32
=
type_convert
<
float
>
(
reg_a_v
.
template
AsType
<
bf8_t
>()[
Number
<
k
>
{}]);
float
reg_b_f32
=
type_convert
<
float
>
(
reg_b_v
.
template
AsType
<
f8_t
>()[
Number
<
k
>
{}]);
intrin_mfma_f32_16x16x4f32
<
16
,
16
>::
Run
(
reg_a_f32
,
reg_b_f32
,
reg_c
);
});
#endif
}
};
}
// namespace ck
include/ck/utility/data_type.hpp
View file @
3c4fb1dd
...
...
@@ -24,10 +24,9 @@ using std::byte;
using
bhalf_t
=
ushort
;
using
half_t
=
_Float16
;
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
using
int4_t
=
_BitInt
(
4
);
#endif
using
f8_t
=
u
int8_t
;
using
f8_t
=
_BitInt
(
8
);
using
b
f8_t
=
u
nsigned
_BitInt
(
8
)
;
// vector_type
template
<
typename
T
,
index_t
N
>
...
...
@@ -165,7 +164,13 @@ struct scalar_type<f8_t>
static
constexpr
index_t
vector_size
=
1
;
};
//
template
<
>
struct
scalar_type
<
bf8_t
>
{
using
type
=
bf8_t
;
static
constexpr
index_t
vector_size
=
1
;
};
template
<
typename
T
>
struct
vector_type
<
T
,
1
>
{
...
...
@@ -975,6 +980,14 @@ using f8x16_t = typename vector_type<f8_t, 16>::type;
using
f8x32_t
=
typename
vector_type
<
f8_t
,
32
>::
type
;
using
f8x64_t
=
typename
vector_type
<
f8_t
,
64
>::
type
;
// bf8
using
bf8x2_t
=
typename
vector_type
<
bf8_t
,
2
>::
type
;
using
bf8x4_t
=
typename
vector_type
<
bf8_t
,
4
>::
type
;
using
bf8x8_t
=
typename
vector_type
<
bf8_t
,
8
>::
type
;
using
bf8x16_t
=
typename
vector_type
<
bf8_t
,
16
>::
type
;
using
bf8x32_t
=
typename
vector_type
<
bf8_t
,
32
>::
type
;
using
bf8x64_t
=
typename
vector_type
<
bf8_t
,
64
>::
type
;
template
<
typename
T
>
struct
NumericLimits
;
...
...
@@ -1100,18 +1113,103 @@ struct NumericLimits<int4_t>
template
<
>
struct
NumericLimits
<
f8_t
>
{
// negative zero nan mode with exp bias = 8
static
constexpr
uint8_t
binary_min
=
0x08
;
// 0b00001000
static
constexpr
uint8_t
binary_max
=
0x7
7
;
// 0b0111
0
111
static
constexpr
uint8_t
binary_lowest
=
0xF
7
;
// 0b1111
0
111
static
constexpr
uint8_t
binary_max
=
0x7
F
;
// 0b0111
1
111
static
constexpr
uint8_t
binary_lowest
=
0xF
F
;
// 0b1111
1
111
static
constexpr
uint8_t
binary_qnan
=
0x80
;
// 0b10000000
// ieee mode with exp bias = 7
// static constexpr uint8_t binary_min = 0x08; // 0b00001000
// static constexpr uint8_t binary_max = 0x77; // 0b01110111
// static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=0
__host__
__device__
static
constexpr
f8_t
Min
()
{
return
bit_cast
<
f8_t
>
(
binary_min
);
}
__host__
__device__
static
constexpr
f8_t
Min
()
{
return
f8_t
(
binary_min
);
}
__host__
__device__
static
constexpr
f8_t
Max
()
{
return
bit_cast
<
f8_t
>
(
binary_max
);
}
__host__
__device__
static
constexpr
f8_t
Max
()
{
return
f8_t
(
binary_max
);
}
__host__
__device__
static
constexpr
f8_t
Lowest
()
{
return
bit_cast
<
f8_t
>
(
binary_lowest
);
}
__host__
__device__
static
constexpr
f8_t
Lowest
()
{
return
f8_t
(
binary_lowest
);
}
__host__
__device__
static
constexpr
f8_t
QuietNaN
()
{
return
bit_cast
<
f8_t
>
(
binary_qnan
);
}
__host__
__device__
static
constexpr
f8_t
QuietNaN
()
{
return
f8_t
(
binary_qnan
);
}
};
template
<
>
struct
NumericLimits
<
bf8_t
>
{
// negative zero nan mode with exp bias = 16
static
constexpr
uint8_t
binary_min
=
0x04
;
// 0b00000100
static
constexpr
uint8_t
binary_max
=
0x7F
;
// 0b01111111
static
constexpr
uint8_t
binary_lowest
=
0xFF
;
// 0b11111111
static
constexpr
uint8_t
binary_qnan
=
0x80
;
// 0b10000000
// ieee mode with exp bias = 15
// static constexpr uint8_t binary_min = 0x04; // 0b00000100
// static constexpr uint8_t binary_max = 0x7B; // 0b01111011
// static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=
__host__
__device__
static
constexpr
bf8_t
Min
()
{
return
bf8_t
(
binary_min
);
}
__host__
__device__
static
constexpr
bf8_t
Max
()
{
return
bf8_t
(
binary_max
);
}
__host__
__device__
static
constexpr
bf8_t
Lowest
()
{
return
bf8_t
(
binary_lowest
);
}
__host__
__device__
static
constexpr
bf8_t
QuietNaN
()
{
return
bf8_t
(
binary_qnan
);
}
};
template
<
typename
T
>
struct
NumericUtils
{
};
template
<
>
struct
NumericUtils
<
float
>
{
static
constexpr
int
exp
=
8
;
static
constexpr
int
mant
=
23
;
static
constexpr
int
bias
=
127
;
static
constexpr
uint32_t
nan_mask
=
0x7F800000
;
static
constexpr
uint32_t
head_mask
=
0xFF800000
;
static
constexpr
uint32_t
mant_mask
=
0x7FFFFF
;
static
constexpr
uint32_t
exp_mask
=
0xFF
;
static
constexpr
uint32_t
Inf
=
0x7F800000
;
static
constexpr
uint32_t
NegInf
=
0xFF800000
;
static
constexpr
uint32_t
NaN
=
0x7F800001
;
static
constexpr
uint32_t
Neg0
=
0x80000000
;
using
bitwise_type
=
uint32_t
;
};
template
<
>
struct
NumericUtils
<
half_t
>
{
static
constexpr
int
exp
=
5
;
static
constexpr
int
mant
=
10
;
static
constexpr
int
bias
=
15
;
static
constexpr
uint16_t
nan_mask
=
0x7C00
;
static
constexpr
uint16_t
head_mask
=
0xFC00
;
static
constexpr
uint16_t
mant_mask
=
0x3FF
;
static
constexpr
uint16_t
exp_mask
=
0x1F
;
static
constexpr
uint32_t
Inf
=
0x7C00
;
static
constexpr
uint32_t
NegInf
=
0xFC00
;
static
constexpr
uint32_t
NaN
=
0x7C01
;
static
constexpr
uint32_t
Neg0
=
0x8000
;
using
bitwise_type
=
uint16_t
;
};
template
<
>
struct
NumericUtils
<
f8_t
>
{
static
constexpr
int
exp
=
4
;
static
constexpr
int
mant
=
3
;
static
constexpr
int
bias
=
8
;
// negative zero nan mode
// static constexpr int bias = 7; // ieee mode
};
template
<
>
struct
NumericUtils
<
bf8_t
>
{
static
constexpr
int
exp
=
5
;
static
constexpr
int
mant
=
2
;
static
constexpr
int
bias
=
16
;
// negative zero nan mode
// static constexpr int bias = 15; // ieee mode
};
}
// namespace ck
include/ck/utility/dynamic_buffer.hpp
View file @
3c4fb1dd
...
...
@@ -141,9 +141,35 @@ struct DynamicBuffer
else
if
constexpr
(
Op
==
InMemoryDataOperationEnum
::
Add
)
{
auto
tmp
=
this
->
template
Get
<
X
>(
i
,
is_valid_element
);
using
scalar_t
=
typename
scalar_type
<
remove_cvref_t
<
T
>>::
type
;
// handle bfloat addition
if
constexpr
(
is_same_v
<
scalar_t
,
bhalf_t
>
)
{
if
constexpr
(
is_scalar_type
<
X
>::
value
)
{
// Scalar type
auto
result
=
type_convert
<
X
>
(
type_convert
<
float
>
(
x
)
+
type_convert
<
float
>
(
tmp
));
this
->
template
Set
<
X
>(
i
,
is_valid_element
,
result
);
}
else
{
// Vector type
constexpr
auto
vector_size
=
scalar_type
<
remove_cvref_t
<
X
>>::
vector_size
;
const
vector_type
<
scalar_t
,
vector_size
>
a_vector
{
tmp
};
const
vector_type
<
scalar_t
,
vector_size
>
b_vector
{
x
};
static_for
<
0
,
vector_size
,
1
>
{}([
&
](
auto
idx
)
{
auto
result
=
type_convert
<
scalar_t
>
(
type_convert
<
float
>
(
a_vector
.
template
AsType
<
scalar_t
>()[
idx
])
+
type_convert
<
float
>
(
b_vector
.
template
AsType
<
scalar_t
>()[
idx
]));
this
->
template
Set
<
scalar_t
>(
i
+
idx
,
is_valid_element
,
result
);
});
}
}
else
{
this
->
template
Set
<
X
>(
i
,
is_valid_element
,
x
+
tmp
);
// tmp += x;
// this->template Set<X>(i, is_valid_element, tmp);
}
}
}
...
...
include/ck/utility/f8_utils.hpp
View file @
3c4fb1dd
...
...
@@ -16,59 +16,46 @@ enum class f8_rounding_mode
stochastic
};
__host__
inline
int
clz
(
uint32_t
x
)
{
return
__builtin_clz
(
x
);
}
__device__
inline
int
clz
(
uint32_t
x
)
{
return
__clz
(
x
);
}
}
// namespace ck
namespace
ck
::
utils
{
namespace
{
template
<
typename
T
,
bool
negative_zero_nan
,
bool
clip
,
bool
stoch
>
__host__
__device__
f8_t
run_cast_to_f8
(
T
x
,
uint32_t
rng
)
template
<
typename
X
,
typename
Y
,
bool
negative_zero_nan
,
bool
clip
,
bool
stoch
>
__host__
__device__
Y
run_cast_to_f8
(
X
x
,
uint32_t
rng
)
{
// check data type
constexpr
bool
is_half
=
std
::
is_same
<
T
,
half_t
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
T
,
float
>::
value
;
// fp8 exponent/mantissa layout
constexpr
int
f8_exp
=
4
;
constexpr
int
f8_mant
=
3
;
// fp8/bf8 exponent/mantissa layout
constexpr
int
out_exp
=
NumericUtils
<
Y
>::
exp
;
constexpr
int
out_mant
=
NumericUtils
<
Y
>::
mant
;
//
resulting
type exponent/mantissa layout
constexpr
int
type
_exp
=
is_half
?
5
:
8
;
constexpr
int
type
_mant
=
is_half
?
10
:
23
;
//
original
type exponent/mantissa layout
constexpr
int
in
_exp
=
NumericUtils
<
X
>::
exp
;
constexpr
int
in
_mant
=
NumericUtils
<
X
>::
mant
;
int
exponent
;
int
exponent
,
bias
;
uint32_t
head
,
mantissa
,
sign
;
// nan code is same for float and half
constexpr
uint8_t
nan_code
=
0x80
;
constexpr
uint32_t
nan_mask
=
is_half
?
0x7C00
:
0x7F800000
;
constexpr
Y
nan_code
=
0x80
;
constexpr
uint32_t
nan_mask
=
NumericUtils
<
X
>::
nan_mask
;
// convert to bitwise
typedef
typename
ck
::
conditional
<
std
::
is_same
<
T
,
half_t
>::
value
,
uint16_t
,
uint32_t
>::
type
T_bitwise
;
using
T_bitwise
=
typename
NumericUtils
<
X
>::
bitwise_type
;
T_bitwise
x_bitwise
=
*
(
reinterpret_cast
<
T_bitwise
*>
(
&
x
));
// unpack the input, depends on datatype
if
constexpr
(
is_float
)
{
head
=
x_bitwise
&
0xFF800000
;
mantissa
=
x_bitwise
&
0x7FFFFF
;
exponent
=
(
head
>>
type_mant
)
&
0xFF
;
sign
=
head
>>
(
type_exp
+
type_mant
);
}
else
if
constexpr
(
is_half
)
{
head
=
x_bitwise
&
0xFC00
;
mantissa
=
x_bitwise
&
0x3FF
;
exponent
=
(
head
>>
type_mant
)
&
0x1F
;
sign
=
head
>>
(
type_exp
+
type_mant
);
}
head
=
x_bitwise
&
NumericUtils
<
X
>::
head_mask
;
mantissa
=
x_bitwise
&
NumericUtils
<
X
>::
mant_mask
;
exponent
=
(
head
>>
in_mant
)
&
NumericUtils
<
X
>::
exp_mask
;
sign
=
head
>>
(
in_exp
+
in_mant
);
bias
=
NumericUtils
<
X
>::
bias
;
uint32_t
signed_inf
=
(
sign
<<
(
type_exp
+
type_mant
))
+
(((
1
<<
type_exp
)
-
1
)
<<
type_mant
);
uint32_t
drop_mask
=
(
1
<<
(
type_mant
-
f8_mant
))
-
1
;
constexpr
int
max_exp
=
(
1
<<
f8_exp
)
-
(
negative_zero_nan
?
1
:
2
);
constexpr
int
exp_low_cutoff
=
(
1
<<
(
type_exp
-
1
))
-
(
1
<<
(
f8_exp
-
1
))
+
1
-
(
negative_zero_nan
?
1
:
0
);
uint32_t
signed_inf
=
(
sign
<<
(
in_exp
+
in_mant
))
+
(((
1
<<
in_exp
)
-
1
)
<<
in_mant
);
uint32_t
drop_mask
=
(
1
<<
(
in_mant
-
out_mant
))
-
1
;
constexpr
int
max_exp
=
(
1
<<
out_exp
)
-
(
negative_zero_nan
?
1
:
2
);
if
constexpr
(
negative_zero_nan
)
{
...
...
@@ -85,39 +72,103 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
if
(
x_bitwise
==
0
)
return
0
;
exponent
-=
exp_low_cutoff
-
1
;
if
(
exponent
<=
0
)
drop_mask
=
(
1
<<
(
type_mant
-
f8_mant
+
1
-
exponent
))
-
1
;
mantissa
+=
1
<<
type_mant
;
// apply random number if needed
mantissa
+=
(
stoch
?
rng
:
mantissa
)
&
drop_mask
;
if
(
mantissa
>=
(
2
<<
type_mant
))
// First need to check if it is normal or denorm as there is a difference of implict 1
// Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift
// The mantissa. Then for stochastic rounding, add rng to mantissa and truncate. And for
// RNE, no need to add rng. Then probably need to check whether there is carry and adjust
// exponent and mantissa again3
// For IEEE bias mode, the bias is 2^(k-1)-1 where k is the width of exponent bits
const
int
out_bias
=
(
1
<<
(
out_exp
-
1
))
-
1
+
(
negative_zero_nan
?
1
:
0
);
const
int
out_denormal_act_exponent
=
1
-
out_bias
;
// actual exponent of f8 denormal
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
// out_exponent is the converted f8 exponent with bias encoding
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
// the difference needs to be adjusted and mantissa shifted
int
act_exponent
,
out_exponent
,
exponent_diff
;
if
(
exponent
==
0
)
{
// fp32/fp16 is in denormal.
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we mostly concern fp16
here. In this case, f8 is usually in denormal. But there could be exceptions. fp16 denormal has
exponent bias 15 while bf8 with NANOO has exponent bias 16. It means that there are some numbers in
fp16 denormal but they are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 (NANOO) normal.
In this case, the fp16 mantissa should be shift left by 1 */
act_exponent
=
exponent
-
bias
+
1
;
exponent_diff
=
out_denormal_act_exponent
-
act_exponent
;
// actual exponent is exponent-bias+1 as it is denormal
}
else
{
// fp32/fp16 is normal with implicit 1
act_exponent
=
exponent
-
bias
;
if
(
act_exponent
<=
out_denormal_act_exponent
)
{
mantissa
>>=
1
;
exponent
++
;
/* This is the case where fp32/fp16 is normal but it is in f8 denormal range.
For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16
actual exponent is -7, it is actually larger due to the implict 1,
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
exponent_diff
=
out_denormal_act_exponent
-
act_exponent
;
}
else
{
// both fp32/fp16 and f8 are in normal range
exponent_diff
=
0
;
// exponent_diff=0 does not mean there is no difference for this case,
// act_exponent could be larger. Just that it does not need shift mantissa
}
mantissa
+=
(
1
<<
in_mant
);
// Add the implicit 1 into mantissa
}
mantissa
>>=
(
type_mant
-
f8_mant
);
// check negative exponent
if
(
exponent
<=
0
)
bool
midpoint
=
(
mantissa
&
((
1
<<
(
in_mant
-
out_mant
+
exponent_diff
))
-
1
))
==
(
1
<<
(
in_mant
-
out_mant
+
exponent_diff
-
1
));
/* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we
shift right as shift right could rip off some residual part and make something not midpoint look
like midpoint. For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than
midpoint, but after shift right by 4 bits, it would look like midpoint. */
if
(
exponent_diff
>
0
)
mantissa
>>=
exponent_diff
;
else
if
(
exponent_diff
==
-
1
)
mantissa
<<=
-
exponent_diff
;
bool
implicit_one
=
mantissa
&
(
1
<<
in_mant
);
// if there is no implict 1, it means the f8 is denormal and need to adjust to denorm exponent
out_exponent
=
(
act_exponent
+
exponent_diff
)
/*actual f8 exponent*/
+
out_bias
-
(
implicit_one
?
0
:
1
);
// Now we have the exponent and mantissa adjusted
bool
odd
=
mantissa
&
(
1
<<
(
in_mant
-
out_mant
));
// if the least significant bit that is not truncated is 1
mantissa
+=
(
stoch
?
rng
:
(
midpoint
?
(
odd
?
mantissa
:
mantissa
-
1
)
:
mantissa
))
&
drop_mask
;
// Now we deal with overflow
if
(
out_exponent
==
0
)
{
if
(
x_bitwise
==
0
)
return
0
;
if
((
1
<<
in_mant
)
&
mantissa
)
{
out_exponent
=
1
;
// denormal overflow to become normal, promote exponent
// No need to make 1 implicit now as it will be addressed later
}
}
else
{
// subnormal range; represented by a subnormal float8 (exponent 0)
// and involves loss of accuracy
mantissa
>>=
1
-
exponent
;
exponent
=
0
;
if
((
1
<<
(
in_mant
+
1
))
&
mantissa
)
{
mantissa
>>=
1
;
out_exponent
++
;
// No need to make 1 implicit now as it will be addressed later
}
}
// above range: quantize to maximum possible float of the same sign
else
if
(
exponent
>
max_exp
)
mantissa
>>=
(
in_mant
-
out_mant
);
if
(
out_exponent
>
max_exp
)
{
if
(
clip
)
{
mantissa
=
(
1
<<
f8
_mant
)
-
1
;
exponent
=
max_exp
;
mantissa
=
(
1
<<
out
_mant
)
-
1
;
out_
exponent
=
max_exp
;
}
else
{
...
...
@@ -126,125 +177,117 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
}
// check if x is 0.0 or -0.0
if
(
exponent
==
0
&&
mantissa
==
0
)
return
negative_zero_nan
?
0
:
(
sign
<<
(
f8
_exp
+
f8
_mant
));
mantissa
&=
(
1
<<
f8
_mant
)
-
1
;
return
(
sign
<<
(
f8
_exp
+
f8
_mant
))
|
(
exponent
<<
f8
_mant
)
|
mantissa
;
if
(
out_
exponent
==
0
&&
mantissa
==
0
)
return
negative_zero_nan
?
0
:
(
sign
<<
(
out
_exp
+
out
_mant
));
mantissa
&=
(
1
<<
out
_mant
)
-
1
;
return
(
sign
<<
(
out
_exp
+
out
_mant
))
|
(
out_
exponent
<<
out
_mant
)
|
mantissa
;
}
template
<
typename
T
,
bool
negative_zero_nan
>
__host__
__device__
T
run_cast_from_f8
(
f8_t
x
)
template
<
typename
X
,
typename
Y
,
bool
negative_zero_nan
>
__host__
__device__
Y
run_cast_from_f8
(
X
x
)
{
// check data type
constexpr
bool
is_half
=
std
::
is_same
<
T
,
half_t
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
T
,
float
>::
value
;
// fp8 exponent/mantissa layout
constexpr
int
f8_exp
=
4
;
constexpr
int
f8_mant
=
3
;
// fp8/bf8 exponent/mantissa layout
constexpr
int
in_exp
=
NumericUtils
<
X
>::
exp
;
constexpr
int
in_mant
=
NumericUtils
<
X
>::
mant
;
// resulting type exponent/mantissa layout
constexpr
int
type
_exp
=
is_half
?
5
:
8
;
constexpr
int
type
_mant
=
is_half
?
10
:
23
;
constexpr
int
out
_exp
=
NumericUtils
<
Y
>::
exp
;
constexpr
int
out
_mant
=
NumericUtils
<
Y
>::
mant
;
// prepare the codes
constexpr
uint8_t
nan_code
=
0x80
;
T
fInf
,
fNegInf
,
fNaN
,
fNeg0
;
if
constexpr
(
is_half
)
{
constexpr
uint16_t
ihInf
=
0x7C00
;
constexpr
uint16_t
ihNegInf
=
0xFC00
;
constexpr
uint16_t
ihNaN
=
0x7C01
;
constexpr
uint16_t
ihNeg0
=
0x8000
;
fInf
=
*
(
reinterpret_cast
<
const
half_t
*>
(
&
ihInf
));
fNegInf
=
*
(
reinterpret_cast
<
const
half_t
*>
(
&
ihNegInf
));
fNaN
=
*
(
reinterpret_cast
<
const
half_t
*>
(
&
ihNaN
));
fNeg0
=
*
(
reinterpret_cast
<
const
half_t
*>
(
&
ihNeg0
));
}
else
if
constexpr
(
is_float
)
{
constexpr
uint32_t
ifInf
=
0x7F800000
;
constexpr
uint32_t
ifNegInf
=
0xFF800000
;
constexpr
uint32_t
ifNaN
=
0x7F800001
;
constexpr
uint32_t
ifNeg0
=
0x80000000
;
fInf
=
*
(
reinterpret_cast
<
const
float
*>
(
&
ifInf
));
fNegInf
=
*
(
reinterpret_cast
<
const
float
*>
(
&
ifNegInf
));
fNaN
=
*
(
reinterpret_cast
<
const
float
*>
(
&
ifNaN
));
fNeg0
=
*
(
reinterpret_cast
<
const
float
*>
(
&
ifNeg0
));
}
constexpr
X
nan_code
=
0x80
;
Y
Inf
,
NegInf
,
NaN
,
Neg0
;
using
T_bitwise
=
typename
NumericUtils
<
Y
>::
bitwise_type
;
constexpr
T_bitwise
Inf_bitwise
=
NumericUtils
<
Y
>::
Inf
;
constexpr
T_bitwise
NegInf_bitwise
=
NumericUtils
<
Y
>::
NegInf
;
constexpr
T_bitwise
NaN_bitwise
=
NumericUtils
<
Y
>::
NaN
;
constexpr
T_bitwise
Neg0_bitwise
=
NumericUtils
<
Y
>::
Neg0
;
Inf
=
*
(
reinterpret_cast
<
const
Y
*>
(
&
Inf_bitwise
));
NegInf
=
*
(
reinterpret_cast
<
const
Y
*>
(
&
NegInf_bitwise
));
NaN
=
*
(
reinterpret_cast
<
const
Y
*>
(
&
NaN_bitwise
));
Neg0
=
*
(
reinterpret_cast
<
const
Y
*>
(
&
Neg0_bitwise
));
// check if x is 0.0
if
(
x
==
0
)
return
static_cast
<
Y
>
(
0
);
// unpack the input
uint32_t
sign
=
x
>>
(
f8
_exp
+
f8
_mant
);
uint32_t
mantissa
=
x
&
((
1
<<
f8
_mant
)
-
1
);
int
exponent
=
(
x
&
0x7F
)
>>
f8
_mant
;
uint32_t
sign
=
x
>>
(
in
_exp
+
in
_mant
);
uint32_t
mantissa
=
x
&
((
1
<<
in
_mant
)
-
1
);
int
exponent
=
(
x
&
0x7F
)
>>
in
_mant
;
constexpr
int
exp_low_cutoff
=
(
1
<<
(
type
_exp
-
1
))
-
(
1
<<
(
f8
_exp
-
1
))
+
1
-
(
negative_zero_nan
?
1
:
0
);
typename
ck
::
conditional
<
std
::
is_same
<
T
,
half_t
>::
value
,
uint16_t
,
uint32_t
>::
typ
e
retval
;
(
1
<<
(
out
_exp
-
1
))
-
(
1
<<
(
in
_exp
-
1
))
+
1
-
(
negative_zero_nan
?
1
:
0
);
T_bitwis
e
retval
;
if
constexpr
(
negative_zero_nan
)
{
if
(
x
==
nan_code
)
return
f
NaN
;
return
NaN
;
}
else
{
if
(
x
==
nan_code
)
return
fNeg0
;
if
(
exponent
==
((
1
<<
f8_exp
)
-
1
))
return
(
mantissa
==
0
)
?
(
sign
?
fNegInf
:
fInf
)
:
fNaN
;
return
Neg0
;
if
(
exponent
==
((
1
<<
in_exp
)
-
1
))
return
(
mantissa
==
0
)
?
(
sign
?
NegInf
:
Inf
)
:
NaN
;
}
if
((
NumericUtils
<
Y
>::
mant
==
10
)
&&
(
NumericUtils
<
X
>::
mant
==
2
)
&&
!
negative_zero_nan
)
{
retval
=
x
;
retval
<<=
8
;
return
*
(
reinterpret_cast
<
const
Y
*>
(
&
retval
));
}
// subnormal input
if
(
exponent
==
0
)
{
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
int
sh
=
1
+
__builtin_
clz
(
mantissa
)
-
(
(
1
+
type_exp
+
type_mant
)
-
f8
_mant
);
int
sh
=
1
+
clz
(
mantissa
)
-
(
32
-
in
_mant
);
mantissa
<<=
sh
;
mantissa
&=
((
1
<<
f8_mant
)
-
1
);
exponent
+=
1
-
sh
;
mantissa
&=
((
1
<<
in_mant
)
-
1
);
}
exponent
+=
exp_low_cutoff
-
1
;
mantissa
<<=
type
_mant
-
f8
_mant
;
mantissa
<<=
out
_mant
-
in
_mant
;
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
if
(
exponent
<=
0
)
{
mantissa
|=
1
<<
type
_mant
;
mantissa
|=
1
<<
out
_mant
;
mantissa
>>=
1
-
exponent
;
exponent
=
0
;
}
retval
=
(
sign
<<
(
type
_exp
+
type
_mant
))
|
(
exponent
<<
type
_mant
)
|
mantissa
;
return
*
(
reinterpret_cast
<
const
T
*>
(
&
retval
));
retval
=
(
sign
<<
(
out
_exp
+
out
_mant
))
|
(
exponent
<<
out
_mant
)
|
mantissa
;
return
*
(
reinterpret_cast
<
const
Y
*>
(
&
retval
));
}
}
// namespace
template
<
typename
T
,
bool
negative_zero_nan
,
bool
clip
,
bool
stoch
>
__host__
__device__
f8_t
cast_to_f8
(
T
x
,
uint32_t
rng
)
template
<
typename
X
,
typename
Y
,
bool
negative_zero_nan
,
bool
clip
,
bool
stoch
>
__host__
__device__
Y
cast_to_f8
(
X
x
,
uint32_t
rng
)
{
// check datatype
constexpr
bool
is_half
=
std
::
is_same
<
T
,
half_t
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
T
,
float
>::
value
;
static_assert
(
is_half
||
is_float
,
"Only half and float can be casted
to f8
."
);
// check datatype
s
constexpr
bool
is_half
=
std
::
is_same
<
X
,
half_t
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
X
,
float
>::
value
;
static_assert
(
is_half
||
is_float
,
"Only half and float can be casted."
);
return
run_cast_to_f8
<
T
,
negative_zero_nan
,
clip
,
stoch
>
(
x
,
rng
);
return
run_cast_to_f8
<
X
,
Y
,
negative_zero_nan
,
clip
,
stoch
>
(
x
,
rng
);
}
template
<
typename
T
,
bool
negative_zero_nan
>
__host__
__device__
T
cast_from_f8
(
f8_t
x
)
template
<
typename
X
,
typename
Y
,
bool
negative_zero_nan
>
__host__
__device__
Y
cast_from_f8
(
X
x
)
{
// check datatype
constexpr
bool
is_half
=
std
::
is_same
<
T
,
half_t
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
T
,
float
>::
value
;
constexpr
bool
is_half
=
std
::
is_same
<
Y
,
half_t
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
Y
,
float
>::
value
;
static_assert
(
is_half
||
is_float
,
"only half and float are supported."
);
// check if x is 0.0
if
(
x
==
0
)
return
static_cast
<
T
>
(
0
);
return
run_cast_from_f8
<
T
,
negative_zero_nan
>
(
x
);
return
run_cast_from_f8
<
X
,
Y
,
negative_zero_nan
>
(
x
);
}
}
// namespace ck::utils
include/ck/utility/inner_product.hpp
View file @
3c4fb1dd
...
...
@@ -72,6 +72,18 @@ inner_product<float4_t, float4_t, float>(const float4_t& a, const float4_t& b, f
c
);
}
template
<
>
__device__
void
inner_product
<
bhalf_t
,
bhalf_t
,
float
>
(
const
bhalf_t
&
a
,
const
bhalf_t
&
b
,
float
&
c
)
{
inner_product
(
type_convert
<
float
>
(
a
),
type_convert
<
float
>
(
b
),
c
);
}
template
<
>
__device__
void
inner_product
<
half_t
,
half_t
,
float
>
(
const
half_t
&
a
,
const
half_t
&
b
,
float
&
c
)
{
inner_product
(
type_convert
<
float
>
(
a
),
type_convert
<
float
>
(
b
),
c
);
}
template
<
>
__device__
void
inner_product
<
half2_t
,
half2_t
,
float
>
(
const
half2_t
&
a
,
const
half2_t
&
b
,
float
&
c
)
{
...
...
@@ -180,6 +192,8 @@ inner_product<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a, const int8x4_t& b,
#else
c
=
__builtin_amdgcn_sdot4
(
bit_cast
<
int32_t
>
(
a
),
bit_cast
<
int32_t
>
(
b
),
c
,
false
);
#endif
#elif defined(CK_USE_AMD_V_DOT4_I32_I8_GFX11)
c
=
__builtin_amdgcn_sudot4
(
true
,
bit_cast
<
int32_t
>
(
a
),
true
,
bit_cast
<
int32_t
>
(
b
),
c
,
false
);
#else
const
vector_type
<
int8_t
,
4
>
a_vector
{
a
};
const
vector_type
<
int8_t
,
4
>
b_vector
{
b
};
...
...
include/ck/utility/inner_product_dpp8.hpp
View file @
3c4fb1dd
...
...
@@ -2,6 +2,7 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "amd_gemm_dpp.hpp"
#include "data_type.hpp"
#include "type_convert.hpp"
...
...
@@ -10,6 +11,9 @@ namespace ck {
namespace
dpp8
{
/// Number of lanes that can share data using DPP8 modifiers.
constexpr
index_t
lane_group_size
=
8
;
template
<
int
SrcLaneIdx
>
__device__
void
inline_v_dot2c_dpp8_instr
(
const
half2_t
&
a
,
const
half2_t
&
b
,
float
&
c
);
...
...
include/ck/utility/is_detected.hpp
0 → 100644
View file @
3c4fb1dd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace
ck
{
namespace
detail
{
template
<
class
Default
,
class
AlwaysVoid
,
template
<
class
...
>
class
Op
,
class
...
Args
>
struct
detector
{
using
value_t
=
std
::
false_type
;
using
type
=
Default
;
};
template
<
class
Default
,
template
<
class
...
>
class
Op
,
class
...
Args
>
struct
detector
<
Default
,
std
::
void_t
<
Op
<
Args
...
>>
,
Op
,
Args
...
>
{
using
value_t
=
std
::
true_type
;
using
type
=
Op
<
Args
...
>
;
};
}
// namespace detail
struct
nonesuch
{
~
nonesuch
()
=
delete
;
nonesuch
(
nonesuch
const
&
)
=
delete
;
void
operator
=
(
nonesuch
const
&
)
=
delete
;
};
template
<
template
<
class
...
>
class
Op
,
class
...
Args
>
using
is_detected
=
typename
detail
::
detector
<
nonesuch
,
void
,
Op
,
Args
...
>::
value_t
;
template
<
typename
T
>
using
is_pack2_invocable_t
=
decltype
(
std
::
declval
<
T
&>
().
is_pack2_invocable
);
template
<
typename
T
>
using
is_pack4_invocable_t
=
decltype
(
std
::
declval
<
T
&>
().
is_pack4_invocable
);
template
<
typename
T
>
using
is_pack8_invocable_t
=
decltype
(
std
::
declval
<
T
&>
().
is_pack8_invocable
);
}
// namespace ck
include/ck/utility/loop_scheduler.hpp
0 → 100644
View file @
3c4fb1dd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
namespace
ck
{
enum
struct
LoopScheduler
{
Default
,
Interwave
,
};
constexpr
LoopScheduler
make_default_loop_scheduler
()
{
#if CK_EXPERIMENTAL_DEFAULT_TO_INTER_WAVE_SCHEDULING
return
LoopScheduler
::
Interwave
;
#else
return
LoopScheduler
::
Default
;
#endif // if CK_EXPERIMENTAL_DEFAULT_TO_INTER_WAVE_SCHEDULING
}
}
// namespace ck
include/ck/utility/math.hpp
View file @
3c4fb1dd
...
...
@@ -150,30 +150,6 @@ __host__ __device__ constexpr T clamp(const T& x, const T& lowerbound, const T&
return
min
(
max
(
x
,
lowerbound
),
upperbound
);
}
// disallow implicit type casting
template
<
typename
T
>
__device__
T
exp
(
T
x
);
// TODO: add f16 support using v_exp_f16
template
<
>
__device__
float
exp
<
float
>
(
float
x
)
{
return
__expf
(
x
);
}
template
<
>
__device__
double
exp
<
double
>
(
double
x
)
{
return
exp
(
x
);
}
#ifndef __HIPCC_RTC__
static
inline
__host__
float
exp
(
float
x
)
{
return
::
expf
(
x
);
}
static
inline
__host__
double
exp
(
double
x
)
{
return
std
::
exp
(
x
);
}
#endif
// greatest common divisor, aka highest common factor
__host__
__device__
constexpr
index_t
gcd
(
index_t
x
,
index_t
y
)
{
...
...
Prev
1
…
13
14
15
16
17
18
19
20
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment