Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
e1a5137e
Unverified
Commit
e1a5137e
authored
Sep 19, 2023
by
arai713
Committed by
GitHub
Sep 19, 2023
Browse files
Merge branch 'develop' into transpose_5d
parents
eb57178d
718065eb
Changes
371
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1190 additions
and
378 deletions
+1190
-378
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp
+12
-16
include/ck/tensor_operation/gpu/grid/gridwise_image_to_column.hpp
...ck/tensor_operation/gpu/grid/gridwise_image_to_column.hpp
+97
-0
include/ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm.hpp
...gpu/grid/gridwise_sparse_embeddings_forward_layernorm.hpp
+2
-2
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_1st.hpp
.../grid/normalization/gridwise_normalization_splitk_1st.hpp
+3
-3
include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl_dpp8.hpp
...r_operation/gpu/thread/threadwise_contraction_dl_dpp8.hpp
+0
-136
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+4
-5
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
...tion/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
+3
-0
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp
...tion/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp
+2
-2
include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp
include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp
+538
-0
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
+9
-1
include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp
...operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp
+117
-86
include/ck/utility/amd_buffer_addressing.hpp
include/ck/utility/amd_buffer_addressing.hpp
+16
-2
include/ck/utility/amd_gemm_dpp.hpp
include/ck/utility/amd_gemm_dpp.hpp
+51
-5
include/ck/utility/amd_xdlops.hpp
include/ck/utility/amd_xdlops.hpp
+2
-0
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+123
-8
include/ck/utility/f8_utils.hpp
include/ck/utility/f8_utils.hpp
+109
-111
include/ck/utility/inner_product.hpp
include/ck/utility/inner_product.hpp
+12
-0
include/ck/utility/inner_product_dpp8.hpp
include/ck/utility/inner_product_dpp8.hpp
+4
-0
include/ck/utility/loop_scheduler.hpp
include/ck/utility/loop_scheduler.hpp
+26
-0
include/ck/utility/reduction_operator.hpp
include/ck/utility/reduction_operator.hpp
+60
-1
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp
View file @
e1a5137e
...
@@ -151,8 +151,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
...
@@ -151,8 +151,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
using
GridwiseGemmPipe
=
remove_cvref_t
<
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
>
())
>
;
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
>
())
>
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
()
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
()
{
{
...
@@ -331,18 +331,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
...
@@ -331,18 +331,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
c_grid_desc_m_n
);
c_grid_desc_m_n
);
}
}
using
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
=
using
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
=
remove_cvref_t
<
decltype
(
remove_cvref_t
<
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
(
decltype
(
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
(
CGridDesc_M_N
{}))
>
;
CGridDesc_M_N
{}))
>
;
using
C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
=
using
C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
=
remove_cvref_t
<
decltype
(
remove_cvref_t
<
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
(
decltype
(
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
(
C0GridDesc_M_N
{}))
>
;
C0GridDesc_M_N
{}))
>
;
using
C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
=
using
C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
=
remove_cvref_t
<
decltype
(
remove_cvref_t
<
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
(
decltype
(
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
(
C1GridDesc_M_N
{}))
>
;
C1GridDesc_M_N
{}))
>
;
using
DefaultBlock2CTileMap
=
using
DefaultBlock2CTileMap
=
...
@@ -674,14 +674,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
...
@@ -674,14 +674,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
FloatC
,
// typename Src1Data,
FloatC
,
// typename Src1Data,
FloatC
,
// typename Src2Data,
FloatC
,
// typename Src2Data,
FloatC
,
// typename DstData,
FloatC
,
// typename DstData,
decltype
(
decltype
(
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
),
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
),
decltype
(
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
),
decltype
(
decltype
(
c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
),
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
),
decltype
(
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
),
decltype
(
c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
),
decltype
(
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
),
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
// typename DimAccessOrder,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
// typename DimAccessOrder,
5
,
// index_t VectorDim,
5
,
// index_t VectorDim,
CBlockTransferScalarPerVector_NWaveNPerXdl
,
// index_t ScalarPerVector,
CBlockTransferScalarPerVector_NWaveNPerXdl
,
// index_t ScalarPerVector,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_image_to_column.hpp
0 → 100644
View file @
e1a5137e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
template
<
typename
InputGridDesc
,
typename
InputDataType
,
typename
OutputGridDesc
,
typename
OutputDataType
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
KPerBlock
,
typename
ThreadClusterLengths
,
index_t
ScalarPerVector
,
typename
Block2ETileMap
>
struct
GridwiseImageToColumn
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
__device__
static
void
Run
(
const
InputGridDesc
&
in_grid_desc
,
const
InputDataType
*
__restrict__
p_in_global
,
const
OutputGridDesc
&
out_grid_desc
,
OutputDataType
*
__restrict__
p_out_global
,
const
Block2ETileMap
&
block_2_tile_map
)
{
const
auto
block_work_idx
=
block_2_tile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
MPerBlock
);
const
index_t
k_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
KPerBlock
);
// Global Memory
const
auto
in_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_global
,
in_grid_desc
.
GetElementSpaceSize
());
auto
out_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_out_global
,
out_grid_desc
.
GetElementSpaceSize
());
auto
copy_global_to_global
=
ThreadGroupTensorSliceTransfer_v7
<
ThisThreadBlock
,
Tuple
<
InputDataType
>
,
Tuple
<
OutputDataType
>
,
decltype
(
tie
(
in_grid_desc
)),
decltype
(
tie
(
out_grid_desc
)),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
static_cast
<
index_t
>
(
InMemoryDataOperationEnum
::
Set
)
>
,
Sequence
<
MPerBlock
,
KPerBlock
>
,
ThreadClusterLengths
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
I1
,
ScalarPerVector
,
Sequence
<
true
>
,
Sequence
<
true
>>
{
in_grid_desc
,
make_tuple
(
make_multi_index
(
m_block_data_idx_on_grid
,
k_block_data_idx_on_grid
)),
out_grid_desc
,
make_tuple
(
make_multi_index
(
m_block_data_idx_on_grid
,
k_block_data_idx_on_grid
)),
tensor_operation
::
element_wise
::
PassThrough
{}};
copy_global_to_global
.
Run
(
tie
(
in_grid_desc
),
tie
(
in_global_buf
),
tie
(
out_grid_desc
),
tie
(
out_global_buf
));
}
__host__
static
constexpr
bool
CheckValidity
(
const
InputGridDesc
&
in_grid_desc
,
const
OutputGridDesc
&
out_grid_desc
)
{
if
(
in_grid_desc
.
GetLength
(
I0
)
%
MPerBlock
!=
0
||
in_grid_desc
.
GetLength
(
I1
)
%
KPerBlock
!=
0
)
return
false
;
if
(
out_grid_desc
.
GetLength
(
I0
)
%
MPerBlock
!=
0
||
out_grid_desc
.
GetLength
(
I1
)
%
KPerBlock
!=
0
)
return
false
;
return
true
;
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm.hpp
View file @
e1a5137e
...
@@ -78,8 +78,8 @@ struct GridwiseSparseEmbeddingsForwardLayernorm
...
@@ -78,8 +78,8 @@ struct GridwiseSparseEmbeddingsForwardLayernorm
using
ThreadwiseWolfordDesc2D
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
using
ThreadwiseWolfordDesc2D
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
DimSubBlocks
*
DimThreadSize
>
{},
Number
<
RowSubBlocks
*
RowVectorSize
>
{})));
Number
<
DimSubBlocks
*
DimThreadSize
>
{},
Number
<
RowSubBlocks
*
RowVectorSize
>
{})));
using
ThreadwiseWolfordDescReduce
=
decltype
(
using
ThreadwiseWolfordDescReduce
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
DimSubBlocks
*
DimThreadSize
>
{})));
make_tuple
(
Number
<
DimSubBlocks
*
DimThreadSize
>
{})));
using
ThreadwiseWelford
=
using
ThreadwiseWelford
=
ThreadwiseWelford
<
AccDataType
,
ThreadwiseWolfordDesc2D
,
ThreadwiseWolfordDescReduce
>
;
ThreadwiseWelford
<
AccDataType
,
ThreadwiseWolfordDesc2D
,
ThreadwiseWolfordDescReduce
>
;
...
...
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_1st.hpp
View file @
e1a5137e
...
@@ -87,9 +87,9 @@ struct GridwiseNormalizationSplitK1st
...
@@ -87,9 +87,9 @@ struct GridwiseNormalizationSplitK1st
int
left_kPerBlock
=
math
::
integer_divide_ceil
(
k
,
kGridSize
);
int
left_kPerBlock
=
math
::
integer_divide_ceil
(
k
,
kGridSize
);
int
kRightmostBlock
=
kRaw
-
left_kPerBlock
*
(
kGridSize
-
1
);
int
kRightmostBlock
=
kRaw
-
left_kPerBlock
*
(
kGridSize
-
1
);
int
kPerThread
=
kRightmostBlock
<
K_BlockTileSize
int
kPerThread
=
kRightmostBlock
<
K_BlockTileSize
?
0
?
0
:
KThreadSliceSize
*
(
kRightmostBlock
/
K_BlockTileSize
);
:
KThreadSliceSize
*
(
kRightmostBlock
/
K_BlockTileSize
);
int
kPerBlockTail
=
kRightmostBlock
-
kPerThread
*
KThreadClusterSize
;
int
kPerBlockTail
=
kRightmostBlock
-
kPerThread
*
KThreadClusterSize
;
if
(
kPerBlockTail
>
0
)
if
(
kPerBlockTail
>
0
)
{
{
...
...
include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl_dpp8.hpp
deleted
100644 → 0
View file @
eb57178d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/amd_gemm_dpp.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/utility/inner_product_dpp8.hpp"
#include "ck/utility/math.hpp"
namespace
ck
{
/**
* Threadwise contraction using dot instructions with DPP8 modifier.
*
* Assumptions:
* 1. `AThreadDesc_TK0_TM0_TM1_TK1`, `BThreadDesc_TK0_TN0_TN1_TK1`, `CThreadDesc_TM0_TM1_TN0_TN1`
* are known at compile-time;
* 2. `AOriginIdx`, `BOriginIdx`, `COriginIdx` are known at compile-time;
* 3. `TM0` is equal to 1 and `TN0` is equal to 1;
* 4. When `ShareA` is set (unset, respectively), `TM1` (`TN1`, respectively) is divisible by
* the size of the lane group (`dpp8::lane_group_size`).
*/
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
,
typename
AThreadDesc_TK0_TM0_TM1_TK1
,
typename
BThreadDesc_TK0_TN0_TN1_TK1
,
typename
CThreadDesc_TM0_TM1_TN0_TN1
,
typename
TKLengths
,
typename
TMLengths
,
typename
TNLengths
,
bool
ShareA
,
typename
enable_if
<
AThreadDesc_TK0_TM0_TM1_TK1
::
IsKnownAtCompileTime
()
&&
BThreadDesc_TK0_TN0_TN1_TK1
::
IsKnownAtCompileTime
()
&&
CThreadDesc_TM0_TM1_TN0_TN1
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
struct
ThreadwiseContractionDlDpp8_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
index_t
TK0
=
TKLengths
{}[
I0
];
static
constexpr
index_t
TK1
=
TKLengths
{}[
I1
];
static
constexpr
index_t
TM0
=
TMLengths
{}[
I0
];
static
constexpr
index_t
TM1
=
TMLengths
{}[
I1
];
static
constexpr
index_t
TN0
=
TNLengths
{}[
I0
];
static
constexpr
index_t
TN1
=
TNLengths
{}[
I1
];
static_assert
(
TM0
==
1
&&
TN0
==
1
);
static_assert
((
ShareA
&&
TM1
%
dpp8
::
lane_group_size
==
0
)
||
(
!
ShareA
&&
TN1
%
dpp8
::
lane_group_size
==
0
));
static
constexpr
index_t
shared_elems_per_lane
=
ShareA
?
TM1
/
dpp8
::
lane_group_size
:
TN1
/
dpp8
::
lane_group_size
;
__device__
constexpr
ThreadwiseContractionDlDpp8_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1
()
{
static_assert
(
AThreadDesc_TK0_TM0_TM1_TK1
::
IsKnownAtCompileTime
()
&&
BThreadDesc_TK0_TN0_TN1_TK1
::
IsKnownAtCompileTime
()
&&
CThreadDesc_TM0_TM1_TN0_TN1
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
static_assert
(
TKLengths
::
Size
()
==
2
&&
TMLengths
::
Size
()
==
2
&&
TNLengths
::
Size
()
==
2
,
"wrong!"
);
}
template
<
typename
ABuffer
,
typename
AOriginIdx
,
typename
BBuffer
,
typename
BOriginIdx
,
typename
CBuffer
,
typename
COriginIdx
>
__device__
static
void
Run
(
const
ABuffer
&
a_buf
,
AOriginIdx
,
const
BBuffer
&
b_buf
,
BOriginIdx
,
CBuffer
&
c_buf
,
COriginIdx
)
{
static_assert
(
is_known_at_compile_time
<
remove_cvref_t
<
AOriginIdx
>>::
value
&&
is_known_at_compile_time
<
remove_cvref_t
<
BOriginIdx
>>::
value
&&
is_known_at_compile_time
<
remove_cvref_t
<
COriginIdx
>>::
value
,
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"
);
static_assert
(
is_same
<
remove_cvref_t
<
typename
ABuffer
::
type
>
,
remove_cvref_t
<
FloatA
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
BBuffer
::
type
>
,
remove_cvref_t
<
FloatB
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
CBuffer
::
type
>
,
remove_cvref_t
<
FloatC
>>::
value
&&
"wrong! inconsistent type"
);
constexpr
auto
a_origin_idx
=
to_multi_index
(
AOriginIdx
{});
constexpr
auto
b_origin_idx
=
to_multi_index
(
BOriginIdx
{});
constexpr
auto
c_origin_idx
=
to_multi_index
(
COriginIdx
{});
static_for
<
0
,
TK0
,
1
>
{}([
&
](
auto
tk0
)
{
static_for
<
0
,
TM1
,
1
>
{}([
&
](
auto
tm1
)
{
static_for
<
0
,
TN1
,
1
>
{}([
&
](
auto
tn1
)
{
vector_type
<
FloatA
,
TK1
>
a_vec
;
vector_type
<
FloatB
,
TK1
>
b_vec
;
static_for
<
0
,
TK1
,
1
>
{}([
&
](
auto
tk1
)
{
constexpr
index_t
local_tm1
=
ShareA
?
tm1
%
shared_elems_per_lane
:
tm1
;
constexpr
index_t
a_offset
=
AThreadDesc_TK0_TM0_TM1_TK1
{}.
CalculateOffset
(
a_origin_idx
+
make_multi_index
(
tk0
,
0
,
local_tm1
,
tk1
));
constexpr
index_t
local_tn1
=
ShareA
?
tn1
:
tn1
%
shared_elems_per_lane
;
constexpr
index_t
b_offset
=
BThreadDesc_TK0_TN0_TN1_TK1
{}.
CalculateOffset
(
b_origin_idx
+
make_multi_index
(
tk0
,
0
,
local_tn1
,
tk1
));
a_vec
.
template
AsType
<
FloatA
>()(
tk1
)
=
a_buf
[
Number
<
a_offset
>
{}];
b_vec
.
template
AsType
<
FloatB
>()(
tk1
)
=
b_buf
[
Number
<
b_offset
>
{}];
});
using
a_vector_t
=
typename
vector_type
<
FloatA
,
TK1
>::
type
;
using
b_vector_t
=
typename
vector_type
<
FloatB
,
TK1
>::
type
;
constexpr
index_t
c_offset
=
CThreadDesc_TM0_TM1_TN0_TN1
{}.
CalculateOffset
(
c_origin_idx
+
make_multi_index
(
0
,
tm1
,
0
,
tn1
));
constexpr
int
src_lane
=
ShareA
?
(
tm1
/
shared_elems_per_lane
)
%
dpp8
::
lane_group_size
:
(
tn1
/
shared_elems_per_lane
)
%
dpp8
::
lane_group_size
;
dpp8
::
inner_product_dpp
<
a_vector_t
,
b_vector_t
,
FloatC
,
src_lane
,
ShareA
>
(
a_vec
.
template
AsType
<
a_vector_t
>()[
I0
],
b_vec
.
template
AsType
<
b_vector_t
>()[
I0
],
c_buf
(
Number
<
c_offset
>
{}));
});
});
});
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
e1a5137e
...
@@ -137,13 +137,12 @@ struct ThreadwiseTensorSliceTransfer_v1r3
...
@@ -137,13 +137,12 @@ struct ThreadwiseTensorSliceTransfer_v1r3
constexpr
index_t
src_offset
=
src_desc
.
CalculateOffset
(
constexpr
index_t
src_offset
=
src_desc
.
CalculateOffset
(
src_slice_origin_idx
+
idx_md
+
i
*
dst_scalar_step_in_vector
);
src_slice_origin_idx
+
idx_md
+
i
*
dst_scalar_step_in_vector
);
Src
Data
v
;
Dst
Data
v
;
// apply element-wise operation
// apply element-wise operation
element_op_
(
v
,
src_buf
[
Number
<
src_offset
>
{}]);
element_op_
(
v
,
src_buf
[
Number
<
src_offset
>
{}]);
// apply type convert
dst_vector
.
template
AsType
<
DstData
>()(
i
)
=
v
;
dst_vector
.
template
AsType
<
DstData
>()(
i
)
=
type_convert
<
DstData
>
(
v
);
});
});
const
bool
is_dst_valid
=
const
bool
is_dst_valid
=
...
@@ -1289,13 +1288,13 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
...
@@ -1289,13 +1288,13 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
constexpr
index_t
dst_offset
=
dst_desc
.
CalculateOffset
(
constexpr
index_t
dst_offset
=
dst_desc
.
CalculateOffset
(
dst_slice_origin_idx
+
idx_md
+
i
*
dst_scalar_step_in_vector
);
dst_slice_origin_idx
+
idx_md
+
i
*
dst_scalar_step_in_vector
);
Src
Data
v
;
Dst
Data
v
;
// apply element-wise operation
// apply element-wise operation
element_op_
(
v
,
src_buf
[
Number
<
src_offset
>
{}]);
element_op_
(
v
,
src_buf
[
Number
<
src_offset
>
{}]);
// apply type convert
// apply type convert
dst_buf
(
Number
<
dst_offset
>
{})
=
type_convert
<
DstData
>
(
v
)
;
dst_buf
(
Number
<
dst_offset
>
{})
=
v
;
});
});
});
});
}
}
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
View file @
e1a5137e
...
@@ -129,6 +129,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -129,6 +129,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
static_assert
(
SliceLengths
::
At
(
SrcVectorDim
)
%
SrcScalarPerVector
==
0
,
"SliceLengths[SrcVectorDim] must be divisible by SrcScalarPerVector"
);
constexpr
auto
src_dim_access_order
=
SrcDimAccessOrder
{};
constexpr
auto
src_dim_access_order
=
SrcDimAccessOrder
{};
constexpr
auto
ordered_src_access_lengths
=
constexpr
auto
ordered_src_access_lengths
=
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp
View file @
e1a5137e
...
@@ -104,13 +104,13 @@ struct ThreadwiseTensorSliceTransfer_v6r1
...
@@ -104,13 +104,13 @@ struct ThreadwiseTensorSliceTransfer_v6r1
// apply pointwise operation
// apply pointwise operation
static_for
<
0
,
ScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
ScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
Src
Data
v
;
Dst
Data
v
;
// apply element-wise operation
// apply element-wise operation
element_op_
(
v
,
src_vector_container
.
template
AsType
<
SrcData
>()[
i
]);
element_op_
(
v
,
src_vector_container
.
template
AsType
<
SrcData
>()[
i
]);
// apply type convert
// apply type convert
dst_vector_container
.
template
AsType
<
DstData
>()(
i
)
=
type_convert
<
DstData
>
(
v
)
;
dst_vector_container
.
template
AsType
<
DstData
>()(
i
)
=
v
;
});
});
const
bool
is_dst_valid
=
const
bool
is_dst_valid
=
...
...
include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp
0 → 100644
View file @
e1a5137e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/amd_gemm_dpp.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/utility/math.hpp"
namespace
ck
{
enum
struct
DppInstr
{
dpp8_f16_1x32x2
=
0
,
dpp8_f16_2x16x2
,
dpp8_f16_2x32x2
,
dpp8_f16_4x16x2
,
dpp8_f16_4x32x2
,
dpp8_f16_8x16x2
,
dpp8_f16_8x32x2
,
dpp8_f16_16x16x2
,
dpp8_f16_32x8x2
};
/**
* Structure representing DPP GEMM executed by a single wavefront.
*
* Each structure instantiation must contain the following fields:
* - wave_size - number of threads that execute a single DPP GEMM operation, usually equal to the
* number of threads in a wavefront;
* - lanegroup_size - number of threads (lanes) that share data using DPP instruction modifier,
* it's 8 in case of DPP8;
* - m_per_wave - size along M dimension of matrix C that is processed in a single DPP GEMM
* operation;
* - n_per_wave - size along N dimension of matrix C that is processed in a single DPP GEMM
* operation;
* - m_per_lanegroup - size along M dimension that is processed by a single lanegroup;
* - n_per_lanegroup - size along N dimension that is processed by a single lanegroup;
* - m_per_thread - size along M dimension of the tile calculated by a single thread;
* - n_per_thread - size along N dimension of the tile calculated by a single thread;
* - k_per_dpp - size along K dimension that is reduced in a single DPP GEMM operation;
* - share_a - indicates whether we share matrix A or matrix B between lanes using DPP modifiers.
*
* Not all the combinarions are supported now, for current restrictions see the static asserts
* in the DppSelector's contructor.
*/
template
<
DppInstr
instr
>
struct
dpp_type
;
template
<
>
struct
dpp_type
<
DppInstr
::
dpp8_f16_32x8x2
>
{
static
constexpr
index_t
wave_size
=
32
;
static
constexpr
index_t
lanegroup_size
=
8
;
static
constexpr
index_t
m_per_wave
=
32
;
static
constexpr
index_t
n_per_wave
=
8
;
static
constexpr
index_t
m_per_lanegroup
=
8
;
static
constexpr
index_t
n_per_lanegroup
=
8
;
static
constexpr
index_t
m_per_thread
=
8
;
static
constexpr
index_t
n_per_thread
=
1
;
static
constexpr
index_t
k_per_dpp
=
2
;
static
constexpr
bool
share_a
=
true
;
using
BaseType
=
half_t
;
template
<
index_t
MPerDpp
,
index_t
NPerDpp
,
class
ADataType
,
class
BDataType
,
class
CDataType
>
__device__
void
run
(
const
ADataType
&
a
,
const
BDataType
&
b
,
CDataType
&
reg_c
)
const
{
dpp8
::
DppLanegroupGemm
<
m_per_thread
,
n_per_thread
,
k_per_dpp
,
BaseType
,
ADataType
,
BDataType
,
CDataType
,
share_a
>
{}
.
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
dpp_type
<
DppInstr
::
dpp8_f16_8x32x2
>
{
static
constexpr
index_t
wave_size
=
32
;
static
constexpr
index_t
lanegroup_size
=
8
;
static
constexpr
index_t
m_per_wave
=
8
;
static
constexpr
index_t
n_per_wave
=
32
;
static
constexpr
index_t
m_per_lanegroup
=
8
;
static
constexpr
index_t
n_per_lanegroup
=
8
;
static
constexpr
index_t
m_per_thread
=
8
;
static
constexpr
index_t
n_per_thread
=
1
;
static
constexpr
index_t
k_per_dpp
=
2
;
static
constexpr
bool
share_a
=
true
;
using
BaseType
=
half_t
;
template
<
index_t
MPerDpp
,
index_t
NPerDpp
,
class
ADataType
,
class
BDataType
,
class
CDataType
>
__device__
void
run
(
const
ADataType
&
a
,
const
BDataType
&
b
,
CDataType
&
reg_c
)
const
{
dpp8
::
DppLanegroupGemm
<
m_per_thread
,
n_per_thread
,
k_per_dpp
,
BaseType
,
ADataType
,
BDataType
,
CDataType
,
share_a
>
{}
.
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
dpp_type
<
DppInstr
::
dpp8_f16_8x16x2
>
{
static
constexpr
index_t
wave_size
=
32
;
static
constexpr
index_t
lanegroup_size
=
8
;
static
constexpr
index_t
m_per_wave
=
8
;
static
constexpr
index_t
n_per_wave
=
16
;
static
constexpr
index_t
m_per_lanegroup
=
4
;
static
constexpr
index_t
n_per_lanegroup
=
8
;
static
constexpr
index_t
m_per_thread
=
4
;
static
constexpr
index_t
n_per_thread
=
1
;
static
constexpr
index_t
k_per_dpp
=
2
;
static
constexpr
bool
share_a
=
true
;
using
BaseType
=
half_t
;
template
<
index_t
MPerDpp
,
index_t
NPerDpp
,
class
ADataType
,
class
BDataType
,
class
CDataType
>
__device__
void
run
(
const
ADataType
&
a
,
const
BDataType
&
b
,
CDataType
&
reg_c
)
const
{
dpp8
::
DppLanegroupGemm
<
m_per_thread
,
n_per_thread
,
k_per_dpp
,
BaseType
,
ADataType
,
BDataType
,
CDataType
,
share_a
>
{}
.
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
dpp_type
<
DppInstr
::
dpp8_f16_16x16x2
>
{
static
constexpr
index_t
wave_size
=
32
;
static
constexpr
index_t
lanegroup_size
=
8
;
static
constexpr
index_t
m_per_wave
=
16
;
static
constexpr
index_t
n_per_wave
=
16
;
static
constexpr
index_t
m_per_lanegroup
=
8
;
static
constexpr
index_t
n_per_lanegroup
=
8
;
static
constexpr
index_t
m_per_thread
=
8
;
static
constexpr
index_t
n_per_thread
=
1
;
static
constexpr
index_t
k_per_dpp
=
2
;
static
constexpr
bool
share_a
=
true
;
using
BaseType
=
half_t
;
template
<
index_t
MPerDpp
,
index_t
NPerDpp
,
class
ADataType
,
class
BDataType
,
class
CDataType
>
__device__
void
run
(
const
ADataType
&
a
,
const
BDataType
&
b
,
CDataType
&
reg_c
)
const
{
dpp8
::
DppLanegroupGemm
<
m_per_thread
,
n_per_thread
,
k_per_dpp
,
BaseType
,
ADataType
,
BDataType
,
CDataType
,
share_a
>
{}
.
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
dpp_type
<
DppInstr
::
dpp8_f16_4x32x2
>
{
static
constexpr
index_t
wave_size
=
32
;
static
constexpr
index_t
lanegroup_size
=
8
;
static
constexpr
index_t
m_per_wave
=
4
;
static
constexpr
index_t
n_per_wave
=
32
;
static
constexpr
index_t
m_per_lanegroup
=
4
;
static
constexpr
index_t
n_per_lanegroup
=
8
;
static
constexpr
index_t
m_per_thread
=
4
;
static
constexpr
index_t
n_per_thread
=
1
;
static
constexpr
index_t
k_per_dpp
=
2
;
static
constexpr
bool
share_a
=
true
;
using
BaseType
=
half_t
;
template
<
index_t
MPerDpp
,
index_t
NPerDpp
,
class
ADataType
,
class
BDataType
,
class
CDataType
>
__device__
void
run
(
const
ADataType
&
a
,
const
BDataType
&
b
,
CDataType
&
reg_c
)
const
{
dpp8
::
DppLanegroupGemm
<
m_per_thread
,
n_per_thread
,
k_per_dpp
,
BaseType
,
ADataType
,
BDataType
,
CDataType
,
share_a
>
{}
.
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
dpp_type
<
DppInstr
::
dpp8_f16_4x16x2
>
{
static
constexpr
index_t
wave_size
=
32
;
static
constexpr
index_t
lanegroup_size
=
8
;
static
constexpr
index_t
m_per_wave
=
4
;
static
constexpr
index_t
n_per_wave
=
16
;
static
constexpr
index_t
m_per_lanegroup
=
2
;
static
constexpr
index_t
n_per_lanegroup
=
8
;
static
constexpr
index_t
m_per_thread
=
2
;
static
constexpr
index_t
n_per_thread
=
1
;
static
constexpr
index_t
k_per_dpp
=
2
;
static
constexpr
bool
share_a
=
true
;
using
BaseType
=
half_t
;
template
<
index_t
MPerDpp
,
index_t
NPerDpp
,
class
ADataType
,
class
BDataType
,
class
CDataType
>
__device__
void
run
(
const
ADataType
&
a
,
const
BDataType
&
b
,
CDataType
&
reg_c
)
const
{
dpp8
::
DppLanegroupGemm
<
m_per_thread
,
n_per_thread
,
k_per_dpp
,
BaseType
,
ADataType
,
BDataType
,
CDataType
,
share_a
>
{}
.
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
dpp_type
<
DppInstr
::
dpp8_f16_1x32x2
>
{
static
constexpr
index_t
wave_size
=
32
;
static
constexpr
index_t
lanegroup_size
=
8
;
static
constexpr
index_t
m_per_wave
=
1
;
static
constexpr
index_t
n_per_wave
=
32
;
static
constexpr
index_t
m_per_lanegroup
=
1
;
static
constexpr
index_t
n_per_lanegroup
=
8
;
static
constexpr
index_t
m_per_thread
=
1
;
static
constexpr
index_t
n_per_thread
=
1
;
static
constexpr
index_t
k_per_dpp
=
2
;
static
constexpr
bool
share_a
=
true
;
using
BaseType
=
half_t
;
template
<
index_t
MPerDpp
,
index_t
NPerDpp
,
class
ADataType
,
class
BDataType
,
class
CDataType
>
__device__
void
run
(
const
ADataType
&
a
,
const
BDataType
&
b
,
CDataType
&
reg_c
)
const
{
dpp8
::
DppLanegroupGemm
<
m_per_thread
,
n_per_thread
,
k_per_dpp
,
BaseType
,
ADataType
,
BDataType
,
CDataType
,
share_a
>
{}
.
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
dpp_type
<
DppInstr
::
dpp8_f16_2x32x2
>
{
static
constexpr
index_t
wave_size
=
32
;
static
constexpr
index_t
lanegroup_size
=
8
;
static
constexpr
index_t
m_per_wave
=
2
;
static
constexpr
index_t
n_per_wave
=
32
;
static
constexpr
index_t
m_per_lanegroup
=
2
;
static
constexpr
index_t
n_per_lanegroup
=
8
;
static
constexpr
index_t
m_per_thread
=
2
;
static
constexpr
index_t
n_per_thread
=
1
;
static
constexpr
index_t
k_per_dpp
=
2
;
static
constexpr
bool
share_a
=
true
;
using
BaseType
=
half_t
;
template
<
index_t
MPerDpp
,
index_t
NPerDpp
,
class
ADataType
,
class
BDataType
,
class
CDataType
>
__device__
void
run
(
const
ADataType
&
a
,
const
BDataType
&
b
,
CDataType
&
reg_c
)
const
{
dpp8
::
DppLanegroupGemm
<
m_per_thread
,
n_per_thread
,
k_per_dpp
,
BaseType
,
ADataType
,
BDataType
,
CDataType
,
share_a
>
{}
.
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
dpp_type
<
DppInstr
::
dpp8_f16_2x16x2
>
{
static
constexpr
index_t
wave_size
=
32
;
static
constexpr
index_t
lanegroup_size
=
8
;
static
constexpr
index_t
m_per_wave
=
2
;
static
constexpr
index_t
n_per_wave
=
16
;
static
constexpr
index_t
m_per_lanegroup
=
1
;
static
constexpr
index_t
n_per_lanegroup
=
8
;
static
constexpr
index_t
m_per_thread
=
1
;
static
constexpr
index_t
n_per_thread
=
1
;
static
constexpr
index_t
k_per_dpp
=
2
;
static
constexpr
bool
share_a
=
true
;
using
BaseType
=
half_t
;
template
<
index_t
MPerDpp
,
index_t
NPerDpp
,
class
ADataType
,
class
BDataType
,
class
CDataType
>
__device__
void
run
(
const
ADataType
&
a
,
const
BDataType
&
b
,
CDataType
&
reg_c
)
const
{
dpp8
::
DppLanegroupGemm
<
m_per_thread
,
n_per_thread
,
k_per_dpp
,
BaseType
,
ADataType
,
BDataType
,
CDataType
,
share_a
>
{}
.
Run
(
a
,
b
,
reg_c
);
}
};
template
<
typename
BaseType
,
index_t
MPerDpp
,
index_t
NPerDpp
>
struct
DppSelector
{
template
<
typename
BaseType_
,
index_t
MPerDpp_
,
index_t
NPerDpp_
>
static
constexpr
auto
GetDpp
();
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
8
,
32
>
()
{
return
DppInstr
::
dpp8_f16_8x32x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
8
,
16
>
()
{
return
DppInstr
::
dpp8_f16_8x16x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
16
,
16
>
()
{
return
DppInstr
::
dpp8_f16_16x16x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
32
,
8
>
()
{
return
DppInstr
::
dpp8_f16_32x8x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
1
,
32
>
()
{
return
DppInstr
::
dpp8_f16_1x32x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
2
,
32
>
()
{
return
DppInstr
::
dpp8_f16_2x32x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
2
,
16
>
()
{
return
DppInstr
::
dpp8_f16_2x16x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
4
,
16
>
()
{
return
DppInstr
::
dpp8_f16_4x16x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
4
,
32
>
()
{
return
DppInstr
::
dpp8_f16_4x32x2
;
}
static
constexpr
auto
selected_dpp
=
dpp_type
<
GetDpp
<
BaseType
,
MPerDpp
,
NPerDpp
>
()
>
{};
__host__
__device__
constexpr
DppSelector
()
{
static_assert
(
selected_dpp
.
m_per_wave
%
selected_dpp
.
m_per_lanegroup
==
0
);
static_assert
(
selected_dpp
.
n_per_wave
%
selected_dpp
.
n_per_lanegroup
==
0
);
static_assert
(
selected_dpp
.
k_per_dpp
%
2
==
0
);
static_assert
(
selected_dpp
.
wave_size
%
selected_dpp
.
lanegroup_size
==
0
);
constexpr
index_t
num_dpp_per_wave
=
selected_dpp
.
wave_size
/
selected_dpp
.
lanegroup_size
;
constexpr
index_t
num_wave_c_elems
=
selected_dpp
.
m_per_wave
*
selected_dpp
.
n_per_wave
;
constexpr
index_t
num_dpp_c_elems
=
selected_dpp
.
m_per_lanegroup
*
selected_dpp
.
n_per_lanegroup
;
static_assert
(
num_wave_c_elems
%
num_dpp_c_elems
==
0
);
static_assert
(
num_dpp_per_wave
==
num_wave_c_elems
/
num_dpp_c_elems
);
if
constexpr
(
selected_dpp
.
share_a
)
{
static_assert
(
selected_dpp
.
m_per_lanegroup
==
selected_dpp
.
m_per_thread
);
static_assert
(
selected_dpp
.
n_per_lanegroup
%
selected_dpp
.
n_per_thread
==
0
);
static_assert
(
selected_dpp
.
n_per_lanegroup
/
selected_dpp
.
n_per_thread
==
selected_dpp
.
lanegroup_size
);
}
else
{
static_assert
(
selected_dpp
.
m_per_lanegroup
%
selected_dpp
.
n_per_thread
==
0
);
static_assert
(
selected_dpp
.
m_per_lanegroup
/
selected_dpp
.
n_per_thread
==
selected_dpp
.
lanegroup_size
);
static_assert
(
selected_dpp
.
n_per_lanegroup
==
selected_dpp
.
n_per_thread
);
}
// Below checks come from the restrictions of the current implementation, could be removed
// in the future when the implementation is more generalized.
static_assert
(
selected_dpp
.
share_a
);
static_assert
(
selected_dpp
.
n_per_thread
==
1
);
static_assert
(
selected_dpp
.
m_per_lanegroup
==
selected_dpp
.
m_per_thread
);
static_assert
(
selected_dpp
.
n_per_lanegroup
==
selected_dpp
.
n_per_thread
*
selected_dpp
.
lanegroup_size
);
}
static
constexpr
index_t
GetK1PerDpp
()
{
return
selected_dpp
.
k_per_dpp
;
}
};
template
<
typename
BaseType
,
index_t
MPerDpp
,
index_t
NPerDpp
,
index_t
KPack
>
struct
DppGemm
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
using
CIndex
=
MultiIndex
<
2
>
;
using
CIndex4D
=
MultiIndex
<
4
>
;
__host__
__device__
constexpr
DppGemm
()
{
static_assert
(
KPack
%
dpp_instr
.
k_per_dpp
==
0
,
"KPack must be divisible by k_per_dpp."
);
}
__device__
static
constexpr
index_t
GetRegSizePerDpp
()
{
return
MPerDpp
*
NPerDpp
/
dpp_instr
.
wave_size
;
}
template
<
class
ADataType
,
class
BDataType
,
class
CDataType
>
__device__
void
Run
(
const
ADataType
&
p_a_wave
,
const
BDataType
&
p_b_wave
,
CDataType
&
p_c_thread
)
const
{
static_assert
(
is_same
<
BaseType
,
double
>::
value
||
is_same
<
BaseType
,
float
>::
value
||
is_same
<
BaseType
,
half_t
>::
value
||
is_same
<
BaseType
,
bhalf_t
>::
value
||
is_same
<
BaseType
,
int8_t
>::
value
||
is_same
<
BaseType
,
f8_t
>::
value
,
"base BaseType must be double, float, half, bfloat16, and int8_t!"
);
static_for
<
0
,
KPack
/
dpp_instr
.
k_per_dpp
,
1
>
{}([
&
](
auto
k
)
{
dpp_instr
.
template
run
<
MPerDpp
,
NPerDpp
>(
p_a_wave
[
k
],
p_b_wave
[
k
],
p_c_thread
);
});
}
__device__
static
auto
GetLaneIdInWave
()
{
return
get_thread_local_1d_id
()
%
dpp_instr
.
wave_size
;
}
__device__
static
auto
GetWaveId
()
{
return
get_thread_local_1d_id
()
/
dpp_instr
.
wave_size
;
}
__device__
static
auto
GetLaneIdInLaneGroup
()
{
return
get_thread_local_1d_id
()
%
dpp_instr
.
lanegroup_size
;
}
__device__
static
auto
GetLaneGroupIdInWave
()
{
return
GetLaneIdInWave
()
/
dpp_instr
.
lanegroup_size
;
}
__device__
static
auto
GetDppOpIdx
()
{
const
auto
lanegroupId
=
GetLaneGroupIdInWave
();
constexpr
auto
lanegroup_idx_1d_to_dpp_idx_2d_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
dpp_instr
.
m_per_wave
/
dpp_instr
.
m_per_lanegroup
,
dpp_instr
.
n_per_wave
/
dpp_instr
.
n_per_lanegroup
))),
make_tuple
(
Sequence
<
0
,
1
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
dpp_idx
=
lanegroup_idx_1d_to_dpp_idx_2d_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
lanegroupId
));
const
auto
m_dpp_idx
=
dpp_idx
[
I0
];
const
auto
n_dpp_idx
=
dpp_idx
[
I1
];
return
make_tuple
(
m_dpp_idx
,
n_dpp_idx
);
}
__host__
__device__
static
auto
CalculateAThreadOriginDataIndex_K_M
()
{
const
auto
laneId
=
get_thread_local_1d_id
();
const
auto
wave_row
=
laneId
/
dpp_instr
.
n_per_wave
;
auto
m_idx
=
dpp_instr
.
m_per_thread
*
wave_row
+
GetLaneIdInLaneGroup
();
return
make_tuple
(
0
,
m_idx
%
dpp_instr
.
m_per_wave
);
}
__host__
__device__
static
auto
CalculateBThreadOriginDataIndex_K_N
()
{
const
auto
laneId
=
get_thread_local_1d_id
();
return
make_tuple
(
0
,
laneId
%
dpp_instr
.
n_per_wave
);
}
__device__
static
CIndex
GetBeginOfThreadBlk
()
{
const
auto
dpp_op_idx
=
GetDppOpIdx
();
const
auto
m_dpp_op_idx
=
dpp_op_idx
[
I0
];
const
auto
n_dpp_op_idx
=
dpp_op_idx
[
I1
];
index_t
n_offset
=
n_dpp_op_idx
*
dpp_instr
.
n_per_lanegroup
+
GetLaneIdInLaneGroup
();
index_t
m_offset
=
m_dpp_op_idx
*
dpp_instr
.
m_per_lanegroup
;
return
CIndex
{
m_offset
,
n_offset
};
}
static
constexpr
auto
dpp
=
DppSelector
<
BaseType
,
MPerDpp
,
NPerDpp
>
{};
static
constexpr
auto
dpp_instr
=
dpp
.
selected_dpp
;
static
constexpr
auto
K0PerDpp
=
1
;
static
constexpr
auto
K1PerDpp
=
dpp
.
GetK1PerDpp
();
__host__
__device__
static
constexpr
auto
GetCMNThreadBlkLengths
()
{
return
make_tuple
(
Number
<
dpp_instr
.
m_per_thread
>
{},
Number
<
dpp_instr
.
n_per_thread
>
{});
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
View file @
e1a5137e
...
@@ -456,6 +456,7 @@ struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64>
...
@@ -456,6 +456,7 @@ struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64>
}
}
};
};
#if defined CK_ENABLE_FP8
template
<
>
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x16f8f8
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x16f8f8
>
{
{
...
@@ -499,6 +500,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8f8>
...
@@ -499,6 +500,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8f8>
intrin_mfma_f32_16x16x32f8f8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
intrin_mfma_f32_16x16x32f8f8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
}
};
};
#endif
template
<
typename
base_type
,
index_t
MPerXdlops
,
index_t
NPerXdlops
>
template
<
typename
base_type
,
index_t
MPerXdlops
,
index_t
NPerXdlops
>
struct
MfmaSelector
struct
MfmaSelector
...
@@ -640,6 +642,7 @@ struct MfmaSelector
...
@@ -640,6 +642,7 @@ struct MfmaSelector
}
}
#endif
#endif
#if defined CK_ENABLE_FP8
template
<
>
template
<
>
static
constexpr
auto
GetMfma
<
f8_t
,
32
,
32
>
()
static
constexpr
auto
GetMfma
<
f8_t
,
32
,
32
>
()
{
{
...
@@ -651,6 +654,7 @@ struct MfmaSelector
...
@@ -651,6 +654,7 @@ struct MfmaSelector
{
{
return
MfmaInstr
::
mfma_f32_16x16x32f8f8
;
return
MfmaInstr
::
mfma_f32_16x16x32f8f8
;
}
}
#endif
static
constexpr
auto
selected_mfma
=
mfma_type
<
GetMfma
<
base_type
,
MPerXdlops
,
NPerXdlops
>
()
>
{};
static
constexpr
auto
selected_mfma
=
mfma_type
<
GetMfma
<
base_type
,
MPerXdlops
,
NPerXdlops
>
()
>
{};
...
@@ -852,7 +856,11 @@ struct XdlopsGemm
...
@@ -852,7 +856,11 @@ struct XdlopsGemm
{
{
static_assert
(
is_same
<
base_type
,
double
>::
value
||
is_same
<
base_type
,
float
>::
value
||
static_assert
(
is_same
<
base_type
,
double
>::
value
||
is_same
<
base_type
,
float
>::
value
||
is_same
<
base_type
,
half_t
>::
value
||
is_same
<
base_type
,
bhalf_t
>::
value
||
is_same
<
base_type
,
half_t
>::
value
||
is_same
<
base_type
,
bhalf_t
>::
value
||
is_same
<
base_type
,
int8_t
>::
value
||
is_same
<
base_type
,
f8_t
>::
value
,
is_same
<
base_type
,
int8_t
>::
value
#if defined CK_ENABLE_FP8
||
is_same
<
base_type
,
f8_t
>::
value
#endif
,
"base base_type must be double, float, half, bfloat16, and int8_t!"
);
"base base_type must be double, float, half, bfloat16, and int8_t!"
);
static_for
<
0
,
KPack
/
mfma_instr
.
k_per_blk
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
KPack
/
mfma_instr
.
k_per_blk
,
1
>
{}([
&
](
auto
k
)
{
...
...
include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp
View file @
e1a5137e
...
@@ -164,6 +164,7 @@ template <
...
@@ -164,6 +164,7 @@ template <
index_t
BK1
,
index_t
BK1
,
index_t
GemmMPerBlock
,
index_t
GemmMPerBlock
,
index_t
GemmNPerBlock
,
index_t
GemmNPerBlock
,
index_t
GemmKPerBlock
,
bool
DoPadGemmM
,
bool
DoPadGemmM
,
bool
DoPadGemmN
>
bool
DoPadGemmN
>
struct
TransformConvBwdDataToGemm_v1
struct
TransformConvBwdDataToGemm_v1
...
@@ -236,8 +237,6 @@ struct TransformConvBwdDataToGemm_v1
...
@@ -236,8 +237,6 @@ struct TransformConvBwdDataToGemm_v1
const
index_t
ConvDilationH
=
conv_filter_dilations
[
HIdx
-
NonSpatialDimsNum
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
HIdx
-
NonSpatialDimsNum
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
WIdx
-
NonSpatialDimsNum
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
WIdx
-
NonSpatialDimsNum
];
const
index_t
AK0
=
K
/
AK1
;
// n_do_ho_wo_k for 3d or n_ho_wo_k for 2d
// n_do_ho_wo_k for 3d or n_ho_wo_k for 2d
const
auto
out_grid_desc
=
const
auto
out_grid_desc
=
make_out_grid_desc
<
NDimSpatial
,
ALayout
,
ConvBwdDataSpecialization
>
(
make_out_grid_desc
<
NDimSpatial
,
ALayout
,
ConvBwdDataSpecialization
>
(
...
@@ -247,6 +246,8 @@ struct TransformConvBwdDataToGemm_v1
...
@@ -247,6 +246,8 @@ struct TransformConvBwdDataToGemm_v1
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization
::
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
)
Filter1x1Stride1Pad0
)
{
{
const
index_t
AK0
=
math
::
integer_divide_ceil
(
K
,
AK1
);
// A: output tensor
// A: output tensor
const
auto
out_gemmak0_gemmmraw_gemmak1_grid_desc
=
transform_tensor_descriptor
(
const
auto
out_gemmak0_gemmmraw_gemmak1_grid_desc
=
transform_tensor_descriptor
(
out_grid_desc
,
out_grid_desc
,
...
@@ -332,7 +333,7 @@ struct TransformConvBwdDataToGemm_v1
...
@@ -332,7 +333,7 @@ struct TransformConvBwdDataToGemm_v1
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_
ak0_ak1
_grid_desc
=
const
auto
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_
k
_grid_desc
=
transform_tensor_descriptor
(
transform_tensor_descriptor
(
out_n_ydot_htilde_xdot_wtilde_k_grid_desc
,
out_n_ydot_htilde_xdot_wtilde_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_tuple
(
make_pass_through_transform
(
N
),
...
@@ -340,7 +341,7 @@ struct TransformConvBwdDataToGemm_v1
...
@@ -340,7 +341,7 @@ struct TransformConvBwdDataToGemm_v1
make_slice_transform
(
HTilde
,
IHTildeSliceBegin
,
HTildeSlice
),
make_slice_transform
(
HTilde
,
IHTildeSliceBegin
,
HTildeSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_slice_transform
(
WTilde
,
IWTildeSliceBegin
,
WTildeSlice
),
make_slice_transform
(
WTilde
,
IWTildeSliceBegin
,
WTildeSlice
),
make_
unmerge_transform
(
make_tuple
(
AK0
,
AK1
)
)),
make_
pass_through_transform
(
K
)),
make_tuple
(
Sequence
<
0
>
{},
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
2
>
{},
...
@@ -352,21 +353,30 @@ struct TransformConvBwdDataToGemm_v1
...
@@ -352,21 +353,30 @@ struct TransformConvBwdDataToGemm_v1
Sequence
<
2
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
4
>
{},
Sequence
<
5
,
6
>
{}));
Sequence
<
5
>
{}));
const
auto
out_gemmak0_gemmmraw_gemmak1_grid_desc
=
transform_tensor_descriptor
(
const
auto
out_gemmk_gemmmraw_grid_desc
=
transform_tensor_descriptor
(
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_ak0_ak1_grid_desc
,
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
YDotSlice
,
XDotSlice
,
AK0
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
YDotSlice
,
XDotSlice
,
K
)),
make_merge_transform
(
make_tuple
(
N
,
HTildeSlice
,
WTildeSlice
)),
make_merge_transform
(
make_tuple
(
N
,
HTildeSlice
,
WTildeSlice
))),
make_pass_through_transform
(
AK1
)),
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{}),
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
6
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
out_gemm
ak0
_gemmm_
gemmak1
_grid_desc
=
const
auto
out_gemm
k
_gemmm_
padded
_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
out_gemmak0_gemmmraw_gemmak1_grid_desc
,
out_gemmk_gemmmraw_grid_desc
,
make_tuple
(
AK0
,
GemmMPerBlock
,
AK1
),
make_tuple
(
GemmKPerBlock
,
GemmMPerBlock
),
Sequence
<
false
,
DoPadGemmM
,
false
>
{});
Sequence
<
true
,
DoPadGemmM
>
{});
const
index_t
AK0
=
out_gemmk_gemmm_padded_grid_desc
.
GetLength
(
I0
)
/
AK1
;
const
auto
out_gemmak0_gemmm_gemmak1_grid_desc
=
transform_tensor_descriptor
(
out_gemmk_gemmm_padded_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
out_gemmk_gemmm_padded_grid_desc
.
GetLength
(
I1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
out_gemmak0_gemmm_gemmak1_grid_desc
;
return
out_gemmak0_gemmm_gemmak1_grid_desc
;
}
}
...
@@ -411,7 +421,7 @@ struct TransformConvBwdDataToGemm_v1
...
@@ -411,7 +421,7 @@ struct TransformConvBwdDataToGemm_v1
Sequence
<
7
>
{}));
Sequence
<
7
>
{}));
const
auto
const
auto
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_
ak0_ak1
_grid_desc
=
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_
k
_grid_desc
=
transform_tensor_descriptor
(
transform_tensor_descriptor
(
out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc
,
out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_tuple
(
make_pass_through_transform
(
N
),
...
@@ -421,7 +431,7 @@ struct TransformConvBwdDataToGemm_v1
...
@@ -421,7 +431,7 @@ struct TransformConvBwdDataToGemm_v1
make_slice_transform
(
HTilde
,
IHTildeSliceBegin
,
HTildeSlice
),
make_slice_transform
(
HTilde
,
IHTildeSliceBegin
,
HTildeSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_slice_transform
(
WTilde
,
IWTildeSliceBegin
,
WTildeSlice
),
make_slice_transform
(
WTilde
,
IWTildeSliceBegin
,
WTildeSlice
),
make_
unmerge_transform
(
make_tuple
(
AK0
,
AK1
)
)),
make_
pass_through_transform
(
K
)),
make_tuple
(
Sequence
<
0
>
{},
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
2
>
{},
...
@@ -437,22 +447,31 @@ struct TransformConvBwdDataToGemm_v1
...
@@ -437,22 +447,31 @@ struct TransformConvBwdDataToGemm_v1
Sequence
<
4
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{},
Sequence
<
5
>
{},
Sequence
<
6
>
{},
Sequence
<
6
>
{},
Sequence
<
7
,
8
>
{}));
Sequence
<
7
>
{}));
const
auto
out_gemm
ak0
_gemmmraw_
gemmak1_
grid_desc
=
transform_tensor_descriptor
(
const
auto
out_gemm
k
_gemmmraw_grid_desc
=
transform_tensor_descriptor
(
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_
ak0_ak1
_grid_desc
,
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_
k
_grid_desc
,
make_tuple
(
make_tuple
(
make_merge_transform
(
make_tuple
(
ZDotSlice
,
YDotSlice
,
XDotSlice
,
AK0
)),
make_merge_transform
(
make_tuple
(
ZDotSlice
,
YDotSlice
,
XDotSlice
,
K
)),
make_merge_transform
(
make_tuple
(
N
,
DTildeSlice
,
HTildeSlice
,
WTildeSlice
)),
make_merge_transform
(
make_tuple
(
N
,
DTildeSlice
,
HTildeSlice
,
WTildeSlice
))),
make_pass_through_transform
(
AK1
)),
make_tuple
(
Sequence
<
1
,
3
,
5
,
7
>
{},
Sequence
<
0
,
2
,
4
,
6
>
{}),
make_tuple
(
Sequence
<
1
,
3
,
5
,
7
>
{},
Sequence
<
0
,
2
,
4
,
6
>
{},
Sequence
<
8
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
out_gemm
ak0
_gemmm_
gemmak1
_grid_desc
=
const
auto
out_gemm
k
_gemmm_
padded
_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
out_gemmak0_gemmmraw_gemmak1_grid_desc
,
out_gemmk_gemmmraw_grid_desc
,
make_tuple
(
AK0
,
GemmMPerBlock
,
AK1
),
make_tuple
(
GemmKPerBlock
,
GemmMPerBlock
),
Sequence
<
false
,
DoPadGemmM
,
false
>
{});
Sequence
<
true
,
DoPadGemmM
>
{});
const
index_t
AK0
=
out_gemmk_gemmm_padded_grid_desc
.
GetLength
(
I0
)
/
AK1
;
const
auto
out_gemmak0_gemmm_gemmak1_grid_desc
=
transform_tensor_descriptor
(
out_gemmk_gemmm_padded_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
out_gemmk_gemmm_padded_grid_desc
.
GetLength
(
I1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
out_gemmak0_gemmm_gemmak1_grid_desc
;
return
out_gemmak0_gemmm_gemmak1_grid_desc
;
}
}
...
@@ -505,8 +524,6 @@ struct TransformConvBwdDataToGemm_v1
...
@@ -505,8 +524,6 @@ struct TransformConvBwdDataToGemm_v1
const
index_t
ConvDilationH
=
conv_filter_dilations
[
HIdx
-
NonSpatialDimsNum
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
HIdx
-
NonSpatialDimsNum
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
WIdx
-
NonSpatialDimsNum
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
WIdx
-
NonSpatialDimsNum
];
const
index_t
BK0
=
K
/
BK1
;
// assume packed
// assume packed
// k_y_x_c for 2d or k_z_y_x_c for 3d
// k_y_x_c for 2d or k_z_y_x_c for 3d
const
auto
wei_grid_desc
=
make_wei_grid_desc
<
BLayout
>
(
K
,
Z
,
Y
,
X
,
C
);
const
auto
wei_grid_desc
=
make_wei_grid_desc
<
BLayout
>
(
K
,
Z
,
Y
,
X
,
C
);
...
@@ -515,6 +532,8 @@ struct TransformConvBwdDataToGemm_v1
...
@@ -515,6 +532,8 @@ struct TransformConvBwdDataToGemm_v1
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization
::
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
)
Filter1x1Stride1Pad0
)
{
{
const
index_t
BK0
=
math
::
integer_divide_ceil
(
K
,
BK1
);
// B: weight tensor
// B: weight tensor
const
auto
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc
=
const
auto
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
C
)),
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
C
)),
...
@@ -566,43 +585,49 @@ struct TransformConvBwdDataToGemm_v1
...
@@ -566,43 +585,49 @@ struct TransformConvBwdDataToGemm_v1
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
wei_bk0_bk1_ydotslice_xdotslice_c_grid_desc
=
const
auto
wei_k_ydotslice_xdotslice_c_grid_desc
=
transform_tensor_descriptor
(
transform_tensor_descriptor
(
wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc
,
wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
K
),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_freeze_transform
(
i_ytilde
),
make_freeze_transform
(
i_ytilde
),
make_freeze_transform
(
i_xtilde
),
make_freeze_transform
(
i_xtilde
),
make_pass_through_transform
(
C
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
3
>
{},
Sequence
<
2
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<>
{},
Sequence
<>
{},
Sequence
<>
{},
Sequence
<>
{},
Sequence
<
3
>
{}));
Sequence
<
4
>
{}));
const
auto
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc
=
transform_tensor_descriptor
(
const
auto
wei_gemmk_gemmnraw_grid_desc
=
transform_tensor_descriptor
(
wei_bk0_bk1_ydotslice_xdotslice_c_grid_desc
,
wei_k_ydotslice_xdotslice_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
YDotSlice
,
XDotSlice
,
BK0
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
YDotSlice
,
XDotSlice
,
K
)),
make_pass_through_transform
(
C
),
make_pass_through_transform
(
C
)),
make_pass_through_transform
(
BK1
)),
make_tuple
(
Sequence
<
1
,
2
,
0
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
2
,
3
,
0
>
{},
Sequence
<
4
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
wei_gemm
bk0
_gemmn_
gemmbk1
_grid_desc
=
const
auto
wei_gemm
k
_gemmn_
padded
_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc
,
wei_gemmk_gemmnraw_grid_desc
,
make_tuple
(
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc
.
GetLength
(
I0
),
make_tuple
(
GemmKPerBlock
,
GemmNPerBlock
),
GemmNPerBlock
,
Sequence
<
true
,
DoPadGemmN
>
{});
BK1
),
Sequence
<
false
,
DoPadGemmN
,
false
>
{});
const
index_t
BK0
=
wei_gemmk_gemmn_padded_grid_desc
.
GetLength
(
I0
)
/
BK1
;
const
auto
wei_gemmbk0_gemmn_gemmbk1_grid_desc
=
transform_tensor_descriptor
(
wei_gemmk_gemmn_padded_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
wei_gemmk_gemmn_padded_grid_desc
.
GetLength
(
I1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
wei_gemmbk0_gemmn_gemmbk1_grid_desc
;
return
wei_gemmbk0_gemmn_gemmbk1_grid_desc
;
}
}
...
@@ -631,10 +656,10 @@ struct TransformConvBwdDataToGemm_v1
...
@@ -631,10 +656,10 @@ struct TransformConvBwdDataToGemm_v1
Sequence
<
5
,
6
>
{},
Sequence
<
5
,
6
>
{},
Sequence
<
7
>
{}));
Sequence
<
7
>
{}));
const
auto
wei_
bk0_bk1
_zdotslice_ydotslice_xdotslice_c_grid_desc
=
const
auto
wei_
gemmk
_zdotslice_ydotslice_xdotslice_c_grid_desc
=
transform_tensor_descriptor
(
transform_tensor_descriptor
(
wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc
,
wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc
,
make_tuple
(
make_
unmerge_transform
(
make_tuple
(
BK0
,
BK1
)
),
make_tuple
(
make_
pass_through_transform
(
K
),
make_slice_transform
(
ZDot
,
I0
,
ZDotSlice
),
make_slice_transform
(
ZDot
,
I0
,
ZDotSlice
),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
...
@@ -650,33 +675,39 @@ struct TransformConvBwdDataToGemm_v1
...
@@ -650,33 +675,39 @@ struct TransformConvBwdDataToGemm_v1
Sequence
<
4
>
{},
Sequence
<
4
>
{},
Sequence
<
6
>
{},
Sequence
<
6
>
{},
Sequence
<
7
>
{}),
Sequence
<
7
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<>
{},
Sequence
<>
{},
Sequence
<>
{},
Sequence
<>
{},
Sequence
<>
{},
Sequence
<>
{},
Sequence
<
5
>
{}));
Sequence
<
4
>
{}));
const
auto
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc
=
transform_tensor_descriptor
(
const
auto
wei_gemmk_gemmnraw_grid_desc
=
transform_tensor_descriptor
(
wei_bk0_bk1_zdotslice_ydotslice_xdotslice_c_grid_desc
,
wei_gemmk_zdotslice_ydotslice_xdotslice_c_grid_desc
,
make_tuple
(
make_tuple
(
make_merge_transform
(
make_tuple
(
ZDotSlice
,
YDotSlice
,
XDotSlice
,
K
)),
make_merge_transform
(
make_tuple
(
ZDotSlice
,
YDotSlice
,
XDotSlice
,
BK0
)),
make_pass_through_transform
(
C
)),
make_pass_through_transform
(
C
),
make_tuple
(
Sequence
<
1
,
2
,
3
,
0
>
{},
Sequence
<
4
>
{}),
make_pass_through_transform
(
BK1
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
2
,
3
,
4
,
0
>
{},
Sequence
<
5
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
wei_gemm
bk0
_gemmn_
gemmbk1
_grid_desc
=
const
auto
wei_gemm
k
_gemmn_
padded
_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc
,
wei_gemmk_gemmnraw_grid_desc
,
make_tuple
(
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc
.
GetLength
(
I0
),
make_tuple
(
GemmKPerBlock
,
GemmNPerBlock
),
GemmNPerBlock
,
Sequence
<
true
,
DoPadGemmN
>
{});
BK1
),
Sequence
<
false
,
DoPadGemmN
,
false
>
{});
return
wei_gemmbk0_gemmn_gemmbk1_grid_desc
;
const
index_t
BK0
=
wei_gemmk_gemmn_padded_grid_desc
.
GetLength
(
I0
)
/
BK1
;
const
auto
wei_gemmbk0_gemm_gemmbk1_grid_desc
=
transform_tensor_descriptor
(
wei_gemmk_gemmn_padded_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
wei_gemmk_gemmn_padded_grid_desc
.
GetLength
(
I1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
wei_gemmbk0_gemm_gemmbk1_grid_desc
;
}
}
else
else
{
{
...
...
include/ck/utility/amd_buffer_addressing.hpp
View file @
e1a5137e
...
@@ -1127,7 +1127,7 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
...
@@ -1127,7 +1127,7 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t
src_addr_shift
=
src_thread_element_valid
?
0
:
0x80000000
;
uint32_t
src_addr_shift
=
src_thread_element_valid
?
0
:
0x80000000
;
#if defined CK_ENABLE_FP8
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
{
{
auto
tmp
=
amd_buffer_load_impl
<
int8_t
,
vector_size
,
coherence
>
(
auto
tmp
=
amd_buffer_load_impl
<
int8_t
,
vector_size
,
coherence
>
(
...
@@ -1136,10 +1136,14 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
...
@@ -1136,10 +1136,14 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
}
}
else
else
{
{
#endif
return
amd_buffer_load_impl
<
scalar_t
,
vector_size
,
coherence
>
(
return
amd_buffer_load_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_wave_buffer_resource
,
src_addr_shift
+
src_thread_addr_offset
,
0
);
src_wave_buffer_resource
,
src_addr_shift
+
src_thread_addr_offset
,
0
);
#if defined CK_ENABLE_FP8
}
}
#endif
#else
#else
#if defined CK_ENABLE_FP8
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
{
{
auto
tmp
=
amd_buffer_load_impl
<
int8_t
,
vector_size
,
coherence
>
(
auto
tmp
=
amd_buffer_load_impl
<
int8_t
,
vector_size
,
coherence
>
(
...
@@ -1148,11 +1152,14 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
...
@@ -1148,11 +1152,14 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
}
}
else
else
{
{
#endif
vector_t
tmp
=
amd_buffer_load_impl
<
scalar_t
,
vector_size
,
coherence
>
(
vector_t
tmp
=
amd_buffer_load_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
);
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
);
return
src_thread_element_valid
?
tmp
:
vector_t
(
0
);
return
src_thread_element_valid
?
tmp
:
vector_t
(
0
);
#if defined CK_ENABLE_FP8
}
}
#endif
#endif
#endif
}
}
// buffer_load requires:
// buffer_load requires:
...
@@ -1209,7 +1216,7 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
...
@@ -1209,7 +1216,7 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t
dst_addr_shift
=
dst_thread_element_valid
?
0
:
0x80000000
;
uint32_t
dst_addr_shift
=
dst_thread_element_valid
?
0
:
0x80000000
;
#if defined CK_ENABLE_FP8
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
{
{
auto
tmp
=
auto
tmp
=
...
@@ -1219,12 +1226,16 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
...
@@ -1219,12 +1226,16 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
}
}
else
else
{
{
#endif
amd_buffer_store_impl
<
scalar_t
,
vector_size
,
coherence
>
(
amd_buffer_store_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_addr_shift
+
dst_thread_addr_offset
,
0
);
src_thread_data
,
dst_wave_buffer_resource
,
dst_addr_shift
+
dst_thread_addr_offset
,
0
);
#if defined CK_ENABLE_FP8
}
}
#endif
#else
#else
if
(
dst_thread_element_valid
)
if
(
dst_thread_element_valid
)
{
{
#if defined CK_ENABLE_FP8
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
{
{
auto
tmp
=
bit_cast
<
typename
vector_type_maker
<
int8_t
,
vector_size
>::
type
::
type
>
(
auto
tmp
=
bit_cast
<
typename
vector_type_maker
<
int8_t
,
vector_size
>::
type
::
type
>
(
...
@@ -1234,9 +1245,12 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
...
@@ -1234,9 +1245,12 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
}
}
else
else
{
{
#endif
amd_buffer_store_impl
<
scalar_t
,
vector_size
,
coherence
>
(
amd_buffer_store_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
0
);
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
0
);
#if defined CK_ENABLE_FP8
}
}
#endif
}
}
#endif
#endif
}
}
...
...
include/ck/utility/amd_gemm_dpp.hpp
View file @
e1a5137e
...
@@ -5,17 +5,63 @@
...
@@ -5,17 +5,63 @@
#include "ck/utility/common_header.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/utility/math.hpp"
#include "ck/utility/math.hpp"
#include "ck/utility/
amd_gemm
_dpp.hpp"
#include "ck/utility/
inner_product
_dpp
8
.hpp"
namespace
ck
{
namespace
ck
{
namespace
dpp8
{
namespace
dpp8
{
/// Number of lanes that can share data using DPP8 modifiers.
template
<
class
ABDataType
>
constexpr
index_t
lane_group_size
=
8
;
struct
dpp_datatypes
;
__device__
index_t
get_lane_group_local_idx
()
{
return
threadIdx
.
x
/
lane_group_size
;
}
template
<
>
__device__
index_t
get_thread_idx_in_lane_group
()
{
return
threadIdx
.
x
%
lane_group_size
;
}
struct
dpp_datatypes
<
half_t
>
{
// Dot product of `half2_t` and `half2_t` to get `float`. Reducing 2 elements from K in a
// single instruction.
using
a_dtype
=
half_t
;
using
b_dtype
=
half_t
;
using
c_dtype
=
float
;
static
constexpr
index_t
k_per_instr
=
2
;
};
template
<
index_t
MPerThread
,
index_t
NPerThread
,
index_t
KPerThread
,
class
BaseInputType
,
class
AVecDataType
,
class
BVecDataType
,
class
CVecDataType
,
bool
ShareA
>
struct
DppLanegroupGemm
{
using
datatypes_conf
=
dpp_datatypes
<
BaseInputType
>
;
using
ADataType
=
typename
datatypes_conf
::
a_dtype
;
using
BDataType
=
typename
datatypes_conf
::
b_dtype
;
using
CDataType
=
typename
datatypes_conf
::
c_dtype
;
__device__
void
Run
(
const
AVecDataType
&
a_vec
,
const
BVecDataType
&
b_vec
,
CVecDataType
&
c_vec
)
{
constexpr
index_t
num_c_elems_per_thread
=
ShareA
?
MPerThread
:
NPerThread
;
const
vector_type
<
ADataType
,
KPerThread
>
a_vector
{
a_vec
};
const
vector_type
<
BDataType
,
KPerThread
>
b_vector
{
b_vec
};
static_for
<
0
,
num_c_elems_per_thread
,
1
>
{}([
&
](
auto
c_idx
)
{
float
c
=
c_vec
.
template
AsType
<
CDataType
>()(
c_idx
);
// Next `c_idx` implies that we need to pull data from the next lane.
constexpr
index_t
source_lane
=
c_idx
;
static_for
<
0
,
KPerThread
/
datatypes_conf
::
k_per_instr
,
1
>
{}([
&
](
auto
k_chunk
)
{
const
auto
a_k_vec
=
a_vector
.
template
AsType
<
AVecDataType
>()[
k_chunk
];
const
auto
b_k_vec
=
b_vector
.
template
AsType
<
BVecDataType
>()[
k_chunk
];
ck
::
dpp8
::
inner_product_dpp
<
AVecDataType
,
BVecDataType
,
CDataType
,
source_lane
,
ShareA
>
(
a_k_vec
,
b_k_vec
,
c
);
});
c_vec
.
template
AsType
<
CDataType
>()(
c_idx
)
=
c
;
});
}
};
}
// namespace dpp8
}
// namespace dpp8
...
...
include/ck/utility/amd_xdlops.hpp
View file @
e1a5137e
...
@@ -355,6 +355,7 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
...
@@ -355,6 +355,7 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
}
}
};
};
#if defined CK_ENABLE_FP8
template
<
index_t
MPerWave
,
index_t
NPerWave
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x16f8f8
;
struct
intrin_mfma_f32_32x32x16f8f8
;
...
@@ -417,5 +418,6 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16>
...
@@ -417,5 +418,6 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16>
#endif
#endif
}
}
};
};
#endif
}
// namespace ck
}
// namespace ck
#endif
#endif
include/ck/utility/data_type.hpp
View file @
e1a5137e
...
@@ -12,7 +12,12 @@ using half_t = _Float16;
...
@@ -12,7 +12,12 @@ using half_t = _Float16;
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
using
int4_t
=
_BitInt
(
4
);
using
int4_t
=
_BitInt
(
4
);
#endif
#endif
using
f8_t
=
uint8_t
;
#if defined CK_ENABLE_FP8
using
f8_t
=
_BitInt
(
8
);
#endif
#if defined CK_ENABLE_BF8
using
bf8_t
=
unsigned
_BitInt
(
8
);
#endif
// vector_type
// vector_type
template
<
typename
T
,
index_t
N
>
template
<
typename
T
,
index_t
N
>
...
@@ -143,14 +148,24 @@ struct scalar_type<int4_t>
...
@@ -143,14 +148,24 @@ struct scalar_type<int4_t>
};
};
#endif
#endif
#if defined CK_ENABLE_FP8
template
<
>
template
<
>
struct
scalar_type
<
f8_t
>
struct
scalar_type
<
f8_t
>
{
{
using
type
=
f8_t
;
using
type
=
f8_t
;
static
constexpr
index_t
vector_size
=
1
;
static
constexpr
index_t
vector_size
=
1
;
};
};
#endif
#if defined CK_ENABLE_BF8
template
<
>
struct
scalar_type
<
bf8_t
>
{
using
type
=
bf8_t
;
static
constexpr
index_t
vector_size
=
1
;
};
#endif
//
template
<
typename
T
>
template
<
typename
T
>
struct
vector_type
<
T
,
1
>
struct
vector_type
<
T
,
1
>
{
{
...
@@ -953,12 +968,24 @@ using int8x32_t = typename vector_type<int8_t, 32>::type;
...
@@ -953,12 +968,24 @@ using int8x32_t = typename vector_type<int8_t, 32>::type;
using
int8x64_t
=
typename
vector_type
<
int8_t
,
64
>::
type
;
using
int8x64_t
=
typename
vector_type
<
int8_t
,
64
>::
type
;
// f8
// f8
#if defined CK_ENABLE_FP8
using
f8x2_t
=
typename
vector_type
<
f8_t
,
2
>::
type
;
using
f8x2_t
=
typename
vector_type
<
f8_t
,
2
>::
type
;
using
f8x4_t
=
typename
vector_type
<
f8_t
,
4
>::
type
;
using
f8x4_t
=
typename
vector_type
<
f8_t
,
4
>::
type
;
using
f8x8_t
=
typename
vector_type
<
f8_t
,
8
>::
type
;
using
f8x8_t
=
typename
vector_type
<
f8_t
,
8
>::
type
;
using
f8x16_t
=
typename
vector_type
<
f8_t
,
16
>::
type
;
using
f8x16_t
=
typename
vector_type
<
f8_t
,
16
>::
type
;
using
f8x32_t
=
typename
vector_type
<
f8_t
,
32
>::
type
;
using
f8x32_t
=
typename
vector_type
<
f8_t
,
32
>::
type
;
using
f8x64_t
=
typename
vector_type
<
f8_t
,
64
>::
type
;
using
f8x64_t
=
typename
vector_type
<
f8_t
,
64
>::
type
;
#endif
// bf8
#if defined CK_ENABLE_BF8
using
bf8x2_t
=
typename
vector_type
<
bf8_t
,
2
>::
type
;
using
bf8x4_t
=
typename
vector_type
<
bf8_t
,
4
>::
type
;
using
bf8x8_t
=
typename
vector_type
<
bf8_t
,
8
>::
type
;
using
bf8x16_t
=
typename
vector_type
<
bf8_t
,
16
>::
type
;
using
bf8x32_t
=
typename
vector_type
<
bf8_t
,
32
>::
type
;
using
bf8x64_t
=
typename
vector_type
<
bf8_t
,
64
>::
type
;
#endif
template
<
typename
T
>
template
<
typename
T
>
struct
NumericLimits
struct
NumericLimits
...
@@ -1006,21 +1033,109 @@ struct NumericLimits<int4_t>
...
@@ -1006,21 +1033,109 @@ struct NumericLimits<int4_t>
};
};
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#if defined CK_ENABLE_FP8
template
<
>
template
<
>
struct
NumericLimits
<
f8_t
>
struct
NumericLimits
<
f8_t
>
{
{
// negative zero nan mode with exp bias = 8
static
constexpr
uint8_t
binary_min
=
0x08
;
// 0b00001000
static
constexpr
uint8_t
binary_min
=
0x08
;
// 0b00001000
static
constexpr
uint8_t
binary_max
=
0x77
;
// 0b01110111
static
constexpr
uint8_t
binary_max
=
0x7F
;
// 0b01111111
static
constexpr
uint8_t
binary_lowest
=
0xF7
;
// 0b11110111
static
constexpr
uint8_t
binary_lowest
=
0xFF
;
// 0b11111111
static
constexpr
uint8_t
binary_qnan
=
0x80
;
// 0b10000000
// ieee mode with exp bias = 7
// static constexpr uint8_t binary_min = 0x08; // 0b00001000
// static constexpr uint8_t binary_max = 0x77; // 0b01110111
// static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=0
__host__
__device__
static
constexpr
f8_t
Min
()
{
return
f8_t
(
binary_min
);
}
__host__
__device__
static
constexpr
f8_t
Max
()
{
return
f8_t
(
binary_max
);
}
__host__
__device__
static
constexpr
f8_t
Lowest
()
{
return
f8_t
(
binary_lowest
);
}
__host__
__device__
static
constexpr
f8_t
QuietNaN
()
{
return
f8_t
(
binary_qnan
);
}
};
#endif
#if defined CK_ENABLE_BF8
template
<
>
struct
NumericLimits
<
bf8_t
>
{
// negative zero nan mode with exp bias = 16
static
constexpr
uint8_t
binary_min
=
0x04
;
// 0b00000100
static
constexpr
uint8_t
binary_max
=
0x7F
;
// 0b01111111
static
constexpr
uint8_t
binary_lowest
=
0xFF
;
// 0b11111111
static
constexpr
uint8_t
binary_qnan
=
0x80
;
// 0b10000000
static
constexpr
uint8_t
binary_qnan
=
0x80
;
// 0b10000000
// ieee mode with exp bias = 15
// static constexpr uint8_t binary_min = 0x04; // 0b00000100
// static constexpr uint8_t binary_max = 0x7B; // 0b01111011
// static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=
__host__
__device__
static
constexpr
f8_t
Min
()
{
return
b
it_cast
<
f8_t
>
(
binary_min
);
}
__host__
__device__
static
constexpr
b
f8_t
Min
()
{
return
bf8_t
(
binary_min
);
}
__host__
__device__
static
constexpr
f8_t
Max
()
{
return
b
it_cast
<
f8_t
>
(
binary_max
);
}
__host__
__device__
static
constexpr
b
f8_t
Max
()
{
return
bf8_t
(
binary_max
);
}
__host__
__device__
static
constexpr
f8_t
Lowest
()
{
return
b
it_cast
<
f8_t
>
(
binary_lowest
);
}
__host__
__device__
static
constexpr
b
f8_t
Lowest
()
{
return
bf8_t
(
binary_lowest
);
}
__host__
__device__
static
constexpr
f8_t
QuietNaN
()
{
return
b
it_cast
<
f8_t
>
(
binary_qnan
);
}
__host__
__device__
static
constexpr
b
f8_t
QuietNaN
()
{
return
bf8_t
(
binary_qnan
);
}
};
};
#endif
template
<
typename
T
>
struct
NumericUtils
{
};
template
<
>
struct
NumericUtils
<
float
>
{
static
constexpr
int
exp
=
8
;
static
constexpr
int
mant
=
23
;
static
constexpr
uint32_t
nan_mask
=
0x7F800000
;
static
constexpr
uint32_t
head_mask
=
0xFF800000
;
static
constexpr
uint32_t
mant_mask
=
0x7FFFFF
;
static
constexpr
uint32_t
exp_mask
=
0xFF
;
static
constexpr
uint32_t
Inf
=
0x7F800000
;
static
constexpr
uint32_t
NegInf
=
0xFF800000
;
static
constexpr
uint32_t
NaN
=
0x7F800001
;
static
constexpr
uint32_t
Neg0
=
0x80000000
;
using
bitwise_type
=
uint32_t
;
};
template
<
>
struct
NumericUtils
<
half_t
>
{
static
constexpr
int
exp
=
5
;
static
constexpr
int
mant
=
10
;
static
constexpr
uint16_t
nan_mask
=
0x7C00
;
static
constexpr
uint16_t
head_mask
=
0xFC00
;
static
constexpr
uint16_t
mant_mask
=
0x3FF
;
static
constexpr
uint16_t
exp_mask
=
0x1F
;
static
constexpr
uint32_t
Inf
=
0x7C00
;
static
constexpr
uint32_t
NegInf
=
0xFC00
;
static
constexpr
uint32_t
NaN
=
0x7C01
;
static
constexpr
uint32_t
Neg0
=
0x8000
;
using
bitwise_type
=
uint16_t
;
};
#if defined CK_ENABLE_FP8
template
<
>
struct
NumericUtils
<
f8_t
>
{
static
constexpr
int
exp
=
4
;
static
constexpr
int
mant
=
3
;
};
#endif
#if defined CK_ENABLE_BF8
template
<
>
struct
NumericUtils
<
bf8_t
>
{
static
constexpr
int
exp
=
5
;
static
constexpr
int
mant
=
2
;
};
#endif
}
// namespace ck
}
// namespace ck
include/ck/utility/f8_utils.hpp
View file @
e1a5137e
...
@@ -5,6 +5,9 @@
...
@@ -5,6 +5,9 @@
#include "ck/utility/data_type.hpp"
#include "ck/utility/data_type.hpp"
// these conversions are disabled if native conversions available
#if !defined(__gfx940__) && !defined(__gfx941__) && !defined(__gfx942__)
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
namespace
ck
{
namespace
ck
{
// fp8 rounding modes
// fp8 rounding modes
...
@@ -22,53 +25,38 @@ namespace ck::utils {
...
@@ -22,53 +25,38 @@ namespace ck::utils {
namespace
{
namespace
{
template
<
typename
T
,
bool
negative_zero_nan
,
bool
clip
,
bool
stoch
>
template
<
typename
X
,
typename
Y
,
bool
negative_zero_nan
,
bool
clip
,
bool
stoch
>
__host__
__device__
f8_t
run_cast_to_f8
(
T
x
,
uint32_t
rng
)
__host__
__device__
Y
run_cast_to_f8
(
X
x
,
uint32_t
rng
)
{
{
//
check data type
//
fp8/bf8 exponent/mantissa layout
constexpr
bool
is_half
=
std
::
is_same
<
T
,
half_t
>::
value
;
constexpr
int
out_exp
=
NumericUtils
<
Y
>::
exp
;
constexpr
bool
is_float
=
std
::
is_same
<
T
,
float
>::
value
;
constexpr
int
out_mant
=
NumericUtils
<
Y
>::
mant
;
// fp8 exponent/mantissa layout
// original type exponent/mantissa layout
constexpr
int
f8_exp
=
4
;
constexpr
int
in_exp
=
NumericUtils
<
X
>::
exp
;
constexpr
int
f8_mant
=
3
;
constexpr
int
in_mant
=
NumericUtils
<
X
>::
mant
;
// resulting type exponent/mantissa layout
constexpr
int
type_exp
=
is_half
?
5
:
8
;
constexpr
int
type_mant
=
is_half
?
10
:
23
;
int
exponent
;
int
exponent
;
uint32_t
head
,
mantissa
,
sign
;
uint32_t
head
,
mantissa
,
sign
;
// nan code is same for float and half
// nan code is same for float and half
constexpr
uint8_t
nan_code
=
0x80
;
constexpr
Y
nan_code
=
0x80
;
constexpr
uint32_t
nan_mask
=
is_half
?
0x7C00
:
0x7F800000
;
constexpr
uint32_t
nan_mask
=
NumericUtils
<
X
>::
nan_mask
;
// convert to bitwise
// convert to bitwise
typedef
typename
std
::
conditional
<
std
::
is_same
<
T
,
half_t
>::
value
,
uint16_t
,
uint32_t
>::
type
using
T_bitwise
=
typename
NumericUtils
<
X
>::
bitwise_type
;
T_bitwise
;
T_bitwise
x_bitwise
=
*
(
reinterpret_cast
<
T_bitwise
*>
(
&
x
));
T_bitwise
x_bitwise
=
*
(
reinterpret_cast
<
T_bitwise
*>
(
&
x
));
// unpack the input, depends on datatype
// unpack the input, depends on datatype
if
constexpr
(
is_float
)
head
=
x_bitwise
&
NumericUtils
<
X
>::
head_mask
;
{
mantissa
=
x_bitwise
&
NumericUtils
<
X
>::
mant_mask
;
head
=
x_bitwise
&
0xFF800000
;
exponent
=
(
head
>>
in_mant
)
&
NumericUtils
<
X
>::
exp_mask
;
mantissa
=
x_bitwise
&
0x7FFFFF
;
sign
=
head
>>
(
in_exp
+
in_mant
);
exponent
=
(
head
>>
type_mant
)
&
0xFF
;
sign
=
head
>>
(
type_exp
+
type_mant
);
uint32_t
signed_inf
=
(
sign
<<
(
in_exp
+
in_mant
))
+
(((
1
<<
in_exp
)
-
1
)
<<
in_mant
);
}
uint32_t
drop_mask
=
(
1
<<
(
in_mant
-
out_mant
))
-
1
;
else
if
constexpr
(
is_half
)
constexpr
int
max_exp
=
(
1
<<
out_exp
)
-
(
negative_zero_nan
?
1
:
2
);
{
head
=
x_bitwise
&
0xFC00
;
mantissa
=
x_bitwise
&
0x3FF
;
exponent
=
(
head
>>
type_mant
)
&
0x1F
;
sign
=
head
>>
(
type_exp
+
type_mant
);
}
uint32_t
signed_inf
=
(
sign
<<
(
type_exp
+
type_mant
))
+
(((
1
<<
type_exp
)
-
1
)
<<
type_mant
);
uint32_t
drop_mask
=
(
1
<<
(
type_mant
-
f8_mant
))
-
1
;
constexpr
int
max_exp
=
(
1
<<
f8_exp
)
-
(
negative_zero_nan
?
1
:
2
);
constexpr
int
exp_low_cutoff
=
constexpr
int
exp_low_cutoff
=
(
1
<<
(
type
_exp
-
1
))
-
(
1
<<
(
f8
_exp
-
1
))
+
1
-
(
negative_zero_nan
?
1
:
0
);
(
1
<<
(
in
_exp
-
1
))
-
(
1
<<
(
out
_exp
-
1
))
+
1
-
(
negative_zero_nan
?
1
:
0
);
if
constexpr
(
negative_zero_nan
)
if
constexpr
(
negative_zero_nan
)
{
{
...
@@ -81,22 +69,35 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
...
@@ -81,22 +69,35 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
return
signed_inf
+
(
mantissa
!=
0
?
1
:
0
);
return
signed_inf
+
(
mantissa
!=
0
?
1
:
0
);
}
}
// if input is half and output is bf8
if
((
NumericUtils
<
X
>::
mant
==
10
)
&&
(
NumericUtils
<
Y
>::
mant
==
2
)
&&
negative_zero_nan
&&
exponent
==
0
)
{
exponent
+=
1
;
while
(
mantissa
<
(
1
<<
in_mant
))
{
mantissa
<<=
1
;
exponent
-=
1
;
}
mantissa
&=
~
(
1
<<
in_mant
);
}
// check if x is 0.0
// check if x is 0.0
if
(
x_bitwise
==
0
)
if
(
x_bitwise
==
0
)
return
0
;
return
0
;
exponent
-=
exp_low_cutoff
-
1
;
exponent
-=
exp_low_cutoff
-
1
;
if
(
exponent
<=
0
)
if
(
exponent
<=
0
)
drop_mask
=
(
1
<<
(
type
_mant
-
f8
_mant
+
1
-
exponent
))
-
1
;
drop_mask
=
(
1
<<
(
in
_mant
-
out
_mant
+
1
-
exponent
))
-
1
;
mantissa
+=
1
<<
type
_mant
;
mantissa
+=
1
<<
in
_mant
;
// apply random number if needed
// apply random number if needed
mantissa
+=
(
stoch
?
rng
:
mantissa
)
&
drop_mask
;
mantissa
+=
(
stoch
?
rng
:
mantissa
)
&
drop_mask
;
if
(
mantissa
>=
(
2
<<
type
_mant
))
if
(
mantissa
>=
(
2
<<
in
_mant
))
{
{
mantissa
>>=
1
;
mantissa
>>=
1
;
exponent
++
;
exponent
++
;
}
}
mantissa
>>=
(
type
_mant
-
f8
_mant
);
mantissa
>>=
(
in
_mant
-
out
_mant
);
// check negative exponent
// check negative exponent
if
(
exponent
<=
0
)
if
(
exponent
<=
0
)
...
@@ -116,7 +117,7 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
...
@@ -116,7 +117,7 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
{
{
if
(
clip
)
if
(
clip
)
{
{
mantissa
=
(
1
<<
f8
_mant
)
-
1
;
mantissa
=
(
1
<<
out
_mant
)
-
1
;
exponent
=
max_exp
;
exponent
=
max_exp
;
}
}
else
else
...
@@ -127,124 +128,121 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
...
@@ -127,124 +128,121 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
// check if x is 0.0 or -0.0
// check if x is 0.0 or -0.0
if
(
exponent
==
0
&&
mantissa
==
0
)
if
(
exponent
==
0
&&
mantissa
==
0
)
return
negative_zero_nan
?
0
:
(
sign
<<
(
f8
_exp
+
f8
_mant
));
return
negative_zero_nan
?
0
:
(
sign
<<
(
out
_exp
+
out
_mant
));
mantissa
&=
(
1
<<
f8
_mant
)
-
1
;
mantissa
&=
(
1
<<
out
_mant
)
-
1
;
return
(
sign
<<
(
f8
_exp
+
f8
_mant
))
|
(
exponent
<<
f8
_mant
)
|
mantissa
;
return
(
sign
<<
(
out
_exp
+
out
_mant
))
|
(
exponent
<<
out
_mant
)
|
mantissa
;
}
}
template
<
typename
T
,
bool
negative_zero_nan
>
template
<
typename
X
,
typename
Y
,
bool
negative_zero_nan
>
__host__
__device__
T
run_cast_from_f8
(
f8_t
x
)
__host__
__device__
Y
run_cast_from_f8
(
X
x
)
{
{
// check data type
// fp8/bf8 exponent/mantissa layout
constexpr
bool
is_half
=
std
::
is_same
<
T
,
half_t
>::
value
;
constexpr
int
in_exp
=
NumericUtils
<
X
>::
exp
;
constexpr
bool
is_float
=
std
::
is_same
<
T
,
float
>::
value
;
constexpr
int
in_mant
=
NumericUtils
<
X
>::
mant
;
// fp8 exponent/mantissa layout
constexpr
int
f8_exp
=
4
;
constexpr
int
f8_mant
=
3
;
// resulting type exponent/mantissa layout
// resulting type exponent/mantissa layout
constexpr
int
type
_exp
=
is_half
?
5
:
8
;
constexpr
int
out
_exp
=
NumericUtils
<
Y
>::
exp
;
constexpr
int
type
_mant
=
is_half
?
10
:
23
;
constexpr
int
out
_mant
=
NumericUtils
<
Y
>::
mant
;
// prepare the codes
// prepare the codes
constexpr
uint8_t
nan_code
=
0x80
;
constexpr
X
nan_code
=
0x80
;
T
fInf
,
fNegInf
,
fNaN
,
fNeg0
;
Y
Inf
,
NegInf
,
NaN
,
Neg0
;
if
constexpr
(
is_half
)
using
T_bitwise
=
typename
NumericUtils
<
Y
>::
bitwise_type
;
{
constexpr
uint16_t
ihInf
=
0x7C00
;
constexpr
T_bitwise
Inf_bitwise
=
NumericUtils
<
Y
>::
Inf
;
constexpr
uint16_t
ihNegInf
=
0xFC00
;
constexpr
T_bitwise
NegInf_bitwise
=
NumericUtils
<
Y
>::
NegInf
;
constexpr
uint16_t
ihNaN
=
0x7C01
;
constexpr
T_bitwise
NaN_bitwise
=
NumericUtils
<
Y
>::
NaN
;
constexpr
uint16_t
ihNeg0
=
0x8000
;
constexpr
T_bitwise
Neg0_bitwise
=
NumericUtils
<
Y
>::
Neg0
;
fInf
=
*
(
reinterpret_cast
<
const
half_t
*>
(
&
ihInf
));
fNegInf
=
*
(
reinterpret_cast
<
const
half_t
*>
(
&
ihNegInf
));
Inf
=
*
(
reinterpret_cast
<
const
Y
*>
(
&
Inf_bitwise
));
fNaN
=
*
(
reinterpret_cast
<
const
half_t
*>
(
&
ihNaN
));
NegInf
=
*
(
reinterpret_cast
<
const
Y
*>
(
&
NegInf_bitwise
));
fNeg0
=
*
(
reinterpret_cast
<
const
half_t
*>
(
&
ihNeg0
));
NaN
=
*
(
reinterpret_cast
<
const
Y
*>
(
&
NaN_bitwise
));
}
Neg0
=
*
(
reinterpret_cast
<
const
Y
*>
(
&
Neg0_bitwise
));
else
if
constexpr
(
is_float
)
{
// check if x is 0.0
constexpr
uint32_t
ifInf
=
0x7F800000
;
if
(
x
==
0
)
constexpr
uint32_t
ifNegInf
=
0xFF800000
;
return
static_cast
<
Y
>
(
0
);
constexpr
uint32_t
ifNaN
=
0x7F800001
;
constexpr
uint32_t
ifNeg0
=
0x80000000
;
fInf
=
*
(
reinterpret_cast
<
const
float
*>
(
&
ifInf
));
fNegInf
=
*
(
reinterpret_cast
<
const
float
*>
(
&
ifNegInf
));
fNaN
=
*
(
reinterpret_cast
<
const
float
*>
(
&
ifNaN
));
fNeg0
=
*
(
reinterpret_cast
<
const
float
*>
(
&
ifNeg0
));
}
// unpack the input
// unpack the input
uint32_t
sign
=
x
>>
(
f8
_exp
+
f8
_mant
);
uint32_t
sign
=
x
>>
(
in
_exp
+
in
_mant
);
uint32_t
mantissa
=
x
&
((
1
<<
f8
_mant
)
-
1
);
uint32_t
mantissa
=
x
&
((
1
<<
in
_mant
)
-
1
);
int
exponent
=
(
x
&
0x7F
)
>>
f8
_mant
;
int
exponent
=
(
x
&
0x7F
)
>>
in
_mant
;
constexpr
int
exp_low_cutoff
=
constexpr
int
exp_low_cutoff
=
(
1
<<
(
type
_exp
-
1
))
-
(
1
<<
(
f8
_exp
-
1
))
+
1
-
(
negative_zero_nan
?
1
:
0
);
(
1
<<
(
out
_exp
-
1
))
-
(
1
<<
(
in
_exp
-
1
))
+
1
-
(
negative_zero_nan
?
1
:
0
);
typename
std
::
conditional
<
std
::
is_same
<
T
,
half_t
>::
value
,
uint16_t
,
uint32_t
>::
typ
e
retval
;
T_bitwis
e
retval
;
if
constexpr
(
negative_zero_nan
)
if
constexpr
(
negative_zero_nan
)
{
{
if
(
x
==
nan_code
)
if
(
x
==
nan_code
)
return
f
NaN
;
return
NaN
;
}
}
else
else
{
{
if
(
x
==
nan_code
)
if
(
x
==
nan_code
)
return
fNeg0
;
return
Neg0
;
if
(
exponent
==
((
1
<<
f8_exp
)
-
1
))
if
(
exponent
==
((
1
<<
in_exp
)
-
1
))
return
(
mantissa
==
0
)
?
(
sign
?
fNegInf
:
fInf
)
:
fNaN
;
return
(
mantissa
==
0
)
?
(
sign
?
NegInf
:
Inf
)
:
NaN
;
}
if
((
NumericUtils
<
Y
>::
mant
==
10
)
&&
(
NumericUtils
<
X
>::
mant
==
2
)
&&
!
negative_zero_nan
)
{
retval
=
x
;
retval
<<=
8
;
return
*
(
reinterpret_cast
<
const
Y
*>
(
&
retval
));
}
}
// subnormal input
// subnormal input
if
(
exponent
==
0
)
if
(
exponent
==
0
)
{
{
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
int
sh
=
1
+
__builtin_clz
(
mantissa
)
-
((
1
+
type_exp
+
type_mant
)
-
f8_mant
);
exponent
++
;
mantissa
<<=
sh
;
while
(
mantissa
<
(
1
<<
in_mant
))
mantissa
&=
((
1
<<
f8_mant
)
-
1
);
{
exponent
+=
1
-
sh
;
mantissa
<<=
1
;
exponent
--
;
}
mantissa
&=
((
1
<<
in_mant
)
-
1
);
}
}
exponent
+=
exp_low_cutoff
-
1
;
exponent
+=
exp_low_cutoff
-
1
;
mantissa
<<=
type
_mant
-
f8
_mant
;
mantissa
<<=
out
_mant
-
in
_mant
;
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
if
(
exponent
<=
0
)
if
(
exponent
<=
0
)
{
{
mantissa
|=
1
<<
type
_mant
;
mantissa
|=
1
<<
out
_mant
;
mantissa
>>=
1
-
exponent
;
mantissa
>>=
1
-
exponent
;
exponent
=
0
;
exponent
=
0
;
}
}
retval
=
(
sign
<<
(
type
_exp
+
type
_mant
))
|
(
exponent
<<
type
_mant
)
|
mantissa
;
retval
=
(
sign
<<
(
out
_exp
+
out
_mant
))
|
(
exponent
<<
out
_mant
)
|
mantissa
;
return
*
(
reinterpret_cast
<
const
T
*>
(
&
retval
));
return
*
(
reinterpret_cast
<
const
Y
*>
(
&
retval
));
}
}
}
// namespace
}
// namespace
template
<
typename
T
,
bool
negative_zero_nan
,
bool
clip
,
bool
stoch
>
template
<
typename
X
,
typename
Y
,
bool
negative_zero_nan
,
bool
clip
,
bool
stoch
>
__host__
__device__
f8_t
cast_to_f8
(
T
x
,
uint32_t
rng
)
__host__
__device__
Y
cast_to_f8
(
X
x
,
uint32_t
rng
)
{
{
// check datatype
// check datatype
s
constexpr
bool
is_half
=
std
::
is_same
<
T
,
half_t
>::
value
;
constexpr
bool
is_half
=
std
::
is_same
<
X
,
half_t
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
T
,
float
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
X
,
float
>::
value
;
static_assert
(
is_half
||
is_float
,
"Only half and float can be casted
to f8
."
);
static_assert
(
is_half
||
is_float
,
"Only half and float can be casted."
);
return
run_cast_to_f8
<
T
,
negative_zero_nan
,
clip
,
stoch
>
(
x
,
rng
);
return
run_cast_to_f8
<
X
,
Y
,
negative_zero_nan
,
clip
,
stoch
>
(
x
,
rng
);
}
}
template
<
typename
T
,
bool
negative_zero_nan
>
template
<
typename
X
,
typename
Y
,
bool
negative_zero_nan
>
__host__
__device__
T
cast_from_f8
(
f8_t
x
)
__host__
__device__
Y
cast_from_f8
(
X
x
)
{
{
// check datatype
// check datatype
constexpr
bool
is_half
=
std
::
is_same
<
T
,
half_t
>::
value
;
constexpr
bool
is_half
=
std
::
is_same
<
Y
,
half_t
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
T
,
float
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
Y
,
float
>::
value
;
static_assert
(
is_half
||
is_float
,
"only half and float are supported."
);
static_assert
(
is_half
||
is_float
,
"only half and float are supported."
);
// check if x is 0.0
return
run_cast_from_f8
<
X
,
Y
,
negative_zero_nan
>
(
x
);
if
(
x
==
0
)
return
static_cast
<
T
>
(
0
);
return
run_cast_from_f8
<
T
,
negative_zero_nan
>
(
x
);
}
}
}
// namespace ck::utils
}
// namespace ck::utils
#endif // #if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
#endif // #if !defined(__gfx940__) && !defined(__gfx941__) && !defined(__gfx942__)
include/ck/utility/inner_product.hpp
View file @
e1a5137e
...
@@ -72,6 +72,18 @@ inner_product<float4_t, float4_t, float>(const float4_t& a, const float4_t& b, f
...
@@ -72,6 +72,18 @@ inner_product<float4_t, float4_t, float>(const float4_t& a, const float4_t& b, f
c
);
c
);
}
}
template
<
>
__device__
void
inner_product
<
bhalf_t
,
bhalf_t
,
float
>
(
const
bhalf_t
&
a
,
const
bhalf_t
&
b
,
float
&
c
)
{
inner_product
(
type_convert
<
float
>
(
a
),
type_convert
<
float
>
(
b
),
c
);
}
template
<
>
__device__
void
inner_product
<
half_t
,
half_t
,
float
>
(
const
half_t
&
a
,
const
half_t
&
b
,
float
&
c
)
{
inner_product
(
type_convert
<
float
>
(
a
),
type_convert
<
float
>
(
b
),
c
);
}
template
<
>
template
<
>
__device__
void
inner_product
<
half2_t
,
half2_t
,
float
>
(
const
half2_t
&
a
,
const
half2_t
&
b
,
float
&
c
)
__device__
void
inner_product
<
half2_t
,
half2_t
,
float
>
(
const
half2_t
&
a
,
const
half2_t
&
b
,
float
&
c
)
{
{
...
...
include/ck/utility/inner_product_dpp8.hpp
View file @
e1a5137e
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include "amd_gemm_dpp.hpp"
#include "amd_gemm_dpp.hpp"
#include "data_type.hpp"
#include "data_type.hpp"
#include "type_convert.hpp"
#include "type_convert.hpp"
...
@@ -10,6 +11,9 @@ namespace ck {
...
@@ -10,6 +11,9 @@ namespace ck {
namespace
dpp8
{
namespace
dpp8
{
/// Number of lanes that can share data using DPP8 modifiers.
constexpr
index_t
lane_group_size
=
8
;
template
<
int
SrcLaneIdx
>
template
<
int
SrcLaneIdx
>
__device__
void
inline_v_dot2c_dpp8_instr
(
const
half2_t
&
a
,
const
half2_t
&
b
,
float
&
c
);
__device__
void
inline_v_dot2c_dpp8_instr
(
const
half2_t
&
a
,
const
half2_t
&
b
,
float
&
c
);
...
...
include/ck/utility/loop_scheduler.hpp
0 → 100644
View file @
e1a5137e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
namespace
ck
{
enum
struct
LoopScheduler
{
Default
,
Interwave
,
};
constexpr
LoopScheduler
make_default_loop_scheduler
()
{
#if CK_EXPERIMENTAL_DEFAULT_TO_INTER_WAVE_SCHEDULING
return
LoopScheduler
::
Interwave
;
#else
return
LoopScheduler
::
Default
;
#endif // if CK_EXPERIMENTAL_DEFAULT_TO_INTER_WAVE_SCHEDULING
}
}
// namespace ck
include/ck/utility/reduction_operator.hpp
View file @
e1a5137e
...
@@ -116,7 +116,15 @@ struct Max
...
@@ -116,7 +116,15 @@ struct Max
template
<
typename
T
>
template
<
typename
T
>
__host__
__device__
static
constexpr
T
GetIdentityValue
()
__host__
__device__
static
constexpr
T
GetIdentityValue
()
{
{
return
NumericLimits
<
T
>::
Lowest
();
if
constexpr
(
is_same_v
<
T
,
bhalf_t
>
)
{
float
val
=
NumericLimits
<
float
>::
Lowest
();
return
type_convert
<
bhalf_t
>
(
val
);
}
else
{
return
NumericLimits
<
T
>::
Lowest
();
}
};
};
__host__
__device__
static
constexpr
bool
__host__
__device__
static
constexpr
bool
...
@@ -138,6 +146,15 @@ struct Max
...
@@ -138,6 +146,15 @@ struct Max
a
=
b
;
a
=
b
;
}
}
__host__
__device__
inline
constexpr
void
operator
()(
bhalf_t
&
a
,
bhalf_t
b
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
if
(
a_
<
b_
)
a
=
b
;
}
template
<
typename
T
>
template
<
typename
T
>
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
,
bool
&
changed
)
const
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
,
bool
&
changed
)
const
{
{
...
@@ -152,6 +169,18 @@ struct Max
...
@@ -152,6 +169,18 @@ struct Max
changed
=
true
;
changed
=
true
;
}
}
}
}
__host__
__device__
inline
constexpr
void
operator
()(
bhalf_t
&
a
,
bhalf_t
b
,
bool
&
changed
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
if
(
a_
<
b_
)
{
a
=
b
;
changed
=
true
;
}
}
};
};
struct
Min
struct
Min
...
@@ -159,6 +188,15 @@ struct Min
...
@@ -159,6 +188,15 @@ struct Min
template
<
typename
T
>
template
<
typename
T
>
__host__
__device__
static
constexpr
T
GetIdentityValue
()
__host__
__device__
static
constexpr
T
GetIdentityValue
()
{
{
if
constexpr
(
is_same_v
<
T
,
bhalf_t
>
)
{
float
val
=
NumericLimits
<
float
>::
Max
();
return
type_convert
<
bhalf_t
>
(
val
);
}
else
{
return
NumericLimits
<
T
>::
Max
();
}
return
NumericLimits
<
T
>::
Max
();
return
NumericLimits
<
T
>::
Max
();
};
};
...
@@ -181,6 +219,15 @@ struct Min
...
@@ -181,6 +219,15 @@ struct Min
a
=
b
;
a
=
b
;
}
}
__host__
__device__
inline
constexpr
void
operator
()(
bhalf_t
&
a
,
bhalf_t
b
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
if
(
a_
>
b_
)
a
=
b
;
}
template
<
typename
T
>
template
<
typename
T
>
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
,
bool
&
changed
)
const
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
,
bool
&
changed
)
const
{
{
...
@@ -195,6 +242,18 @@ struct Min
...
@@ -195,6 +242,18 @@ struct Min
changed
=
true
;
changed
=
true
;
}
}
}
}
__host__
__device__
inline
constexpr
void
operator
()(
bhalf_t
&
a
,
bhalf_t
b
,
bool
&
changed
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
if
(
a_
>
b_
)
{
a
=
b
;
changed
=
true
;
}
}
};
};
struct
AMax
struct
AMax
...
...
Prev
1
…
3
4
5
6
7
8
9
10
11
…
19
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment