Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
3552041a
Commit
3552041a
authored
Jul 26, 2024
by
danyao12
Browse files
Merge branch 'develop' into ck_tile/fa_bwd_opt
parents
e8927110
733f33af
Changes
273
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3749 additions
and
1214 deletions
+3749
-1214
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp
...u/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp
+1694
-0
include/ck/tensor_operation/gpu/warp/smfmac_xdlops_gemm.hpp
include/ck/tensor_operation/gpu/warp/smfmac_xdlops_gemm.hpp
+409
-0
include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp
...erator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp
+45
-45
include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp
...eration/operator_transform/transform_conv_fwd_to_gemm.hpp
+997
-925
include/ck/utility/amd_smfmac.hpp
include/ck/utility/amd_smfmac.hpp
+28
-0
include/ck/utility/math_v2.hpp
include/ck/utility/math_v2.hpp
+2
-2
include/ck/utility/reduction_operator.hpp
include/ck/utility/reduction_operator.hpp
+19
-3
include/ck_tile/core/arch/amd_buffer_addressing.hpp
include/ck_tile/core/arch/amd_buffer_addressing.hpp
+333
-185
include/ck_tile/core/arch/arch.hpp
include/ck_tile/core/arch/arch.hpp
+3
-5
include/ck_tile/core/config.hpp
include/ck_tile/core/config.hpp
+9
-0
include/ck_tile/core/numeric/bfloat16.hpp
include/ck_tile/core/numeric/bfloat16.hpp
+4
-1
include/ck_tile/core/numeric/float8.hpp
include/ck_tile/core/numeric/float8.hpp
+2
-2
include/ck_tile/core/numeric/half.hpp
include/ck_tile/core/numeric/half.hpp
+1
-1
include/ck_tile/core/numeric/math.hpp
include/ck_tile/core/numeric/math.hpp
+1
-1
include/ck_tile/core/tensor/buffer_view.hpp
include/ck_tile/core/tensor/buffer_view.hpp
+34
-11
include/ck_tile/core/tensor/load_tile.hpp
include/ck_tile/core/tensor/load_tile.hpp
+13
-6
include/ck_tile/core/tensor/null_tile_window.hpp
include/ck_tile/core/tensor/null_tile_window.hpp
+2
-0
include/ck_tile/core/tensor/tensor_view.hpp
include/ck_tile/core/tensor/tensor_view.hpp
+15
-9
include/ck_tile/core/tensor/tile_elementwise.hpp
include/ck_tile/core/tensor/tile_elementwise.hpp
+48
-8
include/ck_tile/core/tensor/tile_window.hpp
include/ck_tile/core/tensor/tile_window.hpp
+90
-10
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp
0 → 100644
View file @
3552041a
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_ab_scale_selector.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp"
#define DEBUG_LOG 0
namespace
ck
{
// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same
// kernel function Blockers:
// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on
// two lds chunks.
// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds
// buffer when we declare __shared__ inside blkgemmpipe
template
<
typename
GridwiseGemm
,
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
index_t
MinimumOccupancy
=
1
,
TailNumber
TailNum
=
TailNumber
::
Full
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
MinimumOccupancy
)
#endif
// __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v3
(
typename
GridwiseGemm
::
Argument
karg
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
,
TailNum
>(
karg
.
p_a_grid
,
karg
.
p_b_grid
,
karg
.
p_ds_grid
,
karg
.
p_c_grid
,
karg
.
p_a_scale_grid
,
karg
.
p_b_scale_grid
,
p_shared
,
karg
,
karg
.
a_element_op
,
karg
.
b_element_op
,
karg
.
c_element_op
);
#else
ignore
=
karg
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template
<
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
CLayout
,
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
DsDataType
,
typename
CDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
tensor_operation
::
device
::
GemmSpecialization
GemmSpec
,
index_t
BlockSize
,
index_t
ScaleBlockM
,
index_t
ScaleBlockN
,
index_t
ScaleBlockK
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
AK1Value
,
index_t
BK1Value
,
index_t
MPerXdl
,
index_t
NPerXdl
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_AK1
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
index_t
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
index_t
BBlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CDEShuffleBlockTransferScalarPerVectors
,
BlockGemmPipelineScheduler
BlkGemmPipeSched
=
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
BlkGemmPipelineVer
=
BlockGemmPipelineVersion
::
v1
,
typename
ComputeTypeA
=
CDataType
,
typename
ComputeTypeB
=
ComputeTypeA
,
typename
LDSTypeA
=
ADataType
,
typename
LDSTypeB
=
BDataType
>
struct
GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
{
using
AScaleType
=
float
;
using
BScaleType
=
float
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
static
constexpr
auto
CShuffleBlockTransferScalarPerVector_NPerBlock
=
CDEShuffleBlockTransferScalarPerVectors
{}[
I0
];
// K1 should be Number<...>
static
constexpr
auto
AK0Number
=
Number
<
KPerBlock
/
AK1Value
>
{};
static
constexpr
auto
BK0Number
=
Number
<
KPerBlock
/
BK1Value
>
{};
static
constexpr
auto
AK1Number
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK1Number
=
Number
<
BK1Value
>
{};
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
constexpr
auto
MakeDsGridPointer
()
{
return
generate_tuple
(
[
&
](
auto
i
)
{
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
return
static_cast
<
const
DDataType
*>
(
nullptr
);
},
Number
<
NumDTensor
>
{});
}
using
DsGridPointer
=
decltype
(
MakeDsGridPointer
());
static
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1Number
,
BK1Number
),
MfmaSelector
<
ComputeTypeA
,
MPerXdl
,
NPerXdl
,
ComputeTypeB
>::
selected_mfma
.
k_per_blk
);
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
__host__
static
auto
CalculateGridSize
(
index_t
M
,
index_t
N
,
index_t
KBatch
)
{
return
std
::
make_tuple
(
Block2CTileMap
::
CalculateGridSize
(
M
,
N
),
1
,
KBatch
);
}
__host__
static
auto
CalculateMPadded
(
index_t
M
)
{
return
math
::
integer_least_multiple
(
M
,
MPerBlock
);
}
__host__
static
auto
CalculateNPadded
(
index_t
N
)
{
return
math
::
integer_least_multiple
(
N
,
NPerBlock
);
}
__host__
static
auto
CalculateKPadded
(
index_t
K
)
{
return
math
::
integer_divide_ceil
(
K
,
KPerBlock
)
*
KPerBlock
;
}
__host__
static
auto
CalculateAK0Padded
(
index_t
K
,
index_t
K_Batch
=
1
)
{
auto
K_t
=
K_Batch
*
KPerBlock
;
return
(
K
+
K_t
-
1
)
/
K_t
*
(
KPerBlock
/
AK1Value
);
}
__host__
static
auto
CalculateBK0Padded
(
index_t
K
,
index_t
K_Batch
=
1
)
{
auto
K_t
=
K_Batch
*
KPerBlock
;
return
(
K
+
K_t
-
1
)
/
K_t
*
(
KPerBlock
/
BK1Value
);
}
__host__
static
auto
CalculateKPadded
(
index_t
K
,
index_t
K_Batch
=
1
)
{
auto
K_t
=
K_Batch
*
KPerBlock
;
return
(
K
+
K_t
-
1
)
/
K_t
*
KPerBlock
;
}
__host__
static
auto
CalculateKRead
(
index_t
K
,
index_t
K_Batch
=
1
)
{
constexpr
auto
KReadVec
=
math
::
lcm
(
AK1Number
,
BK1Number
);
auto
K_t
=
K_Batch
*
KReadVec
;
return
(
K
+
K_t
-
1
)
/
K_t
*
KReadVec
;
}
__host__
static
auto
CalculateMBlock
(
index_t
M
)
{
return
math
::
integer_divide_ceil
(
M
,
MPerBlock
);
}
__host__
static
auto
CalculateNBlock
(
index_t
N
)
{
return
math
::
integer_divide_ceil
(
N
,
NPerBlock
);
}
template
<
index_t
MNXdlPerWave
,
index_t
MNWaves
,
index_t
MNPerXdl
,
typename
TileDesc_K0_MN_K1
>
__host__
__device__
static
constexpr
auto
MakeGemmMmaTileDescriptor
(
const
TileDesc_K0_MN_K1
&
)
{
constexpr
index_t
K0
=
TileDesc_K0_MN_K1
{}.
GetLength
(
Number
<
0
>
{});
constexpr
index_t
K1
=
TileDesc_K0_MN_K1
{}.
GetLength
(
Number
<
2
>
{});
return
transform_tensor_descriptor
(
TileDesc_K0_MN_K1
{},
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
K0
>
{},
Number
<
K1
>
{})),
make_unmerge_transform
(
make_tuple
(
Number
<
MNXdlPerWave
>
{},
Number
<
MNWaves
>
{},
Number
<
MNPerXdl
>
{}))),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
,
1
,
2
>
{}));
}
__device__
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
index_t
M
,
index_t
MPad
,
index_t
K
,
index_t
KPad
,
index_t
StrideA
,
index_t
AK0
)
{
const
auto
a_grid_desc_mraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
StrideA
,
I1
));
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
I1
,
StrideA
));
}
}();
using
GemmSpecialization
=
tensor_operation
::
device
::
GemmSpecialization
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad both M and K
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_right_pad_transform
(
M
,
MPad
-
M
),
make_right_pad_transform
(
K
,
KPad
-
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1Value
)),
make_pass_through_transform
(
MPad
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
// pad M, but not K
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1Value
)),
make_right_pad_transform
(
M
,
MPad
-
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
)
{
// pad K, but not M
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_pass_through_transform
(
M
),
make_right_pad_transform
(
K
,
KPad
-
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1Value
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
{
// not pad M or K
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1Value
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
}
__device__
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
index_t
K
,
index_t
KPad
,
index_t
N
,
index_t
NPad
,
index_t
StrideB
,
index_t
BK0
)
{
const
auto
b_grid_desc_nraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
N
,
K
),
make_tuple
(
I1
,
StrideB
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
N
,
K
),
make_tuple
(
StrideB
,
I1
));
}
}();
using
GemmSpecialization
=
tensor_operation
::
device
::
GemmSpecialization
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad both N and K
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_right_pad_transform
(
N
,
NPad
-
N
),
make_right_pad_transform
(
K
,
KPad
-
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1Value
)),
make_pass_through_transform
(
NPad
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
// pad N, but not K
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1Value
)),
make_right_pad_transform
(
N
,
NPad
-
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
)
{
// pad K, but not N
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_pass_through_transform
(
N
),
make_right_pad_transform
(
K
,
KPad
-
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1Value
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
{
// not pad N or K
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1Value
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
}
template
<
typename
ABlockDesc_AK0_M_AK1
>
__host__
__device__
static
constexpr
auto
MakeAMmaTileDescriptor_M0_M1_M2_K
(
const
ABlockDesc_AK0_M_AK1
&
)
{
constexpr
index_t
MWaves
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
return
MakeGemmMmaTileDescriptor
<
MXdlPerWave
,
MWaves
,
MPerXdl
>
(
ABlockDesc_AK0_M_AK1
{});
}
template
<
typename
BBlockDesc_BK0_N_BK1
>
__host__
__device__
static
constexpr
auto
MakeBMmaTileDescriptor_N0_N1_N2_K
(
const
BBlockDesc_BK0_N_BK1
&
)
{
constexpr
index_t
NWaves
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
return
MakeGemmMmaTileDescriptor
<
NXdlPerWave
,
NWaves
,
NPerXdl
>
(
BBlockDesc_BK0_N_BK1
{});
}
template
<
typename
ELayout
>
__host__
__device__
static
auto
MakeCGridDescriptor_M_N
(
index_t
M
,
index_t
MPad
,
index_t
N
,
index_t
NPad
,
index_t
StrideC
)
{
const
auto
c_grid_desc_mraw_nraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ELayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
StrideC
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
ELayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
I1
,
StrideC
));
}
}();
using
GemmSpecialization
=
tensor_operation
::
device
::
GemmSpecialization
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad M and N
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
make_tuple
(
make_right_pad_transform
(
M
,
MPad
-
M
),
make_right_pad_transform
(
N
,
NPad
-
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
)
{
// pad M, but not N
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
make_tuple
(
make_right_pad_transform
(
M
,
MPad
-
M
),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
)
{
// pad N, but not M
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
make_tuple
(
make_pass_through_transform
(
M
),
make_right_pad_transform
(
N
,
NPad
-
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
// not pad M or N
return
c_grid_desc_mraw_nraw
;
}
}
__host__
__device__
static
auto
MakeDsGridDescriptor_M_N
(
index_t
M
,
index_t
MPad
,
index_t
N
,
index_t
NPad
,
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
return
MakeCGridDescriptor_M_N
<
DLayout
>
(
M
,
MPad
,
N
,
NPad
,
StrideDs
[
i
]);
},
Number
<
NumDTensor
>
{});
}
template
<
typename
DsGridDesc
>
__device__
static
constexpr
auto
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
DsGridDesc
&
ds_grid_desc_m_n
,
index_t
MBlock
,
index_t
NBlock
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
return
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_grid_desc_m_n
[
i
],
MBlock
,
NBlock
);
},
Number
<
NumDTensor
>
{});
}
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
(
0
,
0
,
0
,
0
,
{}))
>
;
struct
Problem
{
__host__
Problem
(
index_t
M_
,
index_t
N_
,
index_t
K_
,
index_t
StrideA_
,
index_t
StrideB_
,
std
::
array
<
index_t
,
NumDTensor
>
StrideDs_
,
index_t
StrideC_
,
index_t
KBatch_
)
:
M
{
M_
},
N
{
N_
},
K
{
K_
},
StrideA
{
StrideA_
},
StrideB
{
StrideB_
},
StrideDs
{
StrideDs_
},
StrideC
{
StrideC_
},
KBatch
{
KBatch_
},
MPadded
{
CalculateMPadded
(
M_
)},
NPadded
{
CalculateNPadded
(
N_
)},
KRead
{
CalculateKRead
(
K_
,
KBatch_
)},
KPadded
{
CalculateKPadded
(
K_
,
KBatch_
)},
AK0
{
CalculateAK0Padded
(
K_
,
KBatch_
)},
BK0
{
CalculateBK0Padded
(
K_
,
KBatch_
)},
MBlock
{
CalculateMBlock
(
M_
)},
NBlock
{
CalculateNBlock
(
N_
)}
{
}
__host__
void
Print
()
const
{
std
::
cout
<<
"problem {"
<<
"M:"
<<
M
<<
", "
<<
"N:"
<<
N
<<
", "
<<
"K:"
<<
K
<<
", "
<<
"SA:"
<<
StrideA
<<
", "
<<
"SB:"
<<
StrideB
<<
", "
<<
"SC:"
<<
StrideC
<<
", "
<<
"MP:"
<<
MPadded
<<
", "
<<
"NP:"
<<
NPadded
<<
", "
<<
"KRead:"
<<
KRead
<<
", "
<<
"KP:"
<<
KPadded
<<
", "
<<
"AK0:"
<<
AK0
<<
", "
<<
"BK0:"
<<
BK0
<<
", "
<<
"MBlock: "
<<
MBlock
<<
", "
<<
"NBlock: "
<<
NBlock
<<
"}"
<<
std
::
endl
;
}
index_t
M
;
index_t
N
;
index_t
K
;
index_t
StrideA
;
index_t
StrideB
;
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
;
index_t
StrideC
;
index_t
KBatch
;
index_t
MPadded
;
index_t
NPadded
;
index_t
KRead
;
index_t
KPadded
;
index_t
AK0
;
index_t
BK0
;
index_t
MBlock
;
index_t
NBlock
;
};
// Argument
struct
Argument
:
public
tensor_operation
::
device
::
BaseArgument
,
public
Problem
{
__host__
Argument
(
const
ADataType
*
p_a_grid_
,
const
BDataType
*
p_b_grid_
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds_grid_
,
CDataType
*
p_c_grid_
,
index_t
M_
,
index_t
N_
,
index_t
K_
,
index_t
StrideA_
,
index_t
StrideB_
,
std
::
array
<
index_t
,
NumDTensor
>
StrideDs_
,
index_t
StrideC_
,
const
AScaleType
*
p_a_scale_grid_
,
const
BScaleType
*
p_b_scale_grid_
,
index_t
k_batch_
,
AElementwiseOperation
a_element_op_
,
BElementwiseOperation
b_element_op_
,
CElementwiseOperation
c_element_op_
)
:
Problem
{
M_
,
N_
,
K_
,
StrideA_
,
StrideB_
,
StrideDs_
,
StrideC_
,
k_batch_
},
p_a_grid
{
p_a_grid_
},
p_b_grid
{
p_b_grid_
},
p_ds_grid
{},
p_c_grid
{
p_c_grid_
},
p_a_scale_grid
{
p_a_scale_grid_
},
p_b_scale_grid
{
p_b_scale_grid_
},
a_element_op
{
a_element_op_
},
b_element_op
{
b_element_op_
},
c_element_op
{
c_element_op_
}
{
// populate pointer, desc for Ds
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DDataType_
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
// D pointer
p_ds_grid
(
i
)
=
static_cast
<
const
DDataType_
*>
(
p_ds_grid_
[
i
]);
});
}
const
ADataType
*
p_a_grid
;
const
BDataType
*
p_b_grid
;
DsGridPointer
p_ds_grid
;
CDataType
*
p_c_grid
;
const
AScaleType
*
p_a_scale_grid
;
const
BScaleType
*
p_b_scale_grid
;
const
AElementwiseOperation
a_element_op
;
const
BElementwiseOperation
b_element_op
;
const
CElementwiseOperation
c_element_op
;
};
struct
SplitKBatchOffset
{
__device__
SplitKBatchOffset
(
Argument
&
karg
)
{
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
{
a_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
;
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>
)
{
a_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
*
karg
.
M
;
}
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>
)
{
b_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
*
karg
.
N
;
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>
)
{
b_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
;
}
if
(
blockIdx
.
z
<
static_cast
<
uint32_t
>
(
karg
.
KBatch
-
1
))
{
karg
.
K
=
karg
.
KRead
;
}
else
{
karg
.
K
=
karg
.
K
-
karg
.
KRead
*
(
karg
.
KBatch
-
1
);
}
}
index_t
a_k_split_offset
;
index_t
b_k_split_offset
;
};
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
// A matrix in LDS memory, dst of blockwise copy
if
constexpr
(
ABlockLdsExtraM
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
AK0Number
,
Number
<
MPerBlock
>
{},
AK1Number
),
make_tuple
(
AK1Number
,
Number
<
KPerBlock
+
ABlockLdsExtraM
>
{},
I1
));
}
// xor tensor transformation request more unnecessary vgpr usage, would cause register spill
// in some cases.
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
{
constexpr
auto
MLdsLayer
=
32
*
4
/
KPerBlock
/
sizeof
(
LDSTypeA
)
<
1
?
1
:
32
*
4
/
KPerBlock
/
sizeof
(
LDSTypeA
);
constexpr
auto
a_lds_block_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
AK0Number
*
Number
<
MLdsLayer
>
{},
Number
<
MPerBlock
/
MLdsLayer
>
{},
AK1Number
),
make_tuple
(
AK1Number
,
Number
<
KPerBlock
*
MLdsLayer
>
{},
I1
));
constexpr
auto
a_lds_block_desc_permuted
=
transform_tensor_descriptor
(
a_lds_block_desc
,
make_tuple
(
make_xor_with_modulo_transform
(
make_tuple
(
Number
<
MPerBlock
/
MLdsLayer
>
{},
Number
<
AK0Number
*
MLdsLayer
>
{})),
make_pass_through_transform
(
AK1Number
)),
make_tuple
(
Sequence
<
1
,
0
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
1
,
0
>
{},
Sequence
<
2
>
{}));
constexpr
auto
a_lds_block_desc_ak0_mldslayer_m_ak1
=
transform_tensor_descriptor
(
a_lds_block_desc_permuted
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0Number
,
Number
<
MLdsLayer
>
{})),
make_pass_through_transform
(
Number
<
MPerBlock
/
MLdsLayer
>
{}),
make_pass_through_transform
(
AK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{}));
constexpr
auto
a_lds_block_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_lds_block_desc_ak0_mldslayer_m_ak1
,
make_tuple
(
make_pass_through_transform
(
AK0Number
),
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
MPerBlock
/
MLdsLayer
>
{},
Number
<
MLdsLayer
>
{})),
make_pass_through_transform
(
AK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
return
a_lds_block_desc_ak0_m_ak1
;
}
else
// ColumnMajor A
{
// kfold and mpair dimension is not always required.
// more dimension in merge_transform increase the difficulty of generating immarg offset
// for compiler.
constexpr
auto
M0
=
ABlockTransferThreadClusterLengths_AK0_M_AK1
{}.
At
(
I1
);
constexpr
auto
M1
=
MPerBlock
/
M0
;
constexpr
auto
KThreadWrite
=
ABlockTransferThreadClusterLengths_AK0_M_AK1
{}.
At
(
I0
);
constexpr
auto
K0PerThreadWrite
=
AK0Number
/
KThreadWrite
;
constexpr
auto
KThreadRead
=
64
/
MPerXdl
;
constexpr
auto
K0PerThreadRead
=
AK0Number
/
KThreadRead
;
constexpr
auto
kfold
=
(
AK1Number
*
M0
*
sizeof
(
LDSTypeA
)
>
128
)
?
1
:
128
/
(
AK1Number
*
M0
*
sizeof
(
LDSTypeA
));
constexpr
auto
KThreadReadPerm
=
(
kfold
*
K0PerThreadWrite
/
K0PerThreadRead
)
>
1
?
KThreadRead
/
(
kfold
*
K0PerThreadWrite
/
K0PerThreadRead
)
:
KThreadRead
;
// 1<=mpair<=n0
constexpr
auto
mpair
=
(
AK1Number
*
MPerXdl
*
sizeof
(
LDSTypeA
)
>
128
)
?
1
:
((
128
/
(
AK1Number
*
MPerXdl
*
sizeof
(
LDSTypeA
)))
>
M0
?
M0
:
128
/
(
AK1Number
*
MPerXdl
*
sizeof
(
LDSTypeA
)));
constexpr
auto
a_lds_block_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{},
Number
<
K0PerThreadWrite
>
{},
Number
<
KThreadReadPerm
*
M1
>
{},
Number
<
kfold
*
M0
/
mpair
>
{},
Number
<
mpair
>
{},
AK1Number
));
constexpr
auto
a_lds_block_desc_permuted
=
transform_tensor_descriptor
(
a_lds_block_desc
,
make_tuple
(
make_pass_through_transform
(
Number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
Number
<
K0PerThreadWrite
>
{}),
make_xor_with_modulo_transform
(
make_tuple
(
Number
<
KThreadReadPerm
*
M1
>
{},
Number
<
kfold
*
M0
/
mpair
>
{})),
make_pass_through_transform
(
Number
<
mpair
>
{}),
make_pass_through_transform
(
AK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}));
constexpr
auto
a_lds_block_desc_unmerged
=
transform_tensor_descriptor
(
a_lds_block_desc_permuted
,
make_tuple
(
make_pass_through_transform
(
Number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
Number
<
K0PerThreadWrite
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
KThreadReadPerm
>
{},
Number
<
M1
>
{})),
make_unmerge_transform
(
make_tuple
(
Number
<
kfold
>
{},
Number
<
M0
/
mpair
>
{})),
make_pass_through_transform
(
Number
<
mpair
>
{}),
make_pass_through_transform
(
AK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
0
,
3
>
{},
Sequence
<
4
,
5
>
{},
Sequence
<
6
>
{},
Sequence
<
7
>
{}));
constexpr
auto
a_lds_block_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_lds_block_desc_unmerged
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
KThreadReadPerm
>
{},
Number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{},
Number
<
kfold
>
{},
Number
<
K0PerThreadWrite
>
{})),
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
M0
/
mpair
>
{},
Number
<
mpair
>
{},
Number
<
M1
>
{})),
make_pass_through_transform
(
AK1Number
)),
make_tuple
(
Sequence
<
0
,
1
,
4
,
2
>
{},
Sequence
<
5
,
6
,
3
>
{},
Sequence
<
7
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
return
a_lds_block_desc_ak0_m_ak1
;
}
}
__device__
static
constexpr
auto
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
()
{
// B matrix in LDS memory, dst of blockwise copy
if
constexpr
(
BBlockLdsExtraN
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
BK0Number
,
Number
<
NPerBlock
>
{},
BK1Number
),
make_tuple
(
BK1Number
,
Number
<
KPerBlock
+
BBlockLdsExtraN
>
{},
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
// NLdsLayer * K0 as logical Bank
constexpr
auto
NLdsLayer
=
32
*
4
/
KPerBlock
/
sizeof
(
LDSTypeB
)
<
1
?
1
:
32
*
4
/
KPerBlock
/
sizeof
(
LDSTypeB
);
;
constexpr
auto
b_lds_block_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
BK0Number
*
Number
<
NLdsLayer
>
{},
Number
<
NPerBlock
/
NLdsLayer
>
{},
BK1Number
),
make_tuple
(
BK1Number
,
Number
<
KPerBlock
*
NLdsLayer
>
{},
I1
));
constexpr
auto
b_lds_block_desc_permuted
=
transform_tensor_descriptor
(
b_lds_block_desc
,
make_tuple
(
make_xor_with_modulo_transform
(
make_tuple
(
Number
<
NPerBlock
/
NLdsLayer
>
{},
Number
<
BK0Number
*
NLdsLayer
>
{})),
make_pass_through_transform
(
BK1Number
)),
make_tuple
(
Sequence
<
1
,
0
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
1
,
0
>
{},
Sequence
<
2
>
{}));
constexpr
auto
b_lds_block_desc_bk0_nldslayer_n_bk1
=
transform_tensor_descriptor
(
b_lds_block_desc_permuted
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0Number
,
Number
<
NLdsLayer
>
{})),
make_pass_through_transform
(
Number
<
NPerBlock
/
NLdsLayer
>
{}),
make_pass_through_transform
(
BK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{}));
constexpr
auto
b_lds_block_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_lds_block_desc_bk0_nldslayer_n_bk1
,
make_tuple
(
make_pass_through_transform
(
BK0Number
),
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
NPerBlock
/
NLdsLayer
>
{},
Number
<
NLdsLayer
>
{})),
make_pass_through_transform
(
BK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
return
b_lds_block_desc_bk0_n_bk1
;
}
else
// RowMajor B
{
constexpr
auto
N0
=
BBlockTransferThreadClusterLengths_BK0_N_BK1
{}.
At
(
I1
);
constexpr
auto
N1
=
NPerBlock
/
N0
;
constexpr
auto
KThreadWrite
=
BBlockTransferThreadClusterLengths_BK0_N_BK1
{}.
At
(
I0
);
constexpr
auto
K0PerThreadWrite
=
BK0Number
/
KThreadWrite
;
constexpr
auto
KThreadRead
=
64
/
NPerXdl
;
constexpr
auto
K0PerThreadRead
=
BK0Number
/
KThreadRead
;
constexpr
auto
kfold
=
(
BK1Number
*
N0
*
sizeof
(
LDSTypeB
)
>
128
)
?
1
:
128
/
(
BK1Number
*
N0
*
sizeof
(
LDSTypeB
));
constexpr
auto
KThreadReadPerm
=
(
kfold
*
K0PerThreadWrite
/
K0PerThreadRead
)
>
1
?
KThreadRead
/
(
kfold
*
K0PerThreadWrite
/
K0PerThreadRead
)
:
KThreadRead
;
// 1<=npair<=n0
constexpr
auto
npair
=
(
BK1Number
*
NPerXdl
*
sizeof
(
LDSTypeB
)
>
128
)
?
1
:
((
128
/
(
BK1Number
*
NPerXdl
*
sizeof
(
LDSTypeB
)))
>
N0
?
N0
:
128
/
(
BK1Number
*
NPerXdl
*
sizeof
(
LDSTypeB
)));
constexpr
auto
b_lds_block_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{},
Number
<
K0PerThreadWrite
>
{},
Number
<
KThreadReadPerm
*
N1
>
{},
Number
<
kfold
*
N0
/
npair
>
{},
Number
<
npair
>
{},
BK1Number
));
constexpr
auto
b_lds_block_desc_permuted
=
transform_tensor_descriptor
(
b_lds_block_desc
,
make_tuple
(
make_pass_through_transform
(
Number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
Number
<
K0PerThreadWrite
>
{}),
make_xor_with_modulo_transform
(
make_tuple
(
Number
<
KThreadReadPerm
*
N1
>
{},
Number
<
kfold
*
N0
/
npair
>
{})),
make_pass_through_transform
(
Number
<
npair
>
{}),
make_pass_through_transform
(
BK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}));
constexpr
auto
b_lds_block_desc_unmerged
=
transform_tensor_descriptor
(
b_lds_block_desc_permuted
,
make_tuple
(
make_pass_through_transform
(
Number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
Number
<
K0PerThreadWrite
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
KThreadReadPerm
>
{},
Number
<
N1
>
{})),
make_unmerge_transform
(
make_tuple
(
Number
<
kfold
>
{},
Number
<
N0
/
npair
>
{})),
make_pass_through_transform
(
Number
<
npair
>
{}),
make_pass_through_transform
(
BK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
0
,
3
>
{},
Sequence
<
4
,
5
>
{},
Sequence
<
6
>
{},
Sequence
<
7
>
{}));
constexpr
auto
b_lds_block_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_lds_block_desc_unmerged
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
KThreadReadPerm
>
{},
Number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{},
Number
<
kfold
>
{},
Number
<
K0PerThreadWrite
>
{})),
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
N0
/
npair
>
{},
Number
<
npair
>
{},
Number
<
N1
>
{})),
make_pass_through_transform
(
BK1Number
)),
make_tuple
(
Sequence
<
0
,
1
,
4
,
2
>
{},
Sequence
<
5
,
6
,
3
>
{},
Sequence
<
7
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
return
b_lds_block_desc_bk0_n_bk1
;
}
}
__device__
static
constexpr
auto
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
()
{
constexpr
index_t
MWave
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
>
{},
I1
,
Number
<
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
{}));
return
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
;
}
using
BlockwiseGemmPipe
=
remove_cvref_t
<
decltype
(
BlockGemmABScalePipeline_Selector
<
BlkGemmPipelineVer
,
BlkGemmPipeSched
,
BlockSize
,
LDSTypeA
,
LDSTypeB
,
ComputeTypeA
,
AccDataType
,
decltype
(
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()),
decltype
(
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
()),
decltype
(
MakeAMmaTileDescriptor_M0_M1_M2_K
(
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
())),
decltype
(
MakeBMmaTileDescriptor_N0_N1_N2_K
(
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
())),
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXdl
,
NPerXdl
,
MXdlPerWave
,
NXdlPerWave
,
KPack
>
())
>
;
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
// lds max alignment
constexpr
auto
max_lds_align
=
math
::
lcm
(
AK1Number
,
BK1Number
);
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_space_size_aligned
=
math
::
integer_least_multiple
(
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
// LDS allocation for C shuffle in LDS
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
constexpr
auto
c_block_size
=
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
return
math
::
max
((
a_block_space_size_aligned
*
sizeof
(
LDSTypeA
)
+
b_block_space_size_aligned
*
sizeof
(
LDSTypeB
)),
c_block_size
*
sizeof
(
CShuffleDataType
));
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
__host__
static
constexpr
bool
CheckValidity
(
const
Argument
&
karg
)
{
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
"Invalid tuning param!"
);
if
constexpr
(
!
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
))
{
if
(
!
(
karg
.
M
%
MPerBlock
==
0
))
{
#if DEBUG_LOG
std
::
cout
<<
"Arg M value is not a multiple of MPerBlock! M: "
<<
karg
.
M
<<
" "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
return
false
;
}
}
if
constexpr
(
!
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
))
{
if
(
!
(
karg
.
N
%
NPerBlock
==
0
))
{
#if DEBUG_LOG
std
::
cout
<<
"Arg N value is not a multiple of NPerBlock! N: "
<<
karg
.
N
<<
" "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
return
false
;
}
}
if
constexpr
(
!
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
KPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
))
{
auto
K_t
=
karg
.
KBatch
*
KPerBlock
;
if
(
!
(
karg
.
K
%
K_t
==
0
))
{
#if DEBUG_LOG
std
::
cout
<<
"Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
<<
karg
.
K
<<
" "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
return
false
;
}
}
else
{
constexpr
auto
KReadVec
=
math
::
lcm
(
AK1Number
,
BK1Number
);
auto
K_t
=
karg
.
KBatch
*
KReadVec
;
auto
KReadPadSplited
=
math
::
integer_divide_ceil
(
karg
.
K
,
K_t
)
*
KReadVec
;
if
((
KReadPadSplited
*
(
karg
.
KBatch
-
1
))
>=
karg
.
K
)
{
return
false
;
}
}
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
{
if
(
karg
.
K
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
#if DEBUG_LOG
std
::
cout
<<
"Arg K ("
<<
karg
.
K
<<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
<<
ABlockTransferSrcScalarPerVector
<<
" )! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
return
false
;
}
}
else
{
if
(
karg
.
M
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
#if DEBUG_LOG
std
::
cout
<<
"Arg M ("
<<
karg
.
M
<<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
<<
ABlockTransferSrcScalarPerVector
<<
" )! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
return
false
;
}
}
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
if
(
karg
.
N
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
#if DEBUG_LOG
std
::
cout
<<
"Arg N ("
<<
karg
.
N
<<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
<<
BBlockTransferSrcScalarPerVector
<<
" )! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
return
false
;
}
}
else
{
if
(
karg
.
K
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
#if DEBUG_LOG
std
::
cout
<<
"Arg K ("
<<
karg
.
K
<<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
<<
BBlockTransferSrcScalarPerVector
<<
" )! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
return
false
;
}
}
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
{
if
(
karg
.
N
%
CShuffleBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
#if DEBUG_LOG
std
::
cout
<<
"Arg N ("
<<
karg
.
N
<<
") value is not a multiple of "
"CShuffleBlockTransferScalarPerVector_NPerBlock ("
<<
CShuffleBlockTransferScalarPerVector_NPerBlock
<<
" )! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
return
false
;
}
}
else
{
if
(
karg
.
M
%
CShuffleBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
#if DEBUG_LOG
std
::
cout
<<
"Arg M ("
<<
karg
.
M
<<
") value is not a multiple of "
"CShuffleBlockTransferScalarPerVector_NPerBlock ("
<<
CShuffleBlockTransferScalarPerVector_NPerBlock
<<
" )! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
return
false
;
}
}
// check gridwise gemm pipeline
const
auto
num_k_loop
=
karg
.
AK0
/
(
KPerBlock
/
AK1Value
);
if
constexpr
(
BlkGemmPipelineVer
!=
BlockGemmPipelineVersion
::
v1
)
{
if
(
num_k_loop
<=
BlockwiseGemmPipe
::
PrefetchStages
)
{
return
false
;
}
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
true
;
}
__host__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
{
const
index_t
num_loop
=
K
/
KPerBlock
;
return
BlockwiseGemmPipe
::
BlockHasHotloop
(
num_loop
);
}
__host__
static
constexpr
TailNumber
CalculateKBlockLoopTailNum
(
index_t
K
)
{
const
index_t
num_loop
=
K
/
KPerBlock
;
return
BlockwiseGemmPipe
::
BlockLoopTailNum
(
num_loop
);
}
template
<
typename
CGridDesc
>
__device__
static
constexpr
auto
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
CGridDesc
&
c_grid_desc_m_n
,
index_t
MBlock
,
index_t
NBlock
)
{
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MPerBlock
>
{})),
make_unmerge_transform
(
make_tuple
(
NBlock
,
Number
<
NPerBlock
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
return
c_grid_desc_mblock_mperblock_nblock_nperblock
;
}
// return block_id to C matrix tile idx (m0, n0) mapping
// if arch = gfx942
using
Block2CTileMap
=
BlockToCTileMap_Grouped_M00_N0_M01Adapt
<
8
,
MPerBlock
,
NPerBlock
>
;
// using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>;
template
<
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
TailNumber
TailNum
=
TailNumber
::
Odd
>
__device__
static
void
Run
(
const
ADataType
*
p_a_grid
,
const
BDataType
*
p_b_grid
,
DsGridPointer
&
p_ds_grid
,
CDataType
*
p_c_grid
,
const
AScaleType
*
p_a_scale_grid
,
const
BScaleType
*
p_b_scale_grid
,
void
*
p_shared
,
const
Problem
&
problem
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
{
const
auto
a_grid_desc_ak0_m_ak1
=
MakeAGridDescriptor_AK0_M_AK1
(
problem
.
M
,
problem
.
MPadded
,
problem
.
K
,
problem
.
KPadded
,
problem
.
StrideA
,
problem
.
AK0
);
const
auto
b_grid_desc_bk0_n_bk1
=
MakeBGridDescriptor_BK0_N_BK1
(
problem
.
K
,
problem
.
KPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideB
,
problem
.
BK0
);
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
<
CLayout
>
(
problem
.
M
,
problem
.
MPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideC
);
const
auto
a_scale_grid_desc_am_ak
=
make_naive_tensor_descriptor
(
make_tuple
(
math
::
integer_divide_ceil
(
problem
.
M
,
ScaleBlockM
),
math
::
integer_divide_ceil
(
problem
.
K
,
ScaleBlockK
)),
make_tuple
(
math
::
integer_divide_ceil
(
problem
.
K
,
ScaleBlockK
),
1
));
const
auto
b_scale_grid_desc_bn_ak
=
make_naive_tensor_descriptor
(
make_tuple
(
math
::
integer_divide_ceil
(
problem
.
N
,
ScaleBlockN
),
math
::
integer_divide_ceil
(
problem
.
K
,
ScaleBlockK
)),
make_tuple
(
math
::
integer_divide_ceil
(
problem
.
K
,
ScaleBlockK
),
1
));
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
,
problem
.
MBlock
,
problem
.
NBlock
);
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
());
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
const
auto
a_scale_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_scale_grid
,
a_scale_grid_desc_am_ak
.
GetElementSpaceSize
());
const
auto
b_scale_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_scale_grid
,
b_scale_grid_desc_bn_ak
.
GetElementSpaceSize
());
// divide block work by [M, N]
const
auto
block_2_ctile_map
=
Block2CTileMap
{
problem
.
M
,
problem
.
N
,
4
};
const
auto
block_work_idx
=
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
if
(
!
block_2_ctile_map
.
ValidCTileIndex
(
block_work_idx
,
make_tuple
(
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I0
),
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I2
))))
{
return
;
}
const
index_t
block_m_id
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]);
const
index_t
block_n_id
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]);
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_m_id
*
MPerBlock
);
const
index_t
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_n_id
*
NPerBlock
);
// lds max alignment
constexpr
auto
max_lds_align
=
math
::
lcm
(
AK1Number
,
BK1Number
);
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
// A matrix blockwise copy
auto
a_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
AK0Number
,
MPerBlock
,
AK1Number
>
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ADataType
,
LDSTypeA
,
decltype
(
a_grid_desc_ak0_m_ak1
),
decltype
(
a_block_desc_ak0_m_ak1
),
ABlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
,
2
>
,
ABlockTransferSrcVectorDim
,
2
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
,
BlockwiseGemmPipe
::
GlobalBufferNum
>
(
a_grid_desc_ak0_m_ak1
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
a_block_desc_ak0_m_ak1
,
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
// B matrix blockwise copy
auto
b_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
BK0Number
,
NPerBlock
,
BK1Number
>
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BDataType
,
LDSTypeB
,
decltype
(
b_grid_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
BBlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
,
2
>
,
BBlockTransferSrcVectorDim
,
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
,
BlockwiseGemmPipe
::
GlobalBufferNum
>
(
b_grid_desc_bk0_n_bk1
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
b_block_desc_bk0_n_bk1
,
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
// Cast after lds
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
LDSTypeA
*>
(
p_shared
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
LDSTypeB
*>
(
p_shared
)
+
a_block_space_size_aligned
*
sizeof
(
LDSTypeA
)
/
sizeof
(
LDSTypeB
),
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1Number
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
BK1Number
,
0
,
0
);
// Blockwise GEMM pipeline
static_assert
(
std
::
is_default_constructible_v
<
BlockwiseGemmPipe
>
);
auto
blockwise_gemm_pipeline
=
BlockwiseGemmPipe
{};
auto
c_thread_buf
=
blockwise_gemm_pipeline
.
GetCThreadBuffer
();
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
KPerBlock
);
const
index_t
ScaleSliceSizeM
=
1
;
const
index_t
ScaleSliceSizeN
=
1
;
const
index_t
ScaleSliceSizeK
=
1
;
constexpr
auto
a_scale_thread_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
ScaleSliceSizeM
>
{},
Number
<
ScaleSliceSizeK
>
{}));
constexpr
auto
b_scale_thread_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
ScaleSliceSizeM
>
{},
Number
<
ScaleSliceSizeK
>
{}));
auto
a_scale_thread_copy
=
ThreadwiseTensorSliceTransfer_v2
<
AScaleType
,
AScaleType
,
decltype
(
a_scale_grid_desc_am_ak
),
decltype
(
a_scale_thread_desc
),
Sequence
<
ScaleSliceSizeM
,
ScaleSliceSizeK
>
,
Sequence
<
0
,
1
>
,
1
,
1
,
1
,
false
>
(
a_scale_grid_desc_am_ak
,
make_multi_index
(
block_m_id
*
MPerBlock
/
ScaleBlockM
,
0
));
auto
b_scale_thread_copy
=
ThreadwiseTensorSliceTransfer_v2
<
BScaleType
,
BScaleType
,
decltype
(
b_scale_grid_desc_bn_ak
),
decltype
(
b_scale_thread_desc
),
Sequence
<
ScaleSliceSizeN
,
ScaleSliceSizeK
>
,
Sequence
<
0
,
1
>
,
1
,
1
,
1
,
false
>
(
b_scale_grid_desc_bn_ak
,
make_multi_index
(
block_n_id
*
NPerBlock
/
ScaleBlockN
,
0
));
constexpr
auto
a_scale_thread_slice_copy_step
=
make_multi_index
(
0
,
1
);
constexpr
auto
b_scale_thread_slice_copy_step
=
make_multi_index
(
0
,
1
);
const
index_t
num_k_block_per_scale
=
ScaleBlockK
/
KPerBlock
;
blockwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
,
TailNum
>(
a_grid_desc_ak0_m_ak1
,
a_block_desc_ak0_m_ak1
,
a_blockwise_copy
,
a_grid_buf
,
a_block_buf
,
a_block_slice_copy_step
,
b_grid_desc_bk0_n_bk1
,
b_block_desc_bk0_n_bk1
,
b_blockwise_copy
,
b_grid_buf
,
b_block_buf
,
b_block_slice_copy_step
,
c_thread_buf
,
a_scale_grid_desc_am_ak
,
a_scale_thread_desc
,
a_scale_thread_copy
,
a_scale_grid_buf
,
a_scale_thread_slice_copy_step
,
b_scale_grid_desc_bn_ak
,
b_scale_thread_desc
,
b_scale_thread_copy
,
b_scale_grid_buf
,
b_scale_thread_slice_copy_step
,
num_k_block_main_loop
,
num_k_block_per_scale
);
// shuffle C and write out
{
static_assert
(
MXdlPerWave
%
CShuffleMXdlPerWavePerShuffle
==
0
&&
NXdlPerWave
%
CShuffleNXdlPerWavePerShuffle
==
0
,
"wrong!"
);
constexpr
index_t
MWave
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
// TODO: hacky, fix it!
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
blockwise_gemm_pipeline
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
=
blockwise_gemm_pipeline
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
constexpr
auto
M0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I0
);
constexpr
auto
N0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I1
);
constexpr
auto
M1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I2
);
constexpr
auto
N1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I3
);
constexpr
auto
M2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I4
);
constexpr
auto
M3
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I5
);
constexpr
auto
M4
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I6
);
constexpr
auto
N2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I7
);
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
auto
c_shuffle_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
CShuffleDataType
*>
(
p_shared
),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
transform_tensor_descriptor
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_tuple
(
make_freeze_transform
(
I0
),
make_unmerge_transform
(
make_tuple
(
Number
<
CShuffleMXdlPerWavePerShuffle
>
{},
// M0 (MXdlPerWave) per shuffle
M1
,
// M1 = MWave
M2
,
// M2 * M3 * M4 = MPerXdl
M3
,
M4
)),
make_freeze_transform
(
I0
),
make_unmerge_transform
(
make_tuple
(
Number
<
CShuffleNXdlPerWavePerShuffle
>
{},
// N0 (NXdlPerWave) per shuffle
N1
,
// N1 = NWave
N2
))),
// N2 = NPerXdl
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
,
2
,
4
,
5
,
6
>
{},
Sequence
<>
{},
Sequence
<
1
,
3
,
7
>
{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm_pipeline
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
,
I0
,
I0
);
const
index_t
m_thread_data_on_block
=
c_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_block
=
c_thread_mtx_on_block
[
I1
];
const
auto
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
m_thread_data_on_block_idx
=
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
m_thread_data_on_block
));
const
auto
n_thread_data_on_block_to_n0_n1_n2_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_on_block_idx
=
n_thread_data_on_block_to_n0_n1_n2_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_block
));
// shuffle: threadwise copy C from VGPR to LDS
auto
c_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
CShuffleDataType
,
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
decltype
(
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
I1
,
I1
,
M2
,
I1
,
M4
,
I1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
7
,
1
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
make_multi_index
(
0
,
0
,
m_thread_data_on_block_idx
[
I1
],
n_thread_data_on_block_idx
[
I1
],
m_thread_data_on_block_idx
[
I2
],
m_thread_data_on_block_idx
[
I3
],
m_thread_data_on_block_idx
[
I4
],
n_thread_data_on_block_idx
[
I2
]),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
using
EDataType
=
CDataType
;
const
auto
ds_grid_desc_m_n
=
MakeDsGridDescriptor_M_N
(
problem
.
M
,
problem
.
MPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideDs
);
const
auto
ds_grid_desc_mblock_mperblock_nblock_nperblock
=
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_grid_desc_m_n
,
problem
.
MBlock
,
problem
.
NBlock
);
const
auto
ds_grid_buf
=
generate_tuple
(
[
&
](
auto
i
)
{
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_ds_grid
[
i
],
ds_grid_desc_m_n
[
i
].
GetElementSpaceSize
());
},
Number
<
NumDTensor
>
{});
// tuple of reference to C/Ds tensor descriptors
const
auto
c_ds_desc_refs
=
concat_tuple_of_reference
(
tie
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
generate_tie
(
[
&
](
auto
i
)
->
const
auto
&
// return type should be reference
{
return
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
i
];
},
Number
<
NumDTensor
>
{}));
// tuple of reference to C/Ds tensor descriptors
const
auto
c_ds_buf_refs
=
concat_tuple_of_reference
(
tie
(
c_shuffle_block_buf
),
generate_tie
(
[
&
](
auto
i
)
->
const
auto
&
// return type should be reference
{
return
ds_grid_buf
[
i
];
},
Number
<
NumDTensor
>
{}));
// tuple of starting index of C/Ds blockwise copy
const
auto
idx_c_ds_block_begin
=
container_concat
(
make_tuple
(
make_multi_index
(
0
,
0
,
0
,
0
)),
generate_tuple
(
[
&
](
auto
)
{
return
make_multi_index
(
block_work_idx
[
I0
],
0
,
block_work_idx
[
I1
],
0
);
},
Number
<
NumDTensor
>
{}));
const
auto
e_grid_desc_mblock_mperblock_nblock_nperblock
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
using
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
=
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
;
const
auto
EGlobalMemoryDataOperation
=
CGlobalMemoryDataOperation
;
auto
cde_block_copy_lds_and_global
=
ThreadGroupTensorSliceTransfer_v7r3
<
ThisThreadBlock
,
decltype
(
container_concat
(
make_tuple
(
CShuffleDataType
{}),
DsDataType
{})),
Tuple
<
EDataType
>
,
decltype
(
c_ds_desc_refs
),
decltype
(
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
)),
CElementwiseOperation
,
Sequence
<
static_cast
<
index_t
>
(
EGlobalMemoryDataOperation
)
>
,
// FIXME: make Sequence
// support arbitray type
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
,
// BlockSliceLengths,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename SrcDimAccessOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DstDimAccessOrder,
3
,
// index_t SrcVectorDim,
3
,
// index_t DstVectorDim,
CDEShuffleBlockTransferScalarPerVectors
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
sequence_merge_t
<
Sequence
<
true
>
,
uniform_sequence_gen_t
<
NumDTensor
,
false
>>
,
// ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence
<
false
>>
// ThreadTransferDstResetCoordinateAfterRunFlags
{
c_ds_desc_refs
,
idx_c_ds_block_begin
,
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
make_tuple
(
make_multi_index
(
block_m_id
,
0
,
block_n_id
,
0
)),
c_element_op
};
// space filling curve for threadwise C in VGPR
constexpr
auto
sfc_c_vgpr
=
SpaceFillingCurve
<
Sequence
<
MXdlPerWave
,
NXdlPerWave
,
1
,
1
,
M2
,
1
,
M4
,
1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
Sequence
<
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
1
,
1
,
M2
,
1
,
M4
,
1
>>
{};
constexpr
index_t
num_access
=
sfc_c_vgpr
.
GetNumOfAccess
();
// space filling curve for shuffled blockwise C/D/E
constexpr
auto
sfc_cde_block
=
SpaceFillingCurve
<
Sequence
<
1
,
MPerBlock
,
1
,
NPerBlock
>
,
Sequence
<
0
,
2
,
1
,
3
>
,
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>>
{};
static_assert
(
num_access
==
sfc_cde_block
.
GetNumOfAccess
(),
"wrong!"
);
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
access_id
)
{
// make sure it's safe to write to LDS
block_sync_lds
();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
sfc_c_vgpr
.
GetIndexTupleOfNumber
(
access_id
),
c_thread_buf
,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_shuffle_block_buf
);
// make sure it's safe to read from LDS
block_sync_lds
();
// each block copy its data from LDS to global
cde_block_copy_lds_and_global
.
Run
(
c_ds_desc_refs
,
c_ds_buf_refs
,
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
tie
(
c_grid_buf
));
if
constexpr
(
access_id
<
num_access
-
1
)
{
constexpr
auto
cde_lds_and_global_step
=
sfc_cde_block
.
GetForwardStep
(
access_id
);
// move on Ds
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
cde_block_copy_lds_and_global
.
MoveSrcSliceWindow
(
c_ds_desc_refs
,
i
+
I1
,
cde_lds_and_global_step
);
});
// move on E
cde_block_copy_lds_and_global
.
MoveDstSliceWindow
(
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
I0
,
cde_lds_and_global_step
);
}
});
}
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/warp/smfmac_xdlops_gemm.hpp
0 → 100644
View file @
3552041a
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/math.hpp"
#include "ck/utility/amd_smfmac.hpp"
namespace
ck
{
enum
struct
SmfmacInstr
{
smfmac_f32_16x16x32f16
=
0
,
smfmac_f32_32x32x16f16
,
smfmac_f32_16x16x32bf16
,
smfmac_f32_32x32x16bf16
,
};
template
<
SmfmacInstr
instr
>
struct
smfmac_type
;
template
<
>
struct
smfmac
<
SmfmacInstr
::
smfmac_f32_16x16x32f16
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
1
;
static
constexpr
index_t
num_regs_per_blk
=
4
;
static
constexpr
index_t
num_threads_per_blk
=
16
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
4
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
16
;
static
constexpr
index_t
n_per_blk
=
16
;
static
constexpr
index_t
k_per_blk
=
8
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
const
int32_t
&
idx
,
FloatC
&
reg_c
)
const
{
intrin_smfmac_f32_16x16x32f16
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
idx
,
reg_c
);
}
};
template
<
>
struct
smfmac
<
SmfmacInstr
::
smfmac_f32_32x32x16f16
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
4
;
static
constexpr
index_t
num_regs_per_blk
=
16
;
static
constexpr
index_t
num_threads_per_blk
=
32
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
2
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
32
;
static
constexpr
index_t
n_per_blk
=
32
;
static
constexpr
index_t
k_per_blk
=
16
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
const
int32_t
&
idx
,
FloatC
&
reg_c
)
const
{
intrin_smfmac_f32_32x32x16f16
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
idx
,
reg_c
);
}
};
template
<
>
struct
smfmac
<
SmfmacInstr
::
smfmac_f32_16x16x32bf16
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
1
;
static
constexpr
index_t
num_regs_per_blk
=
4
;
static
constexpr
index_t
num_threads_per_blk
=
16
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
4
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
16
;
static
constexpr
index_t
n_per_blk
=
16
;
static
constexpr
index_t
k_per_blk
=
8
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
const
int32_t
&
idx
,
FloatC
&
reg_c
)
const
{
intrin_smfmac_f32_16x16x32bf16
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
idx
,
reg_c
);
}
};
template
<
>
struct
smfmac
<
SmfmacInstr
::
smfmac_f32_32x32x16bf16
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
4
;
static
constexpr
index_t
num_regs_per_blk
=
16
;
static
constexpr
index_t
num_threads_per_blk
=
32
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
2
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
32
;
static
constexpr
index_t
n_per_blk
=
32
;
static
constexpr
index_t
k_per_blk
=
16
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
const
int32_t
&
idx
,
FloatC
&
reg_c
)
const
{
intrin_smfmac_f32_32x32x16bf16
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
idx
,
reg_c
);
}
};
template
<
typename
base_type
,
index_t
MPerXdlops
,
index_t
NPerXdlops
,
typename
additional_type
=
base_type
>
struct
SmfmacSelector
{
template
<
typename
base_type_
,
index_t
MPerXdlops_
,
index_t
NPerXdlops_
,
typename
additional_type_
=
base_type_
>
static
constexpr
auto
GetSmfmac
();
template
<
>
static
constexpr
auto
GetSmfmac
<
half_t
,
16
,
16
>
()
{
return
SmfmacInstr
::
smfmac_f32_16x16x32f16
;
}
template
<
>
static
constexpr
auto
GetSmfmac
<
half_t
,
32
,
32
>
()
{
return
SmfmacInstr
::
smfmac_f32_32x32x16f16
;
}
template
<
>
static
constexpr
auto
GetSmfmac
<
bhalf_t
,
16
,
16
>
()
{
return
SmfmacInstr
::
smfmac_f32_16x16x32bf16
;
}
template
<
>
static
constexpr
auto
GetSmfmac
<
bhalf_t
,
32
,
32
>
()
{
return
SmfmacInstr
::
smfmac_f32_32x32x16bf16
;
}
static
constexpr
auto
selected_smfmac
=
smfmac_type
<
GetSmfmac
<
base_type
,
MPerXdlops
,
NPerXdlops
,
additional_type
>
()
>
{};
__host__
__device__
constexpr
SmfmacSelector
()
{
static_assert
(
selected_smfmac
.
group_size
*
selected_smfmac
.
num_groups_per_blk
==
selected_smfmac
.
num_regs_per_blk
,
"wrong! num_regs_per_blk"
);
static_assert
(
selected_smfmac
.
num_threads_per_blk
==
selected_smfmac
.
n_per_blk
,
"n_per_blk != num_threads_per_blk"
);
static_assert
(
selected_smfmac
.
num_regs_per_blk
*
selected_smfmac
.
num_input_blks
==
selected_smfmac
.
m_per_blk
,
"m_per_blk != num_input_blks * num_regs_per_blk"
);
static_assert
(
selected_smfmac
.
num_output_blks
==
selected_smfmac
.
num_input_blks
||
selected_smfmac
.
num_output_blks
==
1
,
"incorrect num_output_blks"
);
static_assert
(
selected_smfmac
.
num_regs_per_blk
*
selected_smfmac
.
wave_size
==
selected_smfmac
.
m_per_blk
*
selected_smfmac
.
n_per_blk
,
"num_regs_per_blk incorrect"
);
static_assert
(
selected_smfmac
.
is_k_reduction
||
(
selected_smfmac
.
num_input_blks
==
selected_smfmac
.
num_output_blks
),
"is_k_reduction wrong!"
);
}
static
constexpr
index_t
GetKPerXdlops
()
{
return
(
selected_smfmac
.
is_k_reduction
?
selected_smfmac
.
num_input_blks
:
1
)
*
selected_smfmac
.
k_per_blk
;
}
static
constexpr
index_t
GetK1PerXdlops
()
{
return
selected_smfmac
.
k_per_blk
;
}
};
template
<
typename
base_type
,
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
KPack
,
typename
additional_type
=
base_type
>
struct
SparseXdlopsGemm
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
using
CIndex
=
MultiIndex
<
2
>
;
using
CIndex4D
=
MultiIndex
<
4
>
;
__device__
static
constexpr
index_t
GetNumBlks
()
{
return
smfmac_instr
.
num_output_blks
;
}
__device__
static
constexpr
index_t
GetNumXdlops
()
{
return
MPerXdlops
*
NPerXdlops
/
(
smfmac_instr
.
m_per_blk
*
smfmac_instr
.
n_per_blk
*
smfmac_instr
.
num_output_blks
);
}
__host__
__device__
constexpr
SparseXdlopsGemm
()
{
static_assert
(
NPerXdlops
==
16
||
NPerXdlops
==
32
,
"Only support GemmNPerXdlops == 16 or 32 for smfmac xdlops"
);
static_assert
(
MPerXdlops
==
16
||
MPerXdlops
==
32
,
"Only support GemmMPerXdlops == 16 or 32 for smfmac xdlops"
);
static_assert
(
KPack
%
smfmac_instr
.
k_per_blk
==
0
,
"KPack cannot be divided by k_per_blk"
);
}
// XDL output supporting C = A * B
// M2_N2 -> M2_M3_M4_N2
template
<
typename
CDesc_M0_N0_M1_N1_M2_N2
>
__host__
__device__
static
constexpr
auto
MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
const
CDesc_M0_N0_M1_N1_M2_N2
&
c_desc_m0_n0_m1_n1_m2_n2
)
{
const
auto
M0
=
c_desc_m0_n0_m1_n1_m2_n2
.
GetLength
(
I0
);
const
auto
N0
=
c_desc_m0_n0_m1_n1_m2_n2
.
GetLength
(
I1
);
const
auto
M1
=
c_desc_m0_n0_m1_n1_m2_n2
.
GetLength
(
I2
);
const
auto
N1
=
c_desc_m0_n0_m1_n1_m2_n2
.
GetLength
(
I3
);
return
transform_tensor_descriptor
(
c_desc_m0_n0_m1_n1_m2_n2
,
make_tuple
(
make_pass_through_transform
(
M0
),
make_pass_through_transform
(
N0
),
make_pass_through_transform
(
M1
),
make_pass_through_transform
(
N1
),
make_unmerge_transform
(
make_tuple
(
Number
<
smfmac_instr
.
num_groups_per_blk
>
{},
Number
<
smfmac_instr
.
num_input_blks
>
{},
Number
<
smfmac_instr
.
group_size
>
{})),
make_pass_through_transform
(
Number
<
smfmac_instr
.
num_threads_per_blk
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
,
5
,
6
>
{},
Sequence
<
7
>
{}));
}
template
<
typename
CDesc_G_M0_N0_M1_N1_M2_N2
>
__host__
__device__
static
constexpr
auto
MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
(
const
CDesc_G_M0_N0_M1_N1_M2_N2
&
c_desc_g_m0_n0_m1_n1_m2_n2
)
{
const
auto
G
=
c_desc_g_m0_n0_m1_n1_m2_n2
.
GetLength
(
I0
);
const
auto
M0
=
c_desc_g_m0_n0_m1_n1_m2_n2
.
GetLength
(
I1
);
const
auto
N0
=
c_desc_g_m0_n0_m1_n1_m2_n2
.
GetLength
(
I2
);
const
auto
M1
=
c_desc_g_m0_n0_m1_n1_m2_n2
.
GetLength
(
I3
);
const
auto
N1
=
c_desc_g_m0_n0_m1_n1_m2_n2
.
GetLength
(
I4
);
return
transform_tensor_descriptor
(
c_desc_g_m0_n0_m1_n1_m2_n2
,
make_tuple
(
make_pass_through_transform
(
G
),
make_pass_through_transform
(
M0
),
make_pass_through_transform
(
N0
),
make_pass_through_transform
(
M1
),
make_pass_through_transform
(
N1
),
make_unmerge_transform
(
make_tuple
(
smfmac_instr
.
num_groups_per_blk
,
smfmac_instr
.
num_input_blks
,
smfmac_instr
.
group_size
)),
make_pass_through_transform
(
smfmac_instr
.
num_threads_per_blk
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{},
Sequence
<
6
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
,
6
,
7
>
{},
Sequence
<
8
>
{}));
}
__device__
static
constexpr
index_t
GetRegSizePerXdlops
()
{
return
MPerXdlops
*
NPerXdlops
/
smfmac_instr
.
wave_size
;
}
__device__
static
constexpr
index_t
GetWaveSize
()
{
return
smfmac_instr
.
wave_size
;
}
template
<
class
FloatA
,
class
FloatB
,
class
Idx
,
class
FloatC
>
__device__
void
Run
(
const
FloatA
&
p_a_wave
,
const
FloatB
&
p_b_wave
,
const
Idx
&
idx
,
FloatC
&
p_c_thread
)
const
{
static_assert
(
is_same
<
base_type
,
half_t
>::
value
||
is_same
<
base_type
,
bhalf_t
>::
value
,
"base base_type must be half or bfloat16!"
);
static_for
<
0
,
KPack
/
smfmac_instr
.
k_per_blk
,
1
>
{}([
&
](
auto
k
)
{
smfmac_instr
.
template
run
<
MPerXdlops
,
NPerXdlops
>(
p_a_wave
[
k
],
p_b_wave
[
k
],
idx
[
k
],
p_c_thread
);
});
}
__device__
static
auto
GetLaneId
()
{
return
get_thread_local_1d_id
()
%
smfmac_instr
.
wave_size
;
}
__device__
static
auto
GetBlkIdx
()
{
const
auto
laneId
=
GetLaneId
();
constexpr
auto
threadidx_to_blk_idx_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
1
,
smfmac_instr
.
num_input_blks
,
smfmac_instr
.
num_threads_per_blk
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
blk_idx
=
threadidx_to_blk_idx_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
laneId
));
const
auto
blk_id
=
blk_idx
[
I1
];
const
auto
blk_td
=
blk_idx
[
I2
];
return
make_tuple
(
blk_id
,
blk_td
);
}
__host__
__device__
static
auto
CalculateAThreadOriginDataIndex
()
{
const
auto
laneId
=
GetLaneId
();
const
auto
blk_idx
=
GetBlkIdx
();
const
auto
blk_id
=
blk_idx
[
I0
];
const
auto
blk_td
=
blk_idx
[
I1
];
if
constexpr
(
smfmac_instr
.
is_k_reduction
)
{
return
make_tuple
(
blk_id
,
blk_td
);
}
else
{
return
make_tuple
(
0
,
laneId
);
}
}
__host__
__device__
static
auto
CalculateBThreadOriginDataIndex
()
{
const
auto
laneId
=
GetLaneId
();
const
auto
blk_idx
=
GetBlkIdx
();
const
auto
blk_id
=
blk_idx
[
I0
];
const
auto
blk_td
=
blk_idx
[
I1
];
if
constexpr
(
smfmac_instr
.
is_k_reduction
)
{
return
make_tuple
(
blk_id
,
blk_td
);
}
else
{
return
make_tuple
(
0
,
laneId
);
}
}
__device__
static
CIndex
GetBeginOfThreadBlk
(
index_t
xdlops_i
,
index_t
blk_i
)
{
const
auto
blk_idx
=
GetBlkIdx
();
const
auto
blk_id
=
blk_idx
[
I0
];
const
auto
blk_td
=
blk_idx
[
I1
];
index_t
n_offset
=
blk_i
*
smfmac_instr
.
n_per_blk
+
blk_td
;
index_t
m_offset
=
xdlops_i
*
smfmac_instr
.
m_per_blk
+
blk_id
*
smfmac_instr
.
group_size
;
return
CIndex
{
m_offset
,
n_offset
};
}
__device__
static
CIndex4D
GetBeginOfThreadBlk4D
(
index_t
/* xdlops_i */
,
index_t
/* blk_i */
)
{
const
auto
blk_idx
=
GetBlkIdx
();
const
auto
blk_id
=
blk_idx
[
I0
];
const
auto
blk_td
=
blk_idx
[
I1
];
return
CIndex4D
{
I0
,
blk_id
,
I0
,
blk_td
};
}
static
constexpr
auto
smfmac
=
SmfmacSelector
<
base_type
,
MPerXdlops
,
NPerXdlops
,
additional_type
>
{};
static
constexpr
auto
smfmac_instr
=
smfmac
.
selected_smfmac
;
static
constexpr
auto
KPerXdlops
=
smfmac
.
GetKPerXdlops
();
static
constexpr
auto
K1PerXdlops
=
smfmac
.
GetK1PerXdlops
();
static
constexpr
auto
K0PerXdlops
=
KPerXdlops
/
K1PerXdlops
;
__host__
__device__
static
constexpr
auto
GetCM0M1M2NThreadBlkLengths
()
{
return
make_tuple
(
Number
<
smfmac_instr
.
num_groups_per_blk
>
{},
I1
,
Number
<
smfmac_instr
.
group_size
>
{},
I1
);
}
};
}
// namespace ck
include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp
View file @
3552041a
...
...
@@ -27,7 +27,7 @@ template <index_t NDimSpatial,
index_t
NPerBlock
,
index_t
GemmK1Number
,
index_t
K0PerBlock
,
index_t
Num
Batch
ToMerge
,
index_t
Num
Groups
ToMerge
,
device
::
ConvolutionBackwardWeightSpecialization
ConvBackwardWeightSpecialization
>
struct
TransformConvBwdWeightToGemmV2
{
...
...
@@ -45,7 +45,7 @@ struct TransformConvBwdWeightToGemmV2
const
index_t
BatchStride
=
output_strides
[
0
];
const
index_t
WoStride
=
output_strides
[
4
];
const
auto
KStride
=
Number
<
1
>
{};
return
make_naive_tensor_descriptor
(
make_tuple
(
N
*
Ho
*
Wo
,
Num
Batch
ToMerge
,
K
),
return
make_naive_tensor_descriptor
(
make_tuple
(
N
*
Ho
*
Wo
,
Num
Groups
ToMerge
,
K
),
make_tuple
(
WoStride
,
BatchStride
,
KStride
));
}
...
...
@@ -65,13 +65,13 @@ struct TransformConvBwdWeightToGemmV2
if
constexpr
(
ConvBackwardWeightSpecialization
==
device
::
ConvolutionBackwardWeightSpecialization
::
Filter1x1Stride1Pad0
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
N
*
Hi
*
Wi
,
Num
Batch
ToMerge
,
C
),
return
make_naive_tensor_descriptor
(
make_tuple
(
N
*
Hi
*
Wi
,
Num
Groups
ToMerge
,
C
),
make_tuple
(
WiStride
,
BatchStride
,
CStride
));
}
else
{
return
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Hi
,
Wi
,
Num
Batch
ToMerge
,
C
),
make_tuple
(
N
,
Hi
,
Wi
,
Num
Groups
ToMerge
,
C
),
make_tuple
(
NStride
,
HiStride
,
WiStride
,
BatchStride
,
CStride
));
}
}
...
...
@@ -88,30 +88,30 @@ struct TransformConvBwdWeightToGemmV2
const
auto
KStride
=
weights_strides
[
1
];
const
auto
XStride
=
weights_strides
[
4
];
const
auto
BatchStride
=
weights_strides
[
0
];
// Add Num
Batch
ToMerge for Batch+M dimension and, 1 as a placehorder
// Add Num
Groups
ToMerge for Batch+M dimension and, 1 as a placehorder
// for Batch+N dimension
const
auto
desc
=
make_naive_tensor_descriptor
(
make_tuple
(
Num
Batch
ToMerge
,
K
,
Y
*
X
,
1
,
C
),
make_tuple
(
Num
Groups
ToMerge
,
K
,
Y
*
X
,
1
,
C
),
make_tuple
(
BatchStride
,
KStride
,
XStride
,
BatchStride
,
CStride
));
// Padd 1 to Num
Batch
ToMerge
// Padd 1 to Num
Groups
ToMerge
const
auto
padded_desc
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_pass_through_transform
(
Num
Batch
ToMerge
),
make_tuple
(
make_pass_through_transform
(
Num
Groups
ToMerge
),
make_pass_through_transform
(
K
),
make_pass_through_transform
(
Y
*
X
),
make_pad_transform
(
1
,
0
,
Num
Batch
ToMerge
-
1
),
make_pad_transform
(
1
,
0
,
Num
Groups
ToMerge
-
1
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
// We need only matrices from diagonal. Xor returns 0 for the same
// values. So if matrices is not on diagonal then it will be stored in padding.
// To avoid use of modulo after xor we assume that NumBatch to merge is power of 2.
static_assert
(
Num
Batch
ToMerge
==
1
||
Num
Batch
ToMerge
==
2
||
Num
Batch
ToMerge
==
4
||
Num
Batch
ToMerge
==
8
||
Num
Batch
ToMerge
==
16
||
Num
Batch
ToMerge
==
32
||
Num
Batch
ToMerge
==
64
);
static_assert
(
Num
Groups
ToMerge
==
1
||
Num
Groups
ToMerge
==
2
||
Num
Groups
ToMerge
==
4
||
Num
Groups
ToMerge
==
8
||
Num
Groups
ToMerge
==
16
||
Num
Groups
ToMerge
==
32
||
Num
Groups
ToMerge
==
64
);
const
auto
unmerged_padded_desc
=
transform_tensor_descriptor
(
padded_desc
,
make_tuple
(
make_xor_transform
(
make_tuple
(
Num
Batch
ToMerge
,
Num
Batch
ToMerge
)),
make_tuple
(
make_xor_transform
(
make_tuple
(
Num
Groups
ToMerge
,
Num
Groups
ToMerge
)),
make_pass_through_transform
(
K
),
make_pass_through_transform
(
Y
*
X
),
make_pass_through_transform
(
C
)),
...
...
@@ -120,8 +120,8 @@ struct TransformConvBwdWeightToGemmV2
// Merge To M, N
return
transform_tensor_descriptor
(
unmerged_padded_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
Num
Batch
ToMerge
,
K
)),
make_merge_transform
(
make_tuple
(
Y
*
X
,
Num
Batch
ToMerge
,
C
))),
make_tuple
(
make_merge_transform
(
make_tuple
(
Num
Groups
ToMerge
,
K
)),
make_merge_transform
(
make_tuple
(
Y
*
X
,
Num
Groups
ToMerge
,
C
))),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
...
...
@@ -138,7 +138,7 @@ struct TransformConvBwdWeightToGemmV2
const
index_t
BatchStride
=
output_strides
[
0
];
const
index_t
WoStride
=
output_strides
[
5
];
const
auto
KStride
=
Number
<
1
>
{};
return
make_naive_tensor_descriptor
(
make_tuple
(
N
*
Do
*
Ho
*
Wo
,
Num
Batch
ToMerge
,
K
),
return
make_naive_tensor_descriptor
(
make_tuple
(
N
*
Do
*
Ho
*
Wo
,
Num
Groups
ToMerge
,
K
),
make_tuple
(
WoStride
,
BatchStride
,
KStride
));
}
...
...
@@ -160,13 +160,13 @@ struct TransformConvBwdWeightToGemmV2
if
constexpr
(
ConvBackwardWeightSpecialization
==
device
::
ConvolutionBackwardWeightSpecialization
::
Filter1x1Stride1Pad0
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
N
*
Di
*
Hi
*
Wi
,
Num
Batch
ToMerge
,
C
),
return
make_naive_tensor_descriptor
(
make_tuple
(
N
*
Di
*
Hi
*
Wi
,
Num
Groups
ToMerge
,
C
),
make_tuple
(
WiStride
,
BatchStride
,
CStride
));
}
else
{
return
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
Num
Batch
ToMerge
,
C
),
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
Num
Groups
ToMerge
,
C
),
make_tuple
(
NStride
,
DiStride
,
HiStride
,
WiStride
,
BatchStride
,
CStride
));
}
}
...
...
@@ -184,29 +184,29 @@ struct TransformConvBwdWeightToGemmV2
const
auto
KStride
=
weights_strides
[
1
];
const
auto
XStride
=
weights_strides
[
5
];
const
auto
BatchStride
=
weights_strides
[
0
];
// Add Num
Batch
ToMerge for Batch+M dimension and, 1 for placehord for Batch+N dimension
// Add Num
Groups
ToMerge for Batch+M dimension and, 1 for placehord for Batch+N dimension
const
auto
desc
=
make_naive_tensor_descriptor
(
make_tuple
(
Num
Batch
ToMerge
,
K
,
Z
*
Y
*
X
,
1
,
C
),
make_tuple
(
Num
Groups
ToMerge
,
K
,
Z
*
Y
*
X
,
1
,
C
),
make_tuple
(
BatchStride
,
KStride
,
XStride
,
BatchStride
,
CStride
));
// Padd 1 to Num
Batch
ToMerge
// Padd 1 to Num
Groups
ToMerge
const
auto
padded_desc
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_pass_through_transform
(
Num
Batch
ToMerge
),
make_tuple
(
make_pass_through_transform
(
Num
Groups
ToMerge
),
make_pass_through_transform
(
K
),
make_pass_through_transform
(
Z
*
Y
*
X
),
make_pad_transform
(
1
,
0
,
Num
Batch
ToMerge
-
1
),
make_pad_transform
(
1
,
0
,
Num
Groups
ToMerge
-
1
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
// We need only matrices from diagonal. Xor returns 0 for the same
// values. So if matrices is not on diagonal then it will be stored in padding.
// To avoid use of modulo after xor we assume that NumBatch to merge is power of 2.
static_assert
(
Num
Batch
ToMerge
==
1
||
Num
Batch
ToMerge
==
2
||
Num
Batch
ToMerge
==
4
||
Num
Batch
ToMerge
==
8
||
Num
Batch
ToMerge
==
16
||
Num
Batch
ToMerge
==
32
||
Num
Batch
ToMerge
==
64
);
static_assert
(
Num
Groups
ToMerge
==
1
||
Num
Groups
ToMerge
==
2
||
Num
Groups
ToMerge
==
4
||
Num
Groups
ToMerge
==
8
||
Num
Groups
ToMerge
==
16
||
Num
Groups
ToMerge
==
32
||
Num
Groups
ToMerge
==
64
);
const
auto
unmerged_padded_desc
=
transform_tensor_descriptor
(
padded_desc
,
make_tuple
(
make_xor_transform
(
make_tuple
(
Num
Batch
ToMerge
,
Num
Batch
ToMerge
)),
make_tuple
(
make_xor_transform
(
make_tuple
(
Num
Groups
ToMerge
,
Num
Groups
ToMerge
)),
make_pass_through_transform
(
K
),
make_pass_through_transform
(
Z
*
Y
*
X
),
make_pass_through_transform
(
C
)),
...
...
@@ -215,8 +215,8 @@ struct TransformConvBwdWeightToGemmV2
// Merge To M, N
return
transform_tensor_descriptor
(
unmerged_padded_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
Num
Batch
ToMerge
,
K
)),
make_merge_transform
(
make_tuple
(
Z
*
Y
*
X
,
Num
Batch
ToMerge
,
C
))),
make_tuple
(
make_merge_transform
(
make_tuple
(
Num
Groups
ToMerge
,
K
)),
make_merge_transform
(
make_tuple
(
Z
*
Y
*
X
,
Num
Groups
ToMerge
,
C
))),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
...
...
@@ -262,8 +262,8 @@ struct TransformConvBwdWeightToGemmV2
const
index_t
InRightPadW
=
input_right_pads
[
1
];
const
index_t
GemmKTotal
=
N
*
Ho
*
Wo
;
const
index_t
GemmM
=
K
*
Num
Batch
ToMerge
;
const
index_t
GemmN
=
C
*
X
*
Y
*
Num
Batch
ToMerge
;
const
index_t
GemmM
=
K
*
Num
Groups
ToMerge
;
const
index_t
GemmN
=
C
*
X
*
Y
*
Num
Groups
ToMerge
;
const
auto
PadGemmM
=
MPerBlock
-
GemmM
%
MPerBlock
;
const
auto
PadGemmN
=
NPerBlock
-
GemmN
%
NPerBlock
;
...
...
@@ -286,7 +286,7 @@ struct TransformConvBwdWeightToGemmV2
out_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_merge_transform
(
make_tuple
(
Num
Batch
ToMerge
,
GemmM
/
Num
Batch
ToMerge
))),
make_merge_transform
(
make_tuple
(
Num
Groups
ToMerge
,
GemmM
/
Num
Groups
ToMerge
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
...
...
@@ -302,7 +302,7 @@ struct TransformConvBwdWeightToGemmV2
in_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_merge_transform
(
make_tuple
(
Num
Batch
ToMerge
,
GemmN
/
Num
Batch
ToMerge
))),
make_merge_transform
(
make_tuple
(
Num
Groups
ToMerge
,
GemmN
/
Num
Groups
ToMerge
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
...
...
@@ -324,7 +324,7 @@ struct TransformConvBwdWeightToGemmV2
out_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_merge_transform
(
make_tuple
(
Num
Batch
ToMerge
,
GemmM
/
Num
Batch
ToMerge
))),
make_merge_transform
(
make_tuple
(
Num
Groups
ToMerge
,
GemmM
/
Num
Groups
ToMerge
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
...
...
@@ -341,7 +341,7 @@ struct TransformConvBwdWeightToGemmV2
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
Num
Batch
ToMerge
),
make_pass_through_transform
(
Num
Groups
ToMerge
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
...
...
@@ -354,7 +354,7 @@ struct TransformConvBwdWeightToGemmV2
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
Num
Batch
ToMerge
),
make_pass_through_transform
(
Num
Groups
ToMerge
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
...
...
@@ -366,7 +366,7 @@ struct TransformConvBwdWeightToGemmV2
const
auto
in_gemmktotal_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_n_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
Y
,
X
,
Num
Batch
ToMerge
,
C
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
Y
,
X
,
Num
Groups
ToMerge
,
C
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
3
,
5
,
6
>
{},
Sequence
<
0
,
2
,
4
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
...
...
@@ -465,8 +465,8 @@ struct TransformConvBwdWeightToGemmV2
const
index_t
InRightPadW
=
input_right_pads
[
2
];
const
index_t
GemmKTotal
=
N
*
Do
*
Ho
*
Wo
;
const
index_t
GemmM
=
K
*
Num
Batch
ToMerge
;
const
index_t
GemmN
=
C
*
Z
*
X
*
Y
*
Num
Batch
ToMerge
;
const
index_t
GemmM
=
K
*
Num
Groups
ToMerge
;
const
index_t
GemmN
=
C
*
Z
*
X
*
Y
*
Num
Groups
ToMerge
;
const
auto
PadGemmM
=
MPerBlock
-
GemmM
%
MPerBlock
;
const
auto
PadGemmN
=
NPerBlock
-
GemmN
%
NPerBlock
;
...
...
@@ -489,7 +489,7 @@ struct TransformConvBwdWeightToGemmV2
out_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_merge_transform
(
make_tuple
(
Num
Batch
ToMerge
,
GemmM
/
Num
Batch
ToMerge
))),
make_merge_transform
(
make_tuple
(
Num
Groups
ToMerge
,
GemmM
/
Num
Groups
ToMerge
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
...
...
@@ -505,7 +505,7 @@ struct TransformConvBwdWeightToGemmV2
in_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_merge_transform
(
make_tuple
(
Num
Batch
ToMerge
,
GemmN
/
Num
Batch
ToMerge
))),
make_merge_transform
(
make_tuple
(
Num
Groups
ToMerge
,
GemmN
/
Num
Groups
ToMerge
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
...
...
@@ -527,7 +527,7 @@ struct TransformConvBwdWeightToGemmV2
out_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_merge_transform
(
make_tuple
(
Num
Batch
ToMerge
,
GemmM
/
Num
Batch
ToMerge
))),
make_merge_transform
(
make_tuple
(
Num
Groups
ToMerge
,
GemmM
/
Num
Groups
ToMerge
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
...
...
@@ -545,7 +545,7 @@ struct TransformConvBwdWeightToGemmV2
make_pad_transform
(
Di
,
InLeftPadD
,
InRightPadD
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
Num
Batch
ToMerge
),
make_pass_through_transform
(
Num
Groups
ToMerge
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
...
...
@@ -567,7 +567,7 @@ struct TransformConvBwdWeightToGemmV2
make_embed_transform
(
make_tuple
(
Z
,
Do
),
make_tuple
(
ConvDilationD
,
ConvStrideD
)),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
Num
Batch
ToMerge
),
make_pass_through_transform
(
Num
Groups
ToMerge
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
...
...
@@ -584,7 +584,7 @@ struct TransformConvBwdWeightToGemmV2
const
auto
in_gemmktotal_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_n_z_do_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
Z
,
Y
,
X
,
Num
Batch
ToMerge
,
C
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
Z
,
Y
,
X
,
Num
Groups
ToMerge
,
C
)),
make_merge_transform
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
3
,
5
,
7
,
8
>
{},
Sequence
<
0
,
2
,
4
,
6
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
...
...
include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp
View file @
3552041a
...
...
@@ -14,27 +14,26 @@
namespace
ck
{
namespace
tensor_operation
{
// function to be used on device, emulates std::accumulate
template
<
typename
T
,
typename
ForwardIterator
,
typename
Size
>
__host__
__device__
auto
mult_accumulate_n
(
ForwardIterator
first
,
Size
count
,
T
init
)
{
for
(
ForwardIterator
x
=
first
;
x
!=
first
+
count
;
x
++
)
{
init
*=
*
x
;
}
return
init
;
}
template
<
index_t
NDimSpatial
,
device
::
ConvolutionForwardSpecialization
ConvForwardSpecialization
>
template
<
index_t
NDimSpatial
,
device
::
ConvolutionForwardSpecialization
ConvForwardSpecialization
,
bool
SplitN
=
false
,
typename
ADataType
=
float
,
typename
CDataType
=
float
,
index_t
NumGroupsToMerge
=
1
>
struct
TransformConvFwdToGemm
{
private:
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
long_index_t
calculate_element_space_size_impl
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
strides
,
index_t
i
)
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
template
<
typename
ConvDimsType
>
static
long_index_t
calculate_element_space_size_impl
(
const
ConvDimsType
&
lengths
,
const
ConvDimsType
&
strides
,
index_t
i
)
{
long_index_t
acc
=
1
;
for
(;
i
<
(
NDimSpatial
+
3
);
i
++
)
...
...
@@ -46,11 +45,11 @@ struct TransformConvFwdToGemm
return
acc
;
}
template
<
typename
ADataType
,
typename
CData
Type
>
static
index_t
GetSplitedNSize
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>
&
a_g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>
&
a_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>
&
c_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>
&
c_g_n_k_wos_strides
)
template
<
typename
ConvDims
Type
>
static
index_t
GetSplitedNSize
(
const
ConvDimsType
&
a_g_n_c_wis_lengths
,
const
ConvDimsType
&
a_g_n_c_wis_strides
,
const
ConvDimsType
&
c_g_n_k_wos_lengths
,
const
ConvDimsType
&
c_g_n_k_wos_strides
)
{
const
long_index_t
a_element_space_size
=
calculate_element_space_size_impl
(
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
I1
);
...
...
@@ -96,640 +95,421 @@ struct TransformConvFwdToGemm
}
}
// TODO: implement ck::tensor_layout::convolution that describe packed/strided dimemsion as
// properties
template
<
typename
ALayout
,
typename
std
::
enable_if
<
NDimSpatial
==
1
&&
(
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
G_NW_C
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
NWGC
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNWC
>
),
bool
>::
type
=
false
>
static
auto
MakeADescriptor_M_K
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* b_g_k_c_xs_strides */
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* c_g_n_k_wos_strides */
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
const
index_t
N
)
public:
__host__
__device__
constexpr
TransformConvFwdToGemm
()
{}
template
<
typename
ConvDimsType
,
typename
ConvSpatialDimsType
,
index_t
NDim
=
NDimSpatial
,
typename
std
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
__host__
__device__
TransformConvFwdToGemm
(
const
ConvDimsType
&
a_g_n_c_wis_lengths
,
const
ConvDimsType
&
a_g_n_c_wis_strides
,
const
ConvDimsType
&
b_g_k_c_xs_lengths
,
const
ConvDimsType
&
b_g_k_c_xs_strides
,
const
ConvDimsType
&
c_g_n_k_wos_lengths
,
const
ConvDimsType
&
c_g_n_k_wos_strides
,
const
ConvSpatialDimsType
&
conv_filter_strides
,
const
ConvSpatialDimsType
&
conv_filter_dilations
,
const
ConvSpatialDimsType
&
input_left_pads
,
const
ConvSpatialDimsType
&
input_right_pads
)
:
Di_
{
I1
},
Hi_
{
I1
},
Wi_
{
a_g_n_c_wis_lengths
[
I3
]},
Do_
{
I1
},
Ho_
{
I1
},
Wo_
{
c_g_n_k_wos_lengths
[
I3
]},
Z_
{
I1
},
Y_
{
I1
},
X_
{
b_g_k_c_xs_lengths
[
I3
]},
K_
{
c_g_n_k_wos_lengths
[
I2
]},
C_
{
b_g_k_c_xs_lengths
[
I2
]},
DiStride_
{
I1
},
HiStride_
{
I1
},
WiStride_
{
a_g_n_c_wis_strides
[
I3
]},
WoStride_
{
c_g_n_k_wos_strides
[
I3
]},
XStride_
{
b_g_k_c_xs_strides
[
I3
]},
CStrideTensorA_
{
a_g_n_c_wis_strides
[
I2
]},
CStrideTensorB_
{
b_g_k_c_xs_strides
[
I2
]},
KStrideTensorB_
{
b_g_k_c_xs_strides
[
I1
]},
KStrideTensorC_
{
c_g_n_k_wos_strides
[
I2
]},
NStrideTensorA_
{
a_g_n_c_wis_strides
[
I1
]},
GStrideTensorA_
{
a_g_n_c_wis_strides
[
I0
]},
GStrideTensorB_
{
b_g_k_c_xs_strides
[
I0
]},
GStrideTensorC_
{
c_g_n_k_wos_strides
[
I0
]},
ConvStrideD_
{
I1
},
ConvStrideH_
{
I1
},
ConvStrideW_
{
conv_filter_strides
[
I0
]},
ConvDilationD_
{
I1
},
ConvDilationH_
{
I1
},
ConvDilationW_
{
conv_filter_dilations
[
I0
]},
InLeftPadD_
{
I0
},
InLeftPadH_
{
I0
},
InLeftPadW_
{
input_left_pads
[
I0
]},
InRightPadD_
{
I0
},
InRightPadH_
{
I0
},
InRightPadW_
{
input_right_pads
[
I0
]},
ZYX_
{
X_
}
{
const
index_t
C
=
a_g_n_c_wis_lengths
[
2
];
const
index_t
Wi
=
a_g_n_c_wis_lengths
[
3
];
const
index_t
Wo
=
c_g_n_k_wos_lengths
[
3
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
0
];
if
constexpr
(
ConvForwardSpecialization
==
device
::
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
const
index_t
NHoWo
=
N
*
ck
::
accumulate_n
<
index_t
>
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
static_assert
(
is_same_v
<
ConvSpatialDimsType
,
std
::
array
<
index_t
,
NDimSpatial
>>
||
is_same_v
<
ConvSpatialDimsType
,
ck
::
Array
<
index_t
,
NDimSpatial
>>
);
static_assert
(
is_same_v
<
ConvDimsType
,
std
::
array
<
index_t
,
NDimSpatial
+
I3
>>
||
is_same_v
<
ConvDimsType
,
ck
::
Array
<
index_t
,
NDimSpatial
+
I3
>>
);
// This is different
const
index_t
WiStride
=
a_g_n_c_wis_strides
[
2
+
NDimSpatial
];
const
auto
CStride
=
I1
;
const
auto
in_gemmm_gemmk_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
NHoWo
,
C
),
make_tuple
(
WiStride
,
CStride
));
return
in_gemmm_gemmk_desc
;
}
else
if
constexpr
(
ConvForwardSpecialization
==
device
::
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
if
constexpr
(
SplitN
)
{
// This is different
const
index_t
NStride
=
a_g_n_c_wis_strides
[
1
];
const
index_t
WiStride
=
a_g_n_c_wis_strides
[
3
];
const
auto
CStride
=
I1
;
const
auto
in_n_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Wi
,
C
),
make_tuple
(
NStride
,
WiStride
,
CStride
));
const
auto
in_n_wo_c_desc
=
transform_tensor_descriptor
(
in_n_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Wo
),
make_tuple
(
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
in_gemmm_gemmk_desc
=
transform_tensor_descriptor
(
in_n_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Wo
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
in_gemmm_gemmk_desc
;
N_
=
GetSplitedNSize
(
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
c_g_n_k_wos_lengths
,
c_g_n_k_wos_strides
);
}
else
{
const
index_t
X
=
b_g_k_c_xs_lengths
[
3
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
0
];
const
index_t
InLeftPadW
=
input_left_pads
[
0
];
const
index_t
InRightPadW
=
input_right_pads
[
0
];
// This is different
const
index_t
NStride
=
a_g_n_c_wis_strides
[
1
];
const
index_t
WiStride
=
a_g_n_c_wis_strides
[
3
];
const
auto
CStride
=
I1
;
const
auto
in_n_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Wi
,
C
),
make_tuple
(
NStride
,
WiStride
,
CStride
));
const
auto
in_n_wip_c_desc
=
transform_tensor_descriptor
(
in_n_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
in_n_x_wo_c_desc
=
transform_tensor_descriptor
(
in_n_wip_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_gemmm_gemmk_desc
=
transform_tensor_descriptor
(
in_n_x_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Wo
)),
make_merge_transform
(
make_tuple
(
X
,
C
))),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
in_gemmm_gemmk_desc
;
N_
=
c_g_n_k_wos_lengths
[
I1
];
}
NDoHoWo_
=
N_
*
Wo_
;
}
template
<
typename
ALayout
,
typename
std
::
enable_if
<
NDimSpatial
==
2
&&
(
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
G_NHW_C
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
NHWGC
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNHWC
>
),
bool
>::
type
=
false
>
static
auto
MakeADescriptor_M_K
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* b_g_k_c_xs_strides */
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* c_g_n_k_wos_strides */
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
const
index_t
N
)
template
<
typename
ConvDimsType
,
typename
ConvSpatialDimsType
,
index_t
NDim
=
NDimSpatial
,
typename
std
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
__host__
__device__
TransformConvFwdToGemm
(
const
ConvDimsType
&
a_g_n_c_wis_lengths
,
const
ConvDimsType
&
a_g_n_c_wis_strides
,
const
ConvDimsType
&
b_g_k_c_xs_lengths
,
const
ConvDimsType
&
b_g_k_c_xs_strides
,
const
ConvDimsType
&
c_g_n_k_wos_lengths
,
const
ConvDimsType
&
c_g_n_k_wos_strides
,
const
ConvSpatialDimsType
&
conv_filter_strides
,
const
ConvSpatialDimsType
&
conv_filter_dilations
,
const
ConvSpatialDimsType
&
input_left_pads
,
const
ConvSpatialDimsType
&
input_right_pads
)
:
Di_
{
I1
},
Hi_
{
a_g_n_c_wis_lengths
[
I3
]},
Wi_
{
a_g_n_c_wis_lengths
[
I4
]},
Do_
{
I1
},
Ho_
{
c_g_n_k_wos_lengths
[
I3
]},
Wo_
{
c_g_n_k_wos_lengths
[
I4
]},
Z_
{
I1
},
Y_
{
b_g_k_c_xs_lengths
[
I3
]},
X_
{
b_g_k_c_xs_lengths
[
I4
]},
K_
{
c_g_n_k_wos_lengths
[
I2
]},
C_
{
b_g_k_c_xs_lengths
[
I2
]},
DiStride_
{
I1
},
HiStride_
{
a_g_n_c_wis_strides
[
I3
]},
WiStride_
{
a_g_n_c_wis_strides
[
I4
]},
WoStride_
{
c_g_n_k_wos_strides
[
I4
]},
XStride_
{
b_g_k_c_xs_strides
[
I4
]},
CStrideTensorA_
{
a_g_n_c_wis_strides
[
I2
]},
CStrideTensorB_
{
b_g_k_c_xs_strides
[
I2
]},
KStrideTensorB_
{
b_g_k_c_xs_strides
[
I1
]},
KStrideTensorC_
{
c_g_n_k_wos_strides
[
I2
]},
NStrideTensorA_
{
a_g_n_c_wis_strides
[
I1
]},
GStrideTensorA_
{
a_g_n_c_wis_strides
[
I0
]},
GStrideTensorB_
{
b_g_k_c_xs_strides
[
I0
]},
GStrideTensorC_
{
c_g_n_k_wos_strides
[
I0
]},
ConvStrideD_
{
I1
},
ConvStrideH_
{
conv_filter_strides
[
I0
]},
ConvStrideW_
{
conv_filter_strides
[
I1
]},
ConvDilationD_
{
I1
},
ConvDilationH_
{
conv_filter_dilations
[
I0
]},
ConvDilationW_
{
conv_filter_dilations
[
I1
]},
InLeftPadD_
{
I0
},
InLeftPadH_
{
input_left_pads
[
I0
]},
InLeftPadW_
{
input_left_pads
[
I1
]},
InRightPadD_
{
I0
},
InRightPadH_
{
input_right_pads
[
I0
]},
InRightPadW_
{
input_right_pads
[
I1
]},
ZYX_
{
Y_
*
X_
}
{
const
index_t
C
=
a_g_n_c_wis_lengths
[
2
];
const
index_t
Hi
=
a_g_n_c_wis_lengths
[
3
];
const
index_t
Wi
=
a_g_n_c_wis_lengths
[
4
];
const
index_t
Ho
=
c_g_n_k_wos_lengths
[
3
];
const
index_t
Wo
=
c_g_n_k_wos_lengths
[
4
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
1
];
if
constexpr
(
ConvForwardSpecialization
==
device
::
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
const
index_t
NHoWo
=
N
*
ck
::
accumulate_n
<
index_t
>
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
// This is different
const
index_t
WiStride
=
a_g_n_c_wis_strides
[
2
+
NDimSpatial
];
const
auto
CStride
=
I1
;
const
auto
in_gemmm_gemmk_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
NHoWo
,
C
),
make_tuple
(
WiStride
,
CStride
));
static_assert
(
is_same_v
<
ConvSpatialDimsType
,
std
::
array
<
index_t
,
NDimSpatial
>>
||
is_same_v
<
ConvSpatialDimsType
,
ck
::
Array
<
index_t
,
NDimSpatial
>>
);
static_assert
(
is_same_v
<
ConvDimsType
,
std
::
array
<
index_t
,
NDimSpatial
+
I3
>>
||
is_same_v
<
ConvDimsType
,
ck
::
Array
<
index_t
,
NDimSpatial
+
I3
>>
);
return
in_gemmm_gemmk_desc
;
}
else
if
constexpr
(
ConvForwardSpecialization
==
device
::
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
if
constexpr
(
SplitN
)
{
// This is different
const
index_t
NStride
=
a_g_n_c_wis_strides
[
1
];
const
index_t
HiStride
=
a_g_n_c_wis_strides
[
3
];
const
index_t
WiStride
=
a_g_n_c_wis_strides
[
4
];
const
auto
CStride
=
I1
;
const
auto
in_n_hi_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Hi
,
Wi
,
C
),
make_tuple
(
NStride
,
HiStride
,
WiStride
,
CStride
));
const
auto
in_n_ho_wo_c_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Ho
),
make_tuple
(
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
Wo
),
make_tuple
(
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_gemmm_gemmk_desc
=
transform_tensor_descriptor
(
in_n_ho_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
in_gemmm_gemmk_desc
;
N_
=
GetSplitedNSize
(
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
c_g_n_k_wos_lengths
,
c_g_n_k_wos_strides
);
}
else
{
const
index_t
Y
=
b_g_k_c_xs_lengths
[
3
];
const
index_t
X
=
b_g_k_c_xs_lengths
[
4
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
1
];
const
index_t
InLeftPadH
=
input_left_pads
[
0
];
const
index_t
InLeftPadW
=
input_left_pads
[
1
];
const
index_t
InRightPadH
=
input_right_pads
[
0
];
const
index_t
InRightPadW
=
input_right_pads
[
1
];
// This is different
const
index_t
NStride
=
a_g_n_c_wis_strides
[
1
];
const
index_t
HiStride
=
a_g_n_c_wis_strides
[
3
];
const
index_t
WiStride
=
a_g_n_c_wis_strides
[
4
];
const
auto
CStride
=
I1
;
const
auto
in_n_hi_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Hi
,
Wi
,
C
),
make_tuple
(
NStride
,
HiStride
,
WiStride
,
CStride
));
const
auto
in_n_hip_wip_c_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_y_ho_x_wo_c_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_gemmm_gemmk_desc
=
transform_tensor_descriptor
(
in_n_y_ho_x_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
)),
make_merge_transform
(
make_tuple
(
Y
,
X
,
C
))),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
in_gemmm_gemmk_desc
;
N_
=
c_g_n_k_wos_lengths
[
I1
];
}
NDoHoWo_
=
N_
*
Ho_
*
Wo_
;
}
template
<
typename
ALayout
,
typename
std
::
enable_if
<
NDimSpatial
==
3
&&
(
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
G_NDHW_C
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
NDHWGC
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNDHWC
>
),
bool
>::
type
=
false
>
static
auto
MakeADescriptor_M_K
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* b_g_k_c_xs_strides */
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* c_g_n_k_wos_strides*/
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
const
index_t
N
)
template
<
typename
ConvDimsType
,
typename
ConvSpatialDimsType
,
index_t
NDim
=
NDimSpatial
,
typename
std
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
__host__
__device__
TransformConvFwdToGemm
(
const
ConvDimsType
&
a_g_n_c_wis_lengths
,
const
ConvDimsType
&
a_g_n_c_wis_strides
,
const
ConvDimsType
&
b_g_k_c_xs_lengths
,
const
ConvDimsType
&
b_g_k_c_xs_strides
,
const
ConvDimsType
&
c_g_n_k_wos_lengths
,
const
ConvDimsType
&
c_g_n_k_wos_strides
,
const
ConvSpatialDimsType
&
conv_filter_strides
,
const
ConvSpatialDimsType
&
conv_filter_dilations
,
const
ConvSpatialDimsType
&
input_left_pads
,
const
ConvSpatialDimsType
&
input_right_pads
)
:
Di_
{
a_g_n_c_wis_lengths
[
I3
]},
Hi_
{
a_g_n_c_wis_lengths
[
I4
]},
Wi_
{
a_g_n_c_wis_lengths
[
I5
]},
Do_
{
c_g_n_k_wos_lengths
[
I3
]},
Ho_
{
c_g_n_k_wos_lengths
[
I4
]},
Wo_
{
c_g_n_k_wos_lengths
[
I5
]},
Z_
{
b_g_k_c_xs_lengths
[
I3
]},
Y_
{
b_g_k_c_xs_lengths
[
I4
]},
X_
{
b_g_k_c_xs_lengths
[
I5
]},
K_
{
c_g_n_k_wos_lengths
[
I2
]},
C_
{
b_g_k_c_xs_lengths
[
I2
]},
DiStride_
{
a_g_n_c_wis_strides
[
I3
]},
HiStride_
{
a_g_n_c_wis_strides
[
I4
]},
WiStride_
{
a_g_n_c_wis_strides
[
I5
]},
WoStride_
{
c_g_n_k_wos_strides
[
I5
]},
XStride_
{
b_g_k_c_xs_strides
[
I5
]},
CStrideTensorA_
{
a_g_n_c_wis_strides
[
I2
]},
CStrideTensorB_
{
b_g_k_c_xs_strides
[
I2
]},
KStrideTensorB_
{
b_g_k_c_xs_strides
[
I1
]},
KStrideTensorC_
{
c_g_n_k_wos_strides
[
I2
]},
NStrideTensorA_
{
a_g_n_c_wis_strides
[
I1
]},
GStrideTensorA_
{
a_g_n_c_wis_strides
[
I0
]},
GStrideTensorB_
{
b_g_k_c_xs_strides
[
I0
]},
GStrideTensorC_
{
c_g_n_k_wos_strides
[
I0
]},
ConvStrideD_
{
conv_filter_strides
[
I0
]},
ConvStrideH_
{
conv_filter_strides
[
I1
]},
ConvStrideW_
{
conv_filter_strides
[
I2
]},
ConvDilationD_
{
conv_filter_dilations
[
I0
]},
ConvDilationH_
{
conv_filter_dilations
[
I1
]},
ConvDilationW_
{
conv_filter_dilations
[
I2
]},
InLeftPadD_
{
input_left_pads
[
I0
]},
InLeftPadH_
{
input_left_pads
[
I1
]},
InLeftPadW_
{
input_left_pads
[
I2
]},
InRightPadD_
{
input_right_pads
[
I0
]},
InRightPadH_
{
input_right_pads
[
I1
]},
InRightPadW_
{
input_right_pads
[
I2
]},
ZYX_
{
Z_
*
Y_
*
X_
}
{
const
index_t
C
=
a_g_n_c_wis_lengths
[
2
];
const
index_t
Di
=
a_g_n_c_wis_lengths
[
3
];
const
index_t
Hi
=
a_g_n_c_wis_lengths
[
4
];
const
index_t
Wi
=
a_g_n_c_wis_lengths
[
5
];
static_assert
(
is_same_v
<
ConvSpatialDimsType
,
std
::
array
<
index_t
,
NDimSpatial
>>
||
is_same_v
<
ConvSpatialDimsType
,
ck
::
Array
<
index_t
,
NDimSpatial
>>
);
static_assert
(
is_same_v
<
ConvDimsType
,
std
::
array
<
index_t
,
NDimSpatial
+
I3
>>
||
is_same_v
<
ConvDimsType
,
ck
::
Array
<
index_t
,
NDimSpatial
+
I3
>>
);
const
index_t
Do
=
c_g_n_k_wos_lengths
[
3
];
const
index_t
Ho
=
c_g_n_k_wos_lengths
[
4
];
const
index_t
Wo
=
c_g_n_k_wos_lengths
[
5
];
const
index_t
ConvStrideD
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
1
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
2
];
if
constexpr
(
ConvForwardSpecialization
==
device
::
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
const
index_t
NDoHoWo
=
N
*
ck
::
accumulate_n
<
index_t
>
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
// This is different
const
index_t
WiStride
=
a_g_n_c_wis_strides
[
2
+
NDimSpatial
];
const
auto
CStride
=
I1
;
const
auto
in_gemmm_gemmk_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
NDoHoWo
,
C
),
make_tuple
(
WiStride
,
CStride
));
return
in_gemmm_gemmk_desc
;
}
else
if
constexpr
(
ConvForwardSpecialization
==
device
::
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
if
constexpr
(
SplitN
)
{
// This is different
const
index_t
NStride
=
a_g_n_c_wis_strides
[
1
];
const
index_t
DiStride
=
a_g_n_c_wis_strides
[
3
];
const
index_t
HiStride
=
a_g_n_c_wis_strides
[
4
];
const
index_t
WiStride
=
a_g_n_c_wis_strides
[
5
];
const
auto
CStride
=
I1
;
const
auto
in_n_di_hi_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
C
),
make_tuple
(
NStride
,
DiStride
,
HiStride
,
WiStride
,
CStride
));
const
auto
in_n_do_ho_wo_c_desc
=
transform_tensor_descriptor
(
in_n_di_hi_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Do
),
make_tuple
(
ConvStrideD
)),
make_embed_transform
(
make_tuple
(
Ho
),
make_tuple
(
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
Wo
),
make_tuple
(
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
auto
in_gemmm_gemmk_desc
=
transform_tensor_descriptor
(
in_n_do_ho_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
in_gemmm_gemmk_desc
;
N_
=
GetSplitedNSize
(
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
c_g_n_k_wos_lengths
,
c_g_n_k_wos_strides
);
}
else
{
const
index_t
Z
=
b_g_k_c_xs_lengths
[
3
];
const
index_t
Y
=
b_g_k_c_xs_lengths
[
4
];
const
index_t
X
=
b_g_k_c_xs_lengths
[
5
];
const
index_t
ConvDilationD
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
1
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
2
];
const
index_t
InLeftPadD
=
input_left_pads
[
0
];
const
index_t
InLeftPadH
=
input_left_pads
[
1
];
const
index_t
InLeftPadW
=
input_left_pads
[
2
];
const
index_t
InRightPadD
=
input_right_pads
[
0
];
const
index_t
InRightPadH
=
input_right_pads
[
1
];
const
index_t
InRightPadW
=
input_right_pads
[
2
];
// This is different
const
index_t
NStride
=
a_g_n_c_wis_strides
[
1
];
const
index_t
DiStride
=
a_g_n_c_wis_strides
[
3
];
const
index_t
HiStride
=
a_g_n_c_wis_strides
[
4
];
const
index_t
WiStride
=
a_g_n_c_wis_strides
[
5
];
const
auto
CStride
=
I1
;
const
auto
in_n_di_hi_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
C
),
make_tuple
(
NStride
,
DiStride
,
HiStride
,
WiStride
,
CStride
));
const
auto
in_n_hip_wip_c_desc
=
transform_tensor_descriptor
(
in_n_di_hi_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Di
,
InLeftPadD
,
InRightPadD
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
auto
in_n_z_do_y_ho_x_wo_c_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Z
,
Do
),
make_tuple
(
ConvDilationD
,
ConvStrideD
)),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
,
6
>
{},
Sequence
<
7
>
{}));
const
auto
in_gemmm_gemmk_desc
=
transform_tensor_descriptor
(
in_n_z_do_y_ho_x_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
)),
make_merge_transform
(
make_tuple
(
Z
,
Y
,
X
,
C
))),
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
>
{},
Sequence
<
1
,
3
,
5
,
7
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
in_gemmm_gemmk_desc
;
N_
=
c_g_n_k_wos_lengths
[
I1
];
}
NDoHoWo_
=
N_
*
Do_
*
Ho_
*
Wo_
;
}
template
<
typename
BLayout
,
typename
std
::
enable_if
<
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
GKXC
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
GKYXC
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
GKZYXC
>
,
bool
>::
type
=
false
>
static
auto
MakeBDescriptor_N_K
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* b_g_k_c_xs_strides */
)
{
const
index_t
K
=
b_g_k_c_xs_lengths
[
1
];
const
index_t
C
=
b_g_k_c_xs_lengths
[
2
];
const
index_t
YX
=
ck
::
accumulate_n
<
index_t
>
(
b_g_k_c_xs_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
const
auto
wei_gemmn_gemmk_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
YX
*
C
));
return
wei_gemmn_gemmk_desc
;
}
template
<
typename
BLayout
,
typename
std
::
enable_if
<
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
G_K_X_C
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
G_K_YX_C
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
G_K_ZYX_C
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
KXGC
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
KYXGC
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
KZYXGC
>
,
bool
>::
type
=
false
>
static
auto
MakeBDescriptor_N_K
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
)
{
const
index_t
K
=
b_g_k_c_xs_lengths
[
1
];
const
index_t
C
=
b_g_k_c_xs_lengths
[
2
];
const
index_t
YX
=
ck
::
accumulate_n
<
index_t
>
(
b_g_k_c_xs_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
const
index_t
KStride
=
b_g_k_c_xs_strides
[
1
];
const
index_t
XStride
=
b_g_k_c_xs_strides
[
2
+
NDimSpatial
];
const
auto
CStride
=
I1
;
const
auto
wei_k_yx_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K
,
YX
,
C
),
make_tuple
(
KStride
,
XStride
,
CStride
));
const
auto
wei_gemmn_gemmk_desc
=
transform_tensor_descriptor
(
wei_k_yx_c_desc
,
make_tuple
(
make_pass_through_transform
(
K
),
make_merge_transform
(
make_tuple
(
YX
,
C
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
wei_gemmn_gemmk_desc
;
}
template
<
typename
CLayout
,
typename
std
::
enable_if
<
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
GNWK
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
GNHWK
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
GNDHWK
>
,
bool
>::
type
=
false
>
static
auto
MakeCDescriptor_M_N
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* c_g_n_k_wos_strides */
,
const
index_t
N
)
{
const
index_t
K
=
c_g_n_k_wos_lengths
[
2
];
const
index_t
NHoWo
=
N
*
ck
::
accumulate_n
<
index_t
>
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
const
auto
out_gemmm_gemmn_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
NHoWo
,
K
));
return
out_gemmm_gemmn_desc
;
}
template
<
typename
CLayout
,
typename
std
::
enable_if
<
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_NW_K
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_NHW_K
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_NDHW_K
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
NWGK
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
NHWGK
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
NDHWGK
>
,
bool
>::
type
=
false
>
static
auto
MakeCDescriptor_M_N
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_strides
,
const
index_t
N
)
{
const
index_t
K
=
c_g_n_k_wos_lengths
[
2
];
const
auto
KStride
=
I1
;
const
index_t
WoStride
=
c_g_n_k_wos_strides
[
NDimSpatial
+
2
];
const
index_t
NHoWo
=
N
*
ck
::
accumulate_n
<
index_t
>
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
const
auto
out_gemmm_gemmn_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
NHoWo
,
K
),
make_tuple
(
WoStride
,
KStride
));
return
out_gemmm_gemmn_desc
;
}
// for output bias
template
<
typename
CLayout
,
typename
std
::
enable_if
<
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_K
>,
bool
>::
type
=
false
>
static
auto
MakeCDescriptor_M_N
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_strides
,
const
index_t
N
)
{
const
index_t
K
=
c_g_n_k_wos_lengths
[
2
];
const
index_t
KStride
=
c_g_n_k_wos_strides
[
2
];
const
index_t
NHoWo
=
N
*
ck
::
accumulate_n
<
index_t
>
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
const
auto
out_gemmm_gemmn_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
NHoWo
,
K
),
make_tuple
(
I0
,
KStride
));
return
out_gemmm_gemmn_desc
;
}
// Overloaded functions for hipRTC purposes
// TODO: implement ck::tensor_layout::convolution that describe packed/strided dimemsion as
// properties
template
<
typename
ALayout
,
typename
std
::
enable_if
<
NDimSpatial
==
1
&&
(
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
G_NW_C
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
NWGC
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNWC
>
),
bool
>::
type
=
false
>
__host__
__device__
static
auto
MakeADescriptor_M_K
(
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
/* b_g_k_c_xs_strides */
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
/* c_g_n_k_wos_strides */
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
input_right_pads
)
__host__
__device__
auto
MakeADescriptor_M_K
()
const
{
const
index_t
N
=
a_g_n_c_wis_lengths
[
1
];
const
index_t
C
=
a_g_n_c_wis_lengths
[
2
];
const
index_t
Wi
=
a_g_n_c_wis_lengths
[
3
];
const
index_t
Wo
=
c_g_n_k_wos_lengths
[
3
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
0
];
if
constexpr
(
ConvForwardSpecialization
==
device
::
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
const
index_t
NHoWo
=
N
*
ck
::
accumulate_n
<
index_t
>
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
// This is different
const
index_t
WiStride
=
a_g_n_c_wis_strides
[
2
+
NDimSpatial
];
const
auto
CStride
=
I1
;
const
auto
in_gemmm_gemmk_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
NHoWo
,
C
),
make_tuple
(
WiStride
,
CStride
));
if
constexpr
(
NumGroupsToMerge
==
1
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
NDoHoWo_
,
C_
),
make_tuple
(
WiStride_
,
CStrideTensorA_
));
}
else
{
const
auto
in_gemmm_groups_gemmk_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
NDoHoWo_
,
NumGroupsToMerge
,
C_
),
make_tuple
(
WiStride_
,
GStrideTensorA_
,
CStrideTensorA_
));
return
transform_tensor_descriptor
(
in_gemmm_groups_gemmk_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
NDoHoWo_
,
NumGroupsToMerge
)),
make_pass_through_transform
(
C_
)),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}
else
if
constexpr
(
ConvForwardSpecialization
==
device
::
ConvolutionForwardSpecialization
::
Filter3x3
)
{
if
constexpr
(
NumGroupsToMerge
==
1
)
{
return
in_gemmm_gemmk_desc
;
const
auto
in_n_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N_
,
Wi_
),
make_tuple
(
NStrideTensorA_
,
WiStride_
));
const
auto
in_n_wip_c_desc
=
transform_tensor_descriptor
(
in_n_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N_
),
make_pad_transform
(
Wi_
,
InLeftPadW_
,
InRightPadW_
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_n_x_wo_c_desc
=
transform_tensor_descriptor
(
in_n_wip_c_desc
,
make_tuple
(
make_pass_through_transform
(
N_
),
make_embed_transform
(
make_tuple
(
Number
<
3
>
{},
Wo_
),
make_tuple
(
ConvDilationW_
,
ConvStrideW_
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}));
return
transform_tensor_descriptor
(
in_n_x_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N_
,
Wo_
)),
make_pass_through_transform
(
Number
<
3
>
{})),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
const
auto
in_n_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N_
,
Wi_
,
NumGroupsToMerge
),
make_tuple
(
NStrideTensorA_
,
WiStride_
,
GStrideTensorA_
));
const
auto
in_n_wip_c_desc
=
transform_tensor_descriptor
(
in_n_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N_
),
make_pad_transform
(
Wi_
,
InLeftPadW_
,
InRightPadW_
),
make_pass_through_transform
(
NumGroupsToMerge
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
in_n_x_wo_c_desc
=
transform_tensor_descriptor
(
in_n_wip_c_desc
,
make_tuple
(
make_pass_through_transform
(
N_
),
make_embed_transform
(
make_tuple
(
Number
<
3
>
{},
Wo_
),
make_tuple
(
ConvDilationW_
,
ConvStrideW_
)),
make_pass_through_transform
(
NumGroupsToMerge
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
return
transform_tensor_descriptor
(
in_n_x_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N_
,
Wo_
,
NumGroupsToMerge
)),
make_pass_through_transform
(
Number
<
3
>
{})),
make_tuple
(
Sequence
<
0
,
2
,
3
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}
else
if
constexpr
(
ConvForwardSpecialization
==
device
::
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
{
// This is different
const
index_t
NStride
=
a_g_n_c_wis_strides
[
1
];
const
index_t
WiStride
=
a_g_n_c_wis_strides
[
3
];
const
auto
CStride
=
I1
;
const
auto
in_n_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Wi
,
C
),
make_tuple
(
NStride
,
WiStride
,
CStride
));
const
auto
in_n_wo_c_desc
=
transform_tensor_descriptor
(
in_n_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Wo
),
make_tuple
(
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
in_gemmm_gemmk_desc
=
transform_tensor_descriptor
(
in_n_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Wo
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
in_gemmm_gemmk_desc
;
if
constexpr
(
NumGroupsToMerge
==
1
)
{
const
auto
in_n_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N_
,
Wi_
,
C_
),
make_tuple
(
NStrideTensorA_
,
WiStride_
,
CStrideTensorA_
));
const
auto
in_n_wo_c_desc
=
transform_tensor_descriptor
(
in_n_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N_
),
make_embed_transform
(
make_tuple
(
Wo_
),
make_tuple
(
ConvStrideW_
)),
make_pass_through_transform
(
C_
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
return
transform_tensor_descriptor
(
in_n_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N_
,
Wo_
)),
make_pass_through_transform
(
C_
)),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
const
auto
in_n_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N_
,
Wi_
,
NumGroupsToMerge
,
C_
),
make_tuple
(
NStrideTensorA_
,
WiStride_
,
GStrideTensorA_
,
CStrideTensorA_
));
const
auto
in_n_wo_c_desc
=
transform_tensor_descriptor
(
in_n_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N_
),
make_embed_transform
(
make_tuple
(
Wo_
),
make_tuple
(
ConvStrideW_
)),
make_pass_through_transform
(
NumGroupsToMerge
),
make_pass_through_transform
(
C_
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
return
transform_tensor_descriptor
(
in_n_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N_
,
Wo_
,
NumGroupsToMerge
)),
make_pass_through_transform
(
C_
)),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}
else
{
const
index_t
X
=
b_g_k_c_xs_lengths
[
3
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
0
];
const
index_t
InLeftPadW
=
input_left_pads
[
0
];
const
index_t
InRightPadW
=
input_right_pads
[
0
];
// This is different
const
index_t
NStride
=
a_g_n_c_wis_strides
[
1
];
const
index_t
WiStride
=
a_g_n_c_wis_strides
[
3
];
const
auto
CStride
=
I1
;
const
auto
in_n_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Wi
,
C
),
make_tuple
(
NStride
,
WiStride
,
CStride
));
const
auto
in_n_wip_c_desc
=
transform_tensor_descriptor
(
in_n_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
in_n_x_wo_c_desc
=
transform_tensor_descriptor
(
in_n_wip_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_gemmm_gemmk_desc
=
transform_tensor_descriptor
(
in_n_x_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Wo
)),
make_merge_transform
(
make_tuple
(
X
,
C
))),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
in_gemmm_gemmk_desc
;
if
constexpr
(
NumGroupsToMerge
==
1
)
{
const
auto
in_n_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N_
,
Wi_
,
C_
),
make_tuple
(
NStrideTensorA_
,
WiStride_
,
CStrideTensorA_
));
const
auto
in_n_wip_c_desc
=
transform_tensor_descriptor
(
in_n_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N_
),
make_pad_transform
(
Wi_
,
InLeftPadW_
,
InRightPadW_
),
make_pass_through_transform
(
C_
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
in_n_x_wo_c_desc
=
transform_tensor_descriptor
(
in_n_wip_c_desc
,
make_tuple
(
make_pass_through_transform
(
N_
),
make_embed_transform
(
make_tuple
(
X_
,
Wo_
),
make_tuple
(
ConvDilationW_
,
ConvStrideW_
)),
make_pass_through_transform
(
C_
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
return
transform_tensor_descriptor
(
in_n_x_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N_
,
Wo_
)),
make_merge_transform
(
make_tuple
(
X_
,
C_
))),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
const
auto
in_n_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N_
,
Wi_
,
NumGroupsToMerge
,
C_
),
make_tuple
(
NStrideTensorA_
,
WiStride_
,
GStrideTensorA_
,
CStrideTensorA_
));
const
auto
in_n_wip_c_desc
=
transform_tensor_descriptor
(
in_n_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N_
),
make_pad_transform
(
Wi_
,
InLeftPadW_
,
InRightPadW_
),
make_pass_through_transform
(
NumGroupsToMerge
),
make_pass_through_transform
(
C_
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_x_wo_c_desc
=
transform_tensor_descriptor
(
in_n_wip_c_desc
,
make_tuple
(
make_pass_through_transform
(
N_
),
make_embed_transform
(
make_tuple
(
X_
,
Wo_
),
make_tuple
(
ConvDilationW_
,
ConvStrideW_
)),
make_pass_through_transform
(
NumGroupsToMerge
),
make_pass_through_transform
(
C_
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
return
transform_tensor_descriptor
(
in_n_x_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N_
,
Wo_
,
NumGroupsToMerge
)),
make_merge_transform
(
make_tuple
(
X_
,
C_
))),
make_tuple
(
Sequence
<
0
,
2
,
3
>
{},
Sequence
<
1
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}
}
...
...
@@ -739,126 +519,229 @@ struct TransformConvFwdToGemm
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
NHWGC
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNHWC
>
),
bool
>::
type
=
false
>
__host__
__device__
static
auto
MakeADescriptor_M_K
(
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
/* b_g_k_c_xs_strides */
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
/* c_g_n_k_wos_strides */
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
input_right_pads
)
{
const
index_t
N
=
a_g_n_c_wis_lengths
[
1
];
const
index_t
C
=
a_g_n_c_wis_lengths
[
2
];
const
index_t
Hi
=
a_g_n_c_wis_lengths
[
3
];
const
index_t
Wi
=
a_g_n_c_wis_lengths
[
4
];
const
index_t
Ho
=
c_g_n_k_wos_lengths
[
3
];
const
index_t
Wo
=
c_g_n_k_wos_lengths
[
4
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
1
];
__host__
__device__
auto
MakeADescriptor_M_K
()
const
{
if
constexpr
(
ConvForwardSpecialization
==
device
::
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
const
index_t
NHoWo
=
N
*
ck
::
accumulate_n
<
index_t
>
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
// This is different
const
index_t
WiStride
=
a_g_n_c_wis_strides
[
2
+
NDimSpatial
];
const
auto
CStride
=
I1
;
const
auto
in_gemmm_gemmk_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
NHoWo
,
C
),
make_tuple
(
WiStride
,
CStride
));
return
in_gemmm_gemmk_desc
;
if
constexpr
(
NumGroupsToMerge
==
1
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
NDoHoWo_
,
C_
),
make_tuple
(
WiStride_
,
CStrideTensorA_
));
}
else
{
const
auto
in_gemmm_groups_gemmk_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
NDoHoWo_
,
NumGroupsToMerge
,
C_
),
make_tuple
(
WiStride_
,
GStrideTensorA_
,
CStrideTensorA_
));
return
transform_tensor_descriptor
(
in_gemmm_groups_gemmk_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
NDoHoWo_
,
NumGroupsToMerge
)),
make_pass_through_transform
(
C_
)),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}
else
if
constexpr
(
ConvForwardSpecialization
==
device
::
ConvolutionForwardSpecialization
::
Filter3x3
)
{
if
constexpr
(
NumGroupsToMerge
==
1
)
{
const
auto
in_n_hi_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N_
,
Hi_
,
Wi_
),
make_tuple
(
NStrideTensorA_
,
HiStride_
,
WiStride_
));
const
auto
in_n_hip_wip_c_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N_
),
make_pad_transform
(
Hi_
,
InLeftPadH_
,
InRightPadH_
),
make_pad_transform
(
Wi_
,
InLeftPadW_
,
InRightPadW_
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
in_n_y_ho_x_wo_c_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_desc
,
make_tuple
(
make_pass_through_transform
(
N_
),
make_embed_transform
(
make_tuple
(
Number
<
3
>
{},
Ho_
),
make_tuple
(
ConvDilationH_
,
ConvStrideH_
)),
make_embed_transform
(
make_tuple
(
Number
<
3
>
{},
Wo_
),
make_tuple
(
ConvDilationW_
,
ConvStrideW_
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{}));
return
transform_tensor_descriptor
(
in_n_y_ho_x_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N_
,
Ho_
,
Wo_
)),
make_merge_transform
(
make_tuple
(
Number
<
3
>
{},
Number
<
3
>
{}))),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
const
auto
in_n_hi_wi_groups_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N_
,
Hi_
,
Wi_
,
NumGroupsToMerge
),
make_tuple
(
NStrideTensorA_
,
HiStride_
,
WiStride_
,
GStrideTensorA_
));
const
auto
in_n_hip_wip_groups_c_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_groups_c_desc
,
make_tuple
(
make_pass_through_transform
(
N_
),
make_pad_transform
(
Hi_
,
InLeftPadH_
,
InRightPadH_
),
make_pad_transform
(
Wi_
,
InLeftPadW_
,
InRightPadW_
),
make_pass_through_transform
(
NumGroupsToMerge
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_y_ho_x_wo_groups_c_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_groups_c_desc
,
make_tuple
(
make_pass_through_transform
(
N_
),
make_embed_transform
(
make_tuple
(
Number
<
3
>
{},
Ho_
),
make_tuple
(
ConvDilationH_
,
ConvStrideH_
)),
make_embed_transform
(
make_tuple
(
Number
<
3
>
{},
Wo_
),
make_tuple
(
ConvDilationW_
,
ConvStrideW_
)),
make_pass_through_transform
(
NumGroupsToMerge
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
return
transform_tensor_descriptor
(
in_n_y_ho_x_wo_groups_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N_
,
Ho_
,
Wo_
,
NumGroupsToMerge
)),
make_merge_transform
(
make_tuple
(
Number
<
3
>
{},
Number
<
3
>
{}))),
make_tuple
(
Sequence
<
0
,
2
,
4
,
5
>
{},
Sequence
<
1
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}
else
if
constexpr
(
ConvForwardSpecialization
==
device
::
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
{
// This is different
const
index_t
NStride
=
a_g_n_c_wis_strides
[
1
];
const
index_t
HiStride
=
a_g_n_c_wis_strides
[
3
];
const
index_t
WiStride
=
a_g_n_c_wis_strides
[
4
];
const
auto
CStride
=
I1
;
const
auto
in_n_hi_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Hi
,
Wi
,
C
),
make_tuple
(
NStride
,
HiStride
,
WiStride
,
CStride
));
const
auto
in_n_ho_wo_c_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Ho
),
make_tuple
(
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
Wo
),
make_tuple
(
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_gemmm_gemmk_desc
=
transform_tensor_descriptor
(
in_n_ho_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
in_gemmm_gemmk_desc
;
if
constexpr
(
NumGroupsToMerge
==
1
)
{
const
auto
in_n_hi_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N_
,
Hi_
,
Wi_
,
C_
),
make_tuple
(
NStrideTensorA_
,
HiStride_
,
WiStride_
,
CStrideTensorA_
));
const
auto
in_n_ho_wo_c_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N_
),
make_embed_transform
(
make_tuple
(
Ho_
),
make_tuple
(
ConvStrideH_
)),
make_embed_transform
(
make_tuple
(
Wo_
),
make_tuple
(
ConvStrideW_
)),
make_pass_through_transform
(
C_
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
return
transform_tensor_descriptor
(
in_n_ho_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N_
,
Ho_
,
Wo_
)),
make_pass_through_transform
(
C_
)),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
const
auto
in_n_hi_wi_groups_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N_
,
Hi_
,
Wi_
,
NumGroupsToMerge
,
C_
),
make_tuple
(
NStrideTensorA_
,
HiStride_
,
WiStride_
,
GStrideTensorA_
,
CStrideTensorA_
));
const
auto
in_n_ho_wo_groups_c_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_groups_c_desc
,
make_tuple
(
make_pass_through_transform
(
N_
),
make_embed_transform
(
make_tuple
(
Ho_
),
make_tuple
(
ConvStrideH_
)),
make_embed_transform
(
make_tuple
(
Wo_
),
make_tuple
(
ConvStrideW_
)),
make_pass_through_transform
(
NumGroupsToMerge
),
make_pass_through_transform
(
C_
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
return
transform_tensor_descriptor
(
in_n_ho_wo_groups_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N_
,
Ho_
,
Wo_
,
NumGroupsToMerge
)),
make_pass_through_transform
(
C_
)),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}
else
{
const
index_t
Y
=
b_g_k_c_xs_lengths
[
3
];
const
index_t
X
=
b_g_k_c_xs_lengths
[
4
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
1
];
const
index_t
InLeftPadH
=
input_left_pads
[
0
];
const
index_t
InLeftPadW
=
input_left_pads
[
1
];
const
index_t
InRightPadH
=
input_right_pads
[
0
];
const
index_t
InRightPadW
=
input_right_pads
[
1
];
// This is different
const
index_t
NStride
=
a_g_n_c_wis_strides
[
1
];
const
index_t
HiStride
=
a_g_n_c_wis_strides
[
3
];
const
index_t
WiStride
=
a_g_n_c_wis_strides
[
4
];
const
auto
CStride
=
I1
;
const
auto
in_n_hi_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Hi
,
Wi
,
C
),
make_tuple
(
NStride
,
HiStride
,
WiStride
,
CStride
));
const
auto
in_n_hip_wip_c_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_y_ho_x_wo_c_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_gemmm_gemmk_desc
=
transform_tensor_descriptor
(
in_n_y_ho_x_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
)),
make_merge_transform
(
make_tuple
(
Y
,
X
,
C
))),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
if
constexpr
(
NumGroupsToMerge
==
1
)
{
const
auto
in_n_hi_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N_
,
Hi_
,
Wi_
,
C_
),
make_tuple
(
NStrideTensorA_
,
HiStride_
,
WiStride_
,
CStrideTensorA_
));
const
auto
in_n_hip_wip_c_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N_
),
make_pad_transform
(
Hi_
,
InLeftPadH_
,
InRightPadH_
),
make_pad_transform
(
Wi_
,
InLeftPadW_
,
InRightPadW_
),
make_pass_through_transform
(
C_
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_y_ho_x_wo_c_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_desc
,
make_tuple
(
make_pass_through_transform
(
N_
),
make_embed_transform
(
make_tuple
(
Y_
,
Ho_
),
make_tuple
(
ConvDilationH_
,
ConvStrideH_
)),
make_embed_transform
(
make_tuple
(
X_
,
Wo_
),
make_tuple
(
ConvDilationW_
,
ConvStrideW_
)),
make_pass_through_transform
(
C_
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
return
transform_tensor_descriptor
(
in_n_y_ho_x_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N_
,
Ho_
,
Wo_
)),
make_merge_transform
(
make_tuple
(
Y_
,
X_
,
C_
))),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
return
in_gemmm_gemmk_desc
;
const
auto
in_n_hi_wi_groups_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N_
,
Hi_
,
Wi_
,
NumGroupsToMerge
,
C_
),
make_tuple
(
NStrideTensorA_
,
HiStride_
,
WiStride_
,
GStrideTensorA_
,
CStrideTensorA_
));
const
auto
in_n_hip_wip_groups_c_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_groups_c_desc
,
make_tuple
(
make_pass_through_transform
(
N_
),
make_pad_transform
(
Hi_
,
InLeftPadH_
,
InRightPadH_
),
make_pad_transform
(
Wi_
,
InLeftPadW_
,
InRightPadW_
),
make_pass_through_transform
(
NumGroupsToMerge
),
make_pass_through_transform
(
C_
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
auto
in_n_y_ho_x_wo_groups_c_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_groups_c_desc
,
make_tuple
(
make_pass_through_transform
(
N_
),
make_embed_transform
(
make_tuple
(
Y_
,
Ho_
),
make_tuple
(
ConvDilationH_
,
ConvStrideH_
)),
make_embed_transform
(
make_tuple
(
X_
,
Wo_
),
make_tuple
(
ConvDilationW_
,
ConvStrideW_
)),
make_pass_through_transform
(
NumGroupsToMerge
),
make_pass_through_transform
(
C_
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{},
Sequence
<
6
>
{}));
return
transform_tensor_descriptor
(
in_n_y_ho_x_wo_groups_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N_
,
Ho_
,
Wo_
,
NumGroupsToMerge
)),
make_merge_transform
(
make_tuple
(
Y_
,
X_
,
C_
))),
make_tuple
(
Sequence
<
0
,
2
,
4
,
5
>
{},
Sequence
<
1
,
3
,
6
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}
}
...
...
@@ -868,149 +751,293 @@ struct TransformConvFwdToGemm
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
NDHWGC
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNDHWC
>
),
bool
>::
type
=
false
>
static
auto
MakeADescriptor_M_K
(
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
/* b_g_k_c_xs_strides */
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
/* c_g_n_k_wos_strides */
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
input_right_pads
)
{
const
index_t
N
=
a_g_n_c_wis_lengths
[
1
];
const
index_t
C
=
a_g_n_c_wis_lengths
[
2
];
const
index_t
Di
=
a_g_n_c_wis_lengths
[
3
];
const
index_t
Hi
=
a_g_n_c_wis_lengths
[
4
];
const
index_t
Wi
=
a_g_n_c_wis_lengths
[
5
];
const
index_t
Do
=
c_g_n_k_wos_lengths
[
3
];
const
index_t
Ho
=
c_g_n_k_wos_lengths
[
4
];
const
index_t
Wo
=
c_g_n_k_wos_lengths
[
5
];
const
index_t
ConvStrideD
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
1
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
2
];
__host__
__device__
auto
MakeADescriptor_M_K
()
const
{
if
constexpr
(
ConvForwardSpecialization
==
device
::
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
const
index_t
NDoHoWo
=
N
*
ck
::
accumulate_n
<
index_t
>
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
// This is different
const
index_t
WiStride
=
a_g_n_c_wis_strides
[
2
+
NDimSpatial
];
const
auto
CStride
=
I1
;
const
auto
in_gemmm_gemmk_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
NDoHoWo
,
C
),
make_tuple
(
WiStride
,
CStride
));
return
in_gemmm_gemmk_desc
;
if
constexpr
(
NumGroupsToMerge
==
1
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
NDoHoWo_
,
C_
),
make_tuple
(
WiStride_
,
CStrideTensorA_
));
}
else
{
const
auto
in_gemmm_groups_gemmk_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
NDoHoWo_
,
NumGroupsToMerge
,
C_
),
make_tuple
(
WiStride_
,
GStrideTensorA_
,
CStrideTensorA_
));
return
transform_tensor_descriptor
(
in_gemmm_groups_gemmk_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
NDoHoWo_
,
NumGroupsToMerge
)),
make_pass_through_transform
(
C_
)),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}
else
if
constexpr
(
ConvForwardSpecialization
==
device
::
ConvolutionForwardSpecialization
::
Filter3x3
)
{
if
constexpr
(
NumGroupsToMerge
==
1
)
{
const
auto
in_n_di_hi_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N_
,
Di_
,
Hi_
,
Wi_
),
make_tuple
(
NStrideTensorA_
,
DiStride_
,
HiStride_
,
WiStride_
));
const
auto
in_n_hip_wip_c_desc
=
transform_tensor_descriptor
(
in_n_di_hi_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N_
),
make_pad_transform
(
Di_
,
InLeftPadD_
,
InRightPadD_
),
make_pad_transform
(
Hi_
,
InLeftPadH_
,
InRightPadH_
),
make_pad_transform
(
Wi_
,
InLeftPadW_
,
InRightPadW_
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_z_do_y_ho_x_wo_c_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_desc
,
make_tuple
(
make_pass_through_transform
(
N_
),
make_embed_transform
(
make_tuple
(
Number
<
3
>
{},
Do_
),
make_tuple
(
ConvDilationD_
,
ConvStrideD_
)),
make_embed_transform
(
make_tuple
(
Number
<
3
>
{},
Ho_
),
make_tuple
(
ConvDilationH_
,
ConvStrideH_
)),
make_embed_transform
(
make_tuple
(
Number
<
3
>
{},
Wo_
),
make_tuple
(
ConvDilationW_
,
ConvStrideW_
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
,
6
>
{}));
return
transform_tensor_descriptor
(
in_n_z_do_y_ho_x_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N_
,
Do_
,
Ho_
,
Wo_
)),
make_merge_transform
(
make_tuple
(
Number
<
3
>
{},
Number
<
3
>
{},
Number
<
3
>
{}))),
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
>
{},
Sequence
<
1
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
const
auto
in_n_di_hi_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N_
,
Di_
,
Hi_
,
Wi_
,
NumGroupsToMerge
),
make_tuple
(
NStrideTensorA_
,
DiStride_
,
HiStride_
,
WiStride_
,
GStrideTensorA_
));
const
auto
in_n_hip_wip_c_desc
=
transform_tensor_descriptor
(
in_n_di_hi_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N_
),
make_pad_transform
(
Di_
,
InLeftPadD_
,
InRightPadD_
),
make_pad_transform
(
Hi_
,
InLeftPadH_
,
InRightPadH_
),
make_pad_transform
(
Wi_
,
InLeftPadW_
,
InRightPadW_
),
make_pass_through_transform
(
NumGroupsToMerge
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
auto
in_n_z_do_y_ho_x_wo_c_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_desc
,
make_tuple
(
make_pass_through_transform
(
N_
),
make_embed_transform
(
make_tuple
(
Number
<
3
>
{},
Do_
),
make_tuple
(
ConvDilationD_
,
ConvStrideD_
)),
make_embed_transform
(
make_tuple
(
Number
<
3
>
{},
Ho_
),
make_tuple
(
ConvDilationH_
,
ConvStrideH_
)),
make_embed_transform
(
make_tuple
(
Number
<
3
>
{},
Wo_
),
make_tuple
(
ConvDilationW_
,
ConvStrideW_
)),
make_pass_through_transform
(
NumGroupsToMerge
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
,
6
>
{},
Sequence
<
7
>
{}));
return
transform_tensor_descriptor
(
in_n_z_do_y_ho_x_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N_
,
Do_
,
Ho_
,
Wo_
,
NumGroupsToMerge
)),
make_merge_transform
(
make_tuple
(
Number
<
3
>
{},
Number
<
3
>
{},
Number
<
3
>
{}))),
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
,
7
>
{},
Sequence
<
1
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}
else
if
constexpr
(
ConvForwardSpecialization
==
device
::
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
{
// This is different
const
index_t
NStride
=
a_g_n_c_wis_strides
[
1
];
const
index_t
DiStride
=
a_g_n_c_wis_strides
[
3
];
const
index_t
HiStride
=
a_g_n_c_wis_strides
[
4
];
const
index_t
WiStride
=
a_g_n_c_wis_strides
[
5
];
const
auto
CStride
=
I1
;
const
auto
in_n_di_hi_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
C
),
make_tuple
(
NStride
,
DiStride
,
HiStride
,
WiStride
,
CStride
));
const
auto
in_n_do_ho_wo_c_desc
=
transform_tensor_descriptor
(
in_n_di_hi_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Do
),
make_tuple
(
ConvStrideD
)),
make_embed_transform
(
make_tuple
(
Ho
),
make_tuple
(
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
Wo
),
make_tuple
(
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
auto
in_gemmm_gemmk_desc
=
transform_tensor_descriptor
(
in_n_do_ho_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
in_gemmm_gemmk_desc
;
if
constexpr
(
NumGroupsToMerge
==
1
)
{
const
auto
in_n_di_hi_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N_
,
Di_
,
Hi_
,
Wi_
,
C_
),
make_tuple
(
NStrideTensorA_
,
DiStride_
,
HiStride_
,
WiStride_
,
CStrideTensorA_
));
const
auto
in_n_do_ho_wo_c_desc
=
transform_tensor_descriptor
(
in_n_di_hi_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N_
),
make_embed_transform
(
make_tuple
(
Do_
),
make_tuple
(
ConvStrideD_
)),
make_embed_transform
(
make_tuple
(
Ho_
),
make_tuple
(
ConvStrideH_
)),
make_embed_transform
(
make_tuple
(
Wo_
),
make_tuple
(
ConvStrideW_
)),
make_pass_through_transform
(
C_
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
return
transform_tensor_descriptor
(
in_n_do_ho_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N_
,
Do_
,
Ho_
,
Wo_
)),
make_pass_through_transform
(
C_
)),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
const
auto
in_n_di_hi_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N_
,
Di_
,
Hi_
,
Wi_
,
NumGroupsToMerge
,
C_
),
make_tuple
(
NStrideTensorA_
,
DiStride_
,
HiStride_
,
WiStride_
,
GStrideTensorA_
,
CStrideTensorA_
));
const
auto
in_n_do_ho_wo_c_desc
=
transform_tensor_descriptor
(
in_n_di_hi_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N_
),
make_embed_transform
(
make_tuple
(
Do_
),
make_tuple
(
ConvStrideD_
)),
make_embed_transform
(
make_tuple
(
Ho_
),
make_tuple
(
ConvStrideH_
)),
make_embed_transform
(
make_tuple
(
Wo_
),
make_tuple
(
ConvStrideW_
)),
make_pass_through_transform
(
NumGroupsToMerge
),
make_pass_through_transform
(
C_
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}));
return
transform_tensor_descriptor
(
in_n_do_ho_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N_
,
Do_
,
Ho_
,
Wo_
,
NumGroupsToMerge
)),
make_pass_through_transform
(
C_
)),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}
else
{
const
index_t
Z
=
b_g_k_c_xs_lengths
[
3
];
const
index_t
Y
=
b_g_k_c_xs_lengths
[
4
];
const
index_t
X
=
b_g_k_c_xs_lengths
[
5
];
const
index_t
ConvDilationD
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
1
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
2
];
const
index_t
InLeftPadD
=
input_left_pads
[
0
];
const
index_t
InLeftPadH
=
input_left_pads
[
1
];
const
index_t
InLeftPadW
=
input_left_pads
[
2
];
const
index_t
InRightPadD
=
input_right_pads
[
0
];
const
index_t
InRightPadH
=
input_right_pads
[
1
];
const
index_t
InRightPadW
=
input_right_pads
[
2
];
// This is different
const
index_t
NStride
=
a_g_n_c_wis_strides
[
1
];
const
index_t
DiStride
=
a_g_n_c_wis_strides
[
3
];
const
index_t
HiStride
=
a_g_n_c_wis_strides
[
4
];
const
index_t
WiStride
=
a_g_n_c_wis_strides
[
5
];
const
auto
CStride
=
I1
;
const
auto
in_n_di_hi_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
C
),
make_tuple
(
NStride
,
DiStride
,
HiStride
,
WiStride
,
CStride
));
const
auto
in_n_hip_wip_c_desc
=
transform_tensor_descriptor
(
in_n_di_hi_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Di
,
InLeftPadD
,
InRightPadD
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
auto
in_n_z_do_y_ho_x_wo_c_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Z
,
Do
),
make_tuple
(
ConvDilationD
,
ConvStrideD
)),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
,
6
>
{},
Sequence
<
7
>
{}));
const
auto
in_gemmm_gemmk_desc
=
transform_tensor_descriptor
(
in_n_z_do_y_ho_x_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
)),
make_merge_transform
(
make_tuple
(
Z
,
Y
,
X
,
C
))),
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
>
{},
Sequence
<
1
,
3
,
5
,
7
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
in_gemmm_gemmk_desc
;
if
constexpr
(
NumGroupsToMerge
==
1
)
{
const
auto
in_n_di_hi_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N_
,
Di_
,
Hi_
,
Wi_
,
C_
),
make_tuple
(
NStrideTensorA_
,
DiStride_
,
HiStride_
,
WiStride_
,
CStrideTensorA_
));
const
auto
in_n_hip_wip_c_desc
=
transform_tensor_descriptor
(
in_n_di_hi_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N_
),
make_pad_transform
(
Di_
,
InLeftPadD_
,
InRightPadD_
),
make_pad_transform
(
Hi_
,
InLeftPadH_
,
InRightPadH_
),
make_pad_transform
(
Wi_
,
InLeftPadW_
,
InRightPadW_
),
make_pass_through_transform
(
C_
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
auto
in_n_z_do_y_ho_x_wo_c_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_desc
,
make_tuple
(
make_pass_through_transform
(
N_
),
make_embed_transform
(
make_tuple
(
Z_
,
Do_
),
make_tuple
(
ConvDilationD_
,
ConvStrideD_
)),
make_embed_transform
(
make_tuple
(
Y_
,
Ho_
),
make_tuple
(
ConvDilationH_
,
ConvStrideH_
)),
make_embed_transform
(
make_tuple
(
X_
,
Wo_
),
make_tuple
(
ConvDilationW_
,
ConvStrideW_
)),
make_pass_through_transform
(
C_
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
,
6
>
{},
Sequence
<
7
>
{}));
return
transform_tensor_descriptor
(
in_n_z_do_y_ho_x_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N_
,
Do_
,
Ho_
,
Wo_
)),
make_merge_transform
(
make_tuple
(
Z_
,
Y_
,
X_
,
C_
))),
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
>
{},
Sequence
<
1
,
3
,
5
,
7
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
const
auto
in_n_di_hi_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N_
,
Di_
,
Hi_
,
Wi_
,
NumGroupsToMerge
,
C_
),
make_tuple
(
NStrideTensorA_
,
DiStride_
,
HiStride_
,
WiStride_
,
GStrideTensorA_
,
CStrideTensorA_
));
const
auto
in_n_hip_wip_c_desc
=
transform_tensor_descriptor
(
in_n_di_hi_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N_
),
make_pad_transform
(
Di_
,
InLeftPadD_
,
InRightPadD_
),
make_pad_transform
(
Hi_
,
InLeftPadH_
,
InRightPadH_
),
make_pad_transform
(
Wi_
,
InLeftPadW_
,
InRightPadW_
),
make_pass_through_transform
(
NumGroupsToMerge
),
make_pass_through_transform
(
C_
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_n_z_do_y_ho_x_wo_c_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_desc
,
make_tuple
(
make_pass_through_transform
(
N_
),
make_embed_transform
(
make_tuple
(
Z_
,
Do_
),
make_tuple
(
ConvDilationD_
,
ConvStrideD_
)),
make_embed_transform
(
make_tuple
(
Y_
,
Ho_
),
make_tuple
(
ConvDilationH_
,
ConvStrideH_
)),
make_embed_transform
(
make_tuple
(
X_
,
Wo_
),
make_tuple
(
ConvDilationW_
,
ConvStrideW_
)),
make_pass_through_transform
(
NumGroupsToMerge
),
make_pass_through_transform
(
C_
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
,
6
>
{},
Sequence
<
7
>
{},
Sequence
<
8
>
{}));
return
transform_tensor_descriptor
(
in_n_z_do_y_ho_x_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N_
,
Do_
,
Ho_
,
Wo_
,
NumGroupsToMerge
)),
make_merge_transform
(
make_tuple
(
Z_
,
Y_
,
X_
,
C_
))),
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
,
7
>
{},
Sequence
<
1
,
3
,
5
,
8
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}
}
...
...
@@ -1019,20 +1046,53 @@ struct TransformConvFwdToGemm
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
GKYXC
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
GKZYXC
>
,
bool
>::
type
=
false
>
__host__
__device__
static
auto
MakeBDescriptor_N_K
(
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
/* b_g_k_c_xs_strides */
)
__host__
__device__
auto
MakeBDescriptor_N_K
()
const
{
const
index_t
K
=
b_g_k_c_xs_lengths
[
1
];
const
index_t
C
=
b_g_k_c_xs_lengths
[
2
];
const
index_t
YX
=
mult_accumulate_n
<
index_t
>
(
b_g_k_c_xs_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
);
if
constexpr
(
ConvForwardSpecialization
==
device
::
ConvolutionForwardSpecialization
::
Filter3x3
)
{
using
FilterSizeNumType
=
std
::
conditional_t
<
NDimSpatial
==
1
,
Number
<
3
>
,
std
::
conditional_t
<
NDimSpatial
==
2
,
Number
<
9
>
,
Number
<
27
>>>
;
const
auto
wei_gemmn_gemmk_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
YX
*
C
));
if
constexpr
(
NumGroupsToMerge
==
1
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
K_
,
FilterSizeNumType
{}));
}
else
{
return
wei_gemmn_gemmk_desc
;
const
auto
wei_gemmn_groups_gemmk_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K_
,
NumGroupsToMerge
,
FilterSizeNumType
{}),
make_tuple
(
KStrideTensorB_
,
GStrideTensorB_
,
CStrideTensorB_
));
return
transform_tensor_descriptor
(
wei_gemmn_groups_gemmk_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
K_
,
NumGroupsToMerge
)),
make_pass_through_transform
(
FilterSizeNumType
{})),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}
else
{
if
constexpr
(
NumGroupsToMerge
==
1
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
K_
,
ZYX_
*
C_
));
}
else
{
const
auto
wei_gemmn_groups_gemmk_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K_
,
NumGroupsToMerge
,
ZYX_
*
C_
),
make_tuple
(
KStrideTensorB_
,
GStrideTensorB_
,
CStrideTensorB_
));
return
transform_tensor_descriptor
(
wei_gemmn_groups_gemmk_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
K_
,
NumGroupsToMerge
)),
make_pass_through_transform
(
ZYX_
*
C_
)),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}
}
template
<
...
...
@@ -1044,26 +1104,14 @@ struct TransformConvFwdToGemm
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
KYXGC
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
KZYXGC
>
,
bool
>::
type
=
false
>
__host__
__device__
static
auto
MakeBDescriptor_N_K
(
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
)
__host__
__device__
auto
MakeBDescriptor_N_K
()
const
{
const
index_t
K
=
b_g_k_c_xs_lengths
[
1
];
const
index_t
C
=
b_g_k_c_xs_lengths
[
2
];
const
index_t
YX
=
mult_accumulate_n
<
index_t
>
(
b_g_k_c_xs_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
);
const
index_t
KStride
=
b_g_k_c_xs_strides
[
1
];
const
index_t
XStride
=
b_g_k_c_xs_strides
[
2
+
NDimSpatial
];
const
auto
CStride
=
I1
;
const
auto
wei_k_yx_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K
,
YX
,
C
),
make_tuple
(
KStride
,
XStride
,
CStride
));
make_tuple
(
K
_
,
Z
YX
_
,
C
_
),
make_tuple
(
KStride
TensorB_
,
XStride
_
,
CStride
TensorB_
));
const
auto
wei_gemmn_gemmk_desc
=
transform_tensor_descriptor
(
wei_k_yx_c_desc
,
make_tuple
(
make_pass_through_transform
(
K
),
make_merge_transform
(
make_tuple
(
YX
,
C
))),
make_tuple
(
make_pass_through_transform
(
K
_
),
make_merge_transform
(
make_tuple
(
Z
YX
_
,
C
_
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
...
...
@@ -1075,23 +1123,14 @@ struct TransformConvFwdToGemm
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
GNHWK
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
GNDHWK
>
,
bool
>::
type
=
false
>
__host__
__device__
static
auto
MakeCDescriptor_M_N
(
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
/* c_g_n_k_wos_strides */
)
__host__
__device__
auto
MakeCDescriptor_M_N
()
const
{
const
index_t
N
=
c_g_n_k_wos_lengths
[
1
];
const
index_t
K
=
c_g_n_k_wos_lengths
[
2
];
const
index_t
NHoWo
=
N
*
mult_accumulate_n
<
index_t
>
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
);
const
auto
out_gemmm_gemmn_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
NHoWo
,
K
));
return
out_gemmm_gemmn_desc
;
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
NDoHoWo_
,
K_
));
}
template
<
typename
CLayout
,
typename
std
::
enable_if
<
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_NW_K
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_NHW_K
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_NDHW_K
>
||
...
...
@@ -1099,45 +1138,82 @@ struct TransformConvFwdToGemm
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
NHWGK
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
NDHWGK
>
,
bool
>::
type
=
false
>
__host__
__device__
static
auto
MakeCDescriptor_M_N
(
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_strides
)
__host__
__device__
auto
MakeCDescriptor_M_N
()
const
{
const
index_t
N
=
c_g_n_k_wos_lengths
[
1
];
const
index_t
K
=
c_g_n_k_wos_lengths
[
2
];
const
auto
KStride
=
I1
;
const
index_t
WoStride
=
c_g_n_k_wos_strides
[
NDimSpatial
+
2
];
const
index_t
NHoWo
=
N
*
mult_accumulate_n
<
index_t
>
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
);
const
auto
out_gemmm_gemmn_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
NHoWo
,
K
),
make_tuple
(
WoStride
,
KStride
));
return
out_gemmm_gemmn_desc
;
if
constexpr
(
NumGroupsToMerge
==
1
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
NDoHoWo_
,
K_
),
make_tuple
(
WoStride_
,
KStrideTensorC_
));
}
else
{
const
auto
nhwo_groups_k_1_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
NDoHoWo_
,
NumGroupsToMerge
,
K_
,
1
),
make_tuple
(
WoStride_
,
GStrideTensorC_
,
KStrideTensorC_
,
GStrideTensorC_
));
// Padd 1 to NumGroupsToMerge
const
auto
padded_desc
=
transform_tensor_descriptor
(
nhwo_groups_k_1_desc
,
make_tuple
(
make_pass_through_transform
(
NDoHoWo_
),
make_pass_through_transform
(
NumGroupsToMerge
),
make_pass_through_transform
(
K_
),
make_pad_transform
(
1
,
0
,
NumGroupsToMerge
-
1
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
// We need only matrices from diagonal. X_or returns 0 for the same
// values. So if matrices is not on diagonal then it will be stored in padding.
// To avoid use of modulo after xor we assume that NumBatch to merge is power of 2.
static_assert
(
NumGroupsToMerge
==
1
||
NumGroupsToMerge
==
2
||
NumGroupsToMerge
==
4
||
NumGroupsToMerge
==
8
||
NumGroupsToMerge
==
16
||
NumGroupsToMerge
==
32
||
NumGroupsToMerge
==
64
);
const
auto
unmerged_padded_desc
=
transform_tensor_descriptor
(
padded_desc
,
make_tuple
(
make_pass_through_transform
(
NDoHoWo_
),
make_xor_transform
(
make_tuple
(
NumGroupsToMerge
,
NumGroupsToMerge
)),
make_pass_through_transform
(
K_
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
3
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
3
>
{},
Sequence
<
2
>
{}));
// Merge To M, N
return
transform_tensor_descriptor
(
unmerged_padded_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
NDoHoWo_
,
NumGroupsToMerge
)),
make_merge_transform
(
make_tuple
(
K_
,
NumGroupsToMerge
))),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}
// for output bias
template
<
typename
CLayout
,
typename
std
::
enable_if
<
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_K
>,
bool
>::
type
=
false
>
__host__
__device__
static
auto
MakeCDescriptor_M_N
(
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_strides
)
__host__
__device__
auto
MakeCDescriptor_M_N
()
const
{
const
index_t
N
=
c_g_n_k_wos_lengths
[
1
];
const
index_t
K
=
c_g_n_k_wos_lengths
[
2
];
const
index_t
KStride
=
c_g_n_k_wos_strides
[
2
];
const
index_t
NHoWo
=
N
*
mult_accumulate_n
<
index_t
>
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
);
const
auto
out_gemmm_gemmn_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
NHoWo
,
K
),
make_tuple
(
I0
,
KStride
));
make_naive_tensor_descriptor
(
make_tuple
(
N
Do
HoWo
_
,
K
_
),
make_tuple
(
I0
,
KStride
TensorC_
));
return
out_gemmm_gemmn_desc
;
}
public:
index_t
N_
;
private:
const
index_t
Di_
,
Hi_
,
Wi_
;
const
index_t
Do_
,
Ho_
,
Wo_
;
const
index_t
Z_
,
Y_
,
X_
;
const
index_t
K_
,
C_
;
const
index_t
DiStride_
,
HiStride_
,
WiStride_
;
const
index_t
WoStride_
;
const
index_t
XStride_
;
const
index_t
CStrideTensorA_
,
CStrideTensorB_
,
KStrideTensorB_
,
KStrideTensorC_
;
const
index_t
NStrideTensorA_
;
const
index_t
GStrideTensorA_
,
GStrideTensorB_
,
GStrideTensorC_
;
const
index_t
ConvStrideD_
,
ConvStrideH_
,
ConvStrideW_
;
const
index_t
ConvDilationD_
,
ConvDilationH_
,
ConvDilationW_
;
const
index_t
InLeftPadD_
,
InLeftPadH_
,
InLeftPadW_
;
const
index_t
InRightPadD_
,
InRightPadH_
,
InRightPadW_
;
const
index_t
ZYX_
;
index_t
NDoHoWo_
;
};
// wrapper class to call member functions on TransformConvToGemm struct at runtime
...
...
@@ -1149,26 +1225,22 @@ struct TransformConv
template
<
index_t
NDimSpatial
,
device
::
ConvolutionForwardSpecialization
ConvForwardSpecialization
>
auto
transform_func
(
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>
out_lengths
,
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>
out_strides
,
TransformConvFwdToGemm
<
NDimSpatial
,
ConvForwardSpecialization
>
conv_fwd_to_gemm
)
transform_func
(
TransformConvFwdToGemm
<
NDimSpatial
,
ConvForwardSpecialization
>
conv_fwd_to_gemm
)
{
if
(
NDimSpatial
==
2
)
{
return
conv_fwd_to_gemm
.
template
MakeCDescriptor_M_N
<
ck
::
tensor_layout
::
convolution
::
NHWGK
>(
out_lengths
,
out_strides
);
.
template
MakeCDescriptor_M_N
<
ck
::
tensor_layout
::
convolution
::
NHWGK
>();
}
else
if
(
NDimSpatial
==
3
)
{
return
conv_fwd_to_gemm
.
template
MakeCDescriptor_M_N
<
tensor_layout
::
convolution
::
NDHWGK
>(
out_lengths
,
out_strides
);
.
template
MakeCDescriptor_M_N
<
tensor_layout
::
convolution
::
NDHWGK
>();
}
else
if
(
NDimSpatial
==
1
)
{
return
conv_fwd_to_gemm
.
template
MakeCDescriptor_M_N
<
tensor_layout
::
convolution
::
NWGK
>(
out_lengths
,
out_strides
);
return
conv_fwd_to_gemm
.
template
MakeCDescriptor_M_N
<
tensor_layout
::
convolution
::
NWGK
>(
);
}
}
};
...
...
include/ck/utility/amd_smfmac.hpp
View file @
3552041a
...
...
@@ -16,8 +16,15 @@ struct intrin_smfmac_f32_16x16x32f16<16, 16>
__device__
static
void
Run
(
const
half4_t
&
reg_a
,
const
half8_t
&
reg_b
,
const
int32_t
&
reg_idx
,
FloatC
&
reg_c
)
{
#if defined(__gfx94__)
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_smfmac_f32_16x16x32_f16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
0
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
ignore
=
reg_idx
;
#endif
}
};
...
...
@@ -31,8 +38,15 @@ struct intrin_smfmac_f32_16x16x32bf16<16, 16>
__device__
static
void
Run
(
const
bhalf4_t
&
reg_a
,
const
bhalf8_t
&
reg_b
,
const
int32_t
&
reg_idx
,
FloatC
&
reg_c
)
{
#if defined(__gfx94__)
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_smfmac_f32_16x16x32_bf16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
0
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
ignore
=
reg_idx
;
#endif
}
};
...
...
@@ -46,8 +60,15 @@ struct intrin_smfmac_f32_32x32x16f16<32, 32>
__device__
static
void
Run
(
const
half4_t
&
reg_a
,
const
half8_t
&
reg_b
,
const
int32_t
&
reg_idx
,
FloatC
&
reg_c
)
{
#if defined(__gfx94__)
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_smfmac_f32_32x32x16_f16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
0
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
ignore
=
reg_idx
;
#endif
}
};
...
...
@@ -61,8 +82,15 @@ struct intrin_smfmac_f32_32x32x16bf16<32, 32>
__device__
static
void
Run
(
const
bhalf4_t
&
reg_a
,
const
bhalf8_t
&
reg_b
,
const
int32_t
&
reg_idx
,
FloatC
&
reg_c
)
{
#if defined(__gfx94__)
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_smfmac_f32_32x32x16_bf16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
0
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
ignore
=
reg_idx
;
#endif
}
};
...
...
include/ck/utility/math_v2.hpp
View file @
3552041a
...
...
@@ -839,7 +839,7 @@ inline __device__ T rcp(T x)
template
<
typename
T
>
inline
__device__
T
exp
(
T
x
)
{
return
ck
::
type_convert
<
T
>
(
__
expf
(
ck
::
type_convert
<
float
>
(
x
)));
return
ck
::
type_convert
<
T
>
(
__
ocml_exp_f32
(
ck
::
type_convert
<
float
>
(
x
)));
};
template
<
>
...
...
@@ -851,7 +851,7 @@ inline __device__ half_t exp<half_t>(half_t x)
template
<
>
inline
__device__
float
exp
<
float
>
(
float
x
)
{
return
__
expf
(
x
);
return
__
ocml_exp_f32
(
x
);
};
template
<
>
...
...
include/ck/utility/reduction_operator.hpp
View file @
3552041a
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -52,11 +52,19 @@ struct Add
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
half_t
>::
value
,
"The data type is not supported by the Add accumulator!"
);
a
=
a
+
b
;
}
__host__
__device__
inline
constexpr
void
operator
()(
bhalf_t
&
a
,
bhalf_t
b
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
a
=
type_convert
<
bhalf_t
>
(
a_
+
b_
);
}
};
struct
SquaredAdd
...
...
@@ -104,11 +112,19 @@ struct Mul
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
half_t
>::
value
,
"The data type is not supported by the Mul accumulator!"
);
a
=
a
*
b
;
}
__host__
__device__
inline
constexpr
void
operator
()(
bhalf_t
&
a
,
bhalf_t
b
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
a
=
type_convert
<
bhalf_t
>
(
a_
*
b_
);
}
};
struct
Max
...
...
include/ck_tile/core/arch/amd_buffer_addressing.hpp
View file @
3552041a
...
...
@@ -54,233 +54,318 @@ template<> struct buffer_load_trait<4 , thread_buffer<bf16_t, 2>> { using payloa
}
// namespace impl
// TODO: glc/slc/...
template
<
index_t
bytes
>
template
<
index_t
bytes
,
bool
pre_nop
=
false
>
struct
buffer_load
;
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wundefined-reinterpret-cast"
// TODO: strict aliasing rule seems fail when reinterpret_cast between vector type
// (exp_vector_type(xxx))
template
<
>
struct
buffer_load
<
16
>
template
<
bool
pre_nop
>
struct
buffer_load
<
16
,
pre_nop
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
s_offset
,
index_t
/*
s_offset
*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
/*flag*/
=
0
)
index_t
/*flag*/
=
0
,
bool_constant
<
pre_nop
>
=
{})
{
static_assert
(
sizeof
(
T
)
==
16
);
using
mbuf_t
=
typename
impl
::
buffer_load_trait
<
16
,
T
>::
payload_t
;
asm
volatile
(
"buffer_load_dwordx4 %0, %1, %2, %3 offen offset:%4"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
)
:
"memory"
);
if
constexpr
(
pre_nop
)
asm
volatile
(
"s_nop 4
\n
"
"buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
)
:
"memory"
);
else
asm
volatile
(
"buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
)
:
"memory"
);
}
};
template
<
>
struct
buffer_load
<
8
>
template
<
bool
pre_nop
>
struct
buffer_load
<
8
,
pre_nop
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
s_offset
,
index_t
/*
s_offset
*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
/*flag*/
=
0
)
index_t
/*flag*/
=
0
,
bool_constant
<
pre_nop
>
=
{})
{
static_assert
(
sizeof
(
T
)
==
8
);
using
mbuf_t
=
typename
impl
::
buffer_load_trait
<
8
,
T
>::
payload_t
;
asm
volatile
(
"buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
)
:
"memory"
);
if
constexpr
(
pre_nop
)
asm
volatile
(
"s_nop 4
\n
"
"buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
)
:
"memory"
);
else
asm
volatile
(
"buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
)
:
"memory"
);
}
};
template
<
>
struct
buffer_load
<
4
>
template
<
bool
pre_nop
>
struct
buffer_load
<
4
,
pre_nop
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
s_offset
,
index_t
/*
s_offset
*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
/*flag*/
=
0
)
index_t
/*flag*/
=
0
,
bool_constant
<
pre_nop
>
=
{})
{
static_assert
(
sizeof
(
T
)
==
4
);
using
mbuf_t
=
typename
impl
::
buffer_load_trait
<
4
,
T
>::
payload_t
;
asm
volatile
(
"buffer_load_dword %0, %1, %2, %3 offen offset:%4"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
)
:
"memory"
);
if
constexpr
(
pre_nop
)
asm
volatile
(
"s_nop 4
\n
"
"buffer_load_dword %0, %1, %2, 0 offen offset:%3"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
)
:
"memory"
);
else
asm
volatile
(
"buffer_load_dword %0, %1, %2, 0 offen offset:%3"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
)
:
"memory"
);
}
};
template
<
>
struct
buffer_load
<
2
>
template
<
bool
pre_nop
>
struct
buffer_load
<
2
,
pre_nop
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
s_offset
,
index_t
/*
s_offset
*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
/*flag*/
=
0
)
index_t
/*flag*/
=
0
,
bool_constant
<
pre_nop
>
=
{})
{
static_assert
(
sizeof
(
T
)
==
4
);
// subdword is buggy, use dword buf and convert manually
using
mbuf_t
=
typename
impl
::
buffer_load_trait
<
2
,
T
>::
payload_t
;
asm
volatile
(
"buffer_load_ushort %0, %1, %2, %3 offen offset:%4"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
)
:
"memory"
);
if
constexpr
(
pre_nop
)
asm
volatile
(
"s_nop 4
\n
"
"buffer_load_ushort %0, %1, %2, 0 offen offset:%3"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
)
:
"memory"
);
else
asm
volatile
(
"buffer_load_ushort %0, %1, %2, 0 offen offset:%3"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
)
:
"memory"
);
}
};
template
<
>
struct
buffer_load
<
1
>
template
<
bool
pre_nop
>
struct
buffer_load
<
1
,
pre_nop
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
s_offset
,
index_t
/*
s_offset
*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
/*flag*/
=
0
)
index_t
/*flag*/
=
0
,
bool_constant
<
pre_nop
>
=
{})
{
static_assert
(
sizeof
(
T
)
==
4
);
using
mbuf_t
=
typename
impl
::
buffer_load_trait
<
1
,
T
>::
payload_t
;
asm
volatile
(
"buffer_load_ubyte %0, %1, %2, %3 offen offset:%4"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
)
:
"memory"
);
if
constexpr
(
pre_nop
)
asm
volatile
(
"s_nop 4
\n
"
"buffer_load_ubyte %0, %1, %2, 0 offen offset:%3"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
)
:
"memory"
);
else
asm
volatile
(
"buffer_load_ubyte %0, %1, %2, 0 offen offset:%3"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
)
:
"memory"
);
}
};
template
<
index_t
bytes
>
template
<
index_t
bytes
,
bool
pre_nop
=
false
>
struct
buffer_load_if
;
template
<
>
struct
buffer_load_if
<
16
>
template
<
bool
pre_nop
>
struct
buffer_load_if
<
16
,
pre_nop
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
s_offset
,
index_t
/*
s_offset
*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
flag
=
0
)
index_t
flag
=
0
,
bool_constant
<
pre_nop
>
=
{})
{
static_assert
(
sizeof
(
T
)
==
16
);
auto
saved_exec
=
__builtin_amdgcn_read_exec
();
using
mbuf_t
=
typename
impl
::
buffer_load_trait
<
16
,
T
>::
payload_t
;
static_assert
(
sizeof
(
mbuf_t
)
==
sizeof
(
T
));
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %5
\n
"
"buffer_load_dwordx4 %0, %1, %2, %3 offen offset:%4
\n
"
"s_mov_b64 exec %6"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
saved_exec
)
:
"memory"
);
if
constexpr
(
pre_nop
)
asm
volatile
(
"s_nop 4
\n
"
"v_cmpx_le_u32 exec, 1, %4
\n
"
"buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3
\n
"
"s_mov_b64 exec %5"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
saved_exec
)
:
"memory"
);
else
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %4
\n
"
"buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3
\n
"
"s_mov_b64 exec %5"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
saved_exec
)
:
"memory"
);
}
};
template
<
>
struct
buffer_load_if
<
8
>
template
<
bool
pre_nop
>
struct
buffer_load_if
<
8
,
pre_nop
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
s_offset
,
index_t
/*
s_offset
*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
flag
=
0
)
index_t
flag
=
0
,
bool_constant
<
pre_nop
>
=
{})
{
static_assert
(
sizeof
(
T
)
==
8
);
auto
saved_exec
=
__builtin_amdgcn_read_exec
();
using
mbuf_t
=
typename
impl
::
buffer_load_trait
<
8
,
T
>::
payload_t
;
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %5
\n
"
"buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4
\n
"
"s_mov_b64 exec %6"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
saved_exec
)
:
"memory"
);
if
constexpr
(
pre_nop
)
asm
volatile
(
"s_nop 4
\n
"
"v_cmpx_le_u32 exec, 1, %4
\n
"
"buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3
\n
"
"s_mov_b64 exec %5"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
saved_exec
)
:
"memory"
);
else
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %4
\n
"
"buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3
\n
"
"s_mov_b64 exec %5"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
saved_exec
)
:
"memory"
);
}
};
template
<
>
struct
buffer_load_if
<
4
>
template
<
bool
pre_nop
>
struct
buffer_load_if
<
4
,
pre_nop
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
s_offset
,
index_t
/*
s_offset
*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
flag
=
0
)
index_t
flag
=
0
,
bool_constant
<
pre_nop
>
=
{})
{
static_assert
(
sizeof
(
T
)
==
4
);
auto
saved_exec
=
__builtin_amdgcn_read_exec
();
using
mbuf_t
=
typename
impl
::
buffer_load_trait
<
4
,
T
>::
payload_t
;
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %5
\n
"
"buffer_load_dword %0, %1, %2, %3 offen offset:%4
\n
"
"s_mov_b64 exec %6"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
saved_exec
)
:
"memory"
);
if
constexpr
(
pre_nop
)
asm
volatile
(
"s_nop 4
\n
"
"v_cmpx_le_u32 exec, 1, %4
\n
"
"buffer_load_dword %0, %1, %2, 0 offen offset:%3
\n
"
"s_mov_b64 exec %5"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
saved_exec
)
:
"memory"
);
else
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %4
\n
"
"buffer_load_dword %0, %1, %2, 0 offen offset:%3
\n
"
"s_mov_b64 exec %5"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
saved_exec
)
:
"memory"
);
}
};
template
<
>
struct
buffer_load_if
<
2
>
template
<
bool
pre_nop
>
struct
buffer_load_if
<
2
,
pre_nop
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
s_offset
,
index_t
/*
s_offset
*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
flag
=
0
)
index_t
flag
=
0
,
bool_constant
<
pre_nop
>
=
{})
{
static_assert
(
sizeof
(
T
)
==
4
);
auto
saved_exec
=
__builtin_amdgcn_read_exec
();
using
mbuf_t
=
typename
impl
::
buffer_load_trait
<
2
,
T
>::
payload_t
;
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %5
\n
"
"buffer_load_ushort %0, %1, %2, %3 offen offset:%4
\n
"
"s_mov_b64 exec %6"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
saved_exec
)
:
"memory"
);
if
constexpr
(
pre_nop
)
asm
volatile
(
"s_nop 4
\n
"
"v_cmpx_le_u32 exec, 1, %4
\n
"
"buffer_load_ushort %0, %1, %2, 0 offen offset:%3
\n
"
"s_mov_b64 exec %5"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
saved_exec
)
:
"memory"
);
else
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %4
\n
"
"buffer_load_ushort %0, %1, %2, 0 offen offset:%3
\n
"
"s_mov_b64 exec %5"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
saved_exec
)
:
"memory"
);
}
};
template
<
>
struct
buffer_load_if
<
1
>
template
<
bool
pre_nop
>
struct
buffer_load_if
<
1
,
pre_nop
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
s_offset
,
index_t
/*
s_offset
*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
flag
=
0
)
index_t
flag
=
0
,
bool_constant
<
pre_nop
>
=
{})
{
static_assert
(
sizeof
(
T
)
==
4
);
auto
saved_exec
=
__builtin_amdgcn_read_exec
();
using
mbuf_t
=
typename
impl
::
buffer_load_trait
<
1
,
T
>::
payload_t
;
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %5
\n
"
"buffer_load_ubyte %0, %1, %2, %3 offen offset:%4
\n
"
"s_mov_b64 exec %6"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
saved_exec
)
:
"memory"
);
if
constexpr
(
pre_nop
)
asm
volatile
(
"s_nop 4
\n
"
"v_cmpx_le_u32 exec, 1, %4
\n
"
"buffer_load_ubyte %0, %1, %2, 0 offen offset:%3
\n
"
"s_mov_b64 exec %5"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
saved_exec
)
:
"memory"
);
else
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %4
\n
"
"buffer_load_ubyte %0, %1, %2, 0 offen offset:%3
\n
"
"s_mov_b64 exec %5"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
saved_exec
)
:
"memory"
);
}
};
#pragma clang diagnostic pop // "-Wundefined-reinterpret-cast"
...
...
@@ -294,17 +379,16 @@ struct buffer_store<16>
CK_TILE_DEVICE
void
operator
()(
const
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
s_offset
,
index_t
/*
s_offset
*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
/*flag*/
=
1
)
{
static_assert
(
sizeof
(
T
)
==
16
);
using
mbuf_t
=
fp32x4_t
;
asm
volatile
(
"buffer_store_dwordx4 %0, %1, %2, %3 offen offset:%4"
:
:
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
)
:
"memory"
);
asm
volatile
(
"buffer_store_dwordx4 %0, %1, %2, 0 offen offset:%3"
:
:
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
)
:
"memory"
);
}
};
...
...
@@ -315,17 +399,16 @@ struct buffer_store<8>
CK_TILE_DEVICE
void
operator
()(
const
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
s_offset
,
index_t
/*
s_offset
*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
/*flag*/
=
1
)
{
static_assert
(
sizeof
(
T
)
==
8
);
using
mbuf_t
=
fp32x2_t
;
asm
volatile
(
"buffer_store_dwordx2 %0, %1, %2, %3 offen offset:%4"
:
:
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
)
:
"memory"
);
asm
volatile
(
"buffer_store_dwordx2 %0, %1, %2, 0 offen offset:%3"
:
:
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
)
:
"memory"
);
}
};
...
...
@@ -336,17 +419,16 @@ struct buffer_store<4>
CK_TILE_DEVICE
void
operator
()(
const
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
s_offset
,
index_t
/*
s_offset
*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
/*flag*/
=
1
)
{
static_assert
(
sizeof
(
T
)
==
4
);
using
mbuf_t
=
float
;
asm
volatile
(
"buffer_store_dword %0, %1, %2, %3 offen offset:%4"
:
:
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
)
:
"memory"
);
asm
volatile
(
"buffer_store_dword %0, %1, %2, 0 offen offset:%3"
:
:
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
)
:
"memory"
);
}
};
...
...
@@ -357,17 +439,16 @@ struct buffer_store<2>
CK_TILE_DEVICE
void
operator
()(
const
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
s_offset
,
index_t
/*
s_offset
*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
/*flag*/
=
1
)
{
static_assert
(
sizeof
(
T
)
==
2
);
using
mbuf_t
=
short
;
asm
volatile
(
"buffer_store_short %0, %1, %2, %3 offen offset:%4"
:
:
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
)
:
"memory"
);
asm
volatile
(
"buffer_store_short %0, %1, %2, 0 offen offset:%3"
:
:
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
)
:
"memory"
);
}
};
...
...
@@ -378,17 +459,16 @@ struct buffer_store<1>
CK_TILE_DEVICE
void
operator
()(
const
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
s_offset
,
index_t
/*
s_offset
*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
/*flag*/
=
1
)
{
static_assert
(
sizeof
(
T
)
==
4
);
using
mbuf_t
=
float
;
asm
volatile
(
"buffer_store_byte %0, %1, %2, %3 offen offset:%4"
:
:
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
)
:
"memory"
);
asm
volatile
(
"buffer_store_byte %0, %1, %2, 0 offen offset:%3"
:
:
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
)
:
"memory"
);
}
};
...
...
@@ -402,21 +482,20 @@ struct buffer_store_if<16>
CK_TILE_DEVICE
void
operator
()(
const
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
s_offset
,
index_t
/*
s_offset
*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
flag
=
1
)
{
static_assert
(
sizeof
(
T
)
==
16
);
auto
save_exec
=
__builtin_amdgcn_read_exec
();
using
mbuf_t
=
fp32x4_t
;
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %
5
\n
"
"buffer_store_dwordx4 %0, %1, %2,
%3
offen offset:%
4
\n
"
"s_mov_b64 exec %
6
"
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %
4
\n
"
"buffer_store_dwordx4 %0, %1, %2,
0
offen offset:%
3
\n
"
"s_mov_b64 exec %
5
"
:
:
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
save_exec
)
...
...
@@ -431,7 +510,7 @@ struct buffer_store_if<8>
CK_TILE_DEVICE
void
operator
()(
const
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
s_offset
,
index_t
/*
s_offset
*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
flag
=
1
)
{
...
...
@@ -439,14 +518,13 @@ struct buffer_store_if<8>
auto
save_exec
=
__builtin_amdgcn_read_exec
();
// TODO: ugly. rocm-6.0/6.1 seems neet bit_cast to same base type to avoid scratch
using
mbuf_t
=
ext_vector_t
<
typename
T
::
value_type
,
T
::
size
()
>
;
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %
5
\n
"
"buffer_store_dwordx2 %0, %1, %2,
%3
offen offset:%
4
\n
"
"s_mov_b64 exec %
6
"
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %
4
\n
"
"buffer_store_dwordx2 %0, %1, %2,
0
offen offset:%
3
\n
"
"s_mov_b64 exec %
5
"
:
:
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
save_exec
)
...
...
@@ -461,21 +539,20 @@ struct buffer_store_if<4>
CK_TILE_DEVICE
void
operator
()(
const
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
s_offset
,
index_t
/*
s_offset
*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
flag
=
1
)
{
static_assert
(
sizeof
(
T
)
==
4
);
auto
save_exec
=
__builtin_amdgcn_read_exec
();
using
mbuf_t
=
float
;
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %
5
\n
"
"buffer_store_dword %0, %1, %2,
%3
offen offset:%
4
\n
"
"s_mov_b64 exec %
6
"
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %
4
\n
"
"buffer_store_dword %0, %1, %2,
0
offen offset:%
3
\n
"
"s_mov_b64 exec %
5
"
:
:
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
save_exec
)
...
...
@@ -490,21 +567,20 @@ struct buffer_store_if<2>
CK_TILE_DEVICE
void
operator
()(
const
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
s_offset
,
index_t
/*
s_offset
*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
flag
=
1
)
{
static_assert
(
sizeof
(
T
)
==
2
);
auto
save_exec
=
__builtin_amdgcn_read_exec
();
using
mbuf_t
=
short
;
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %
5
\n
"
"buffer_store_short %0, %1, %2,
%3
offen offset:%
4
\n
"
"s_mov_b64 exec %
6
"
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %
4
\n
"
"buffer_store_short %0, %1, %2,
0
offen offset:%
3
\n
"
"s_mov_b64 exec %
5
"
:
:
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
save_exec
)
...
...
@@ -519,21 +595,20 @@ struct buffer_store_if<1>
CK_TILE_DEVICE
void
operator
()(
const
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
s_offset
,
index_t
/*
s_offset
*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
flag
=
1
)
{
static_assert
(
sizeof
(
T
)
==
4
);
auto
save_exec
=
__builtin_amdgcn_read_exec
();
using
mbuf_t
=
float
;
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %
5
\n
"
"buffer_store_byte %0, %1, %2,
%3
offen offset:%
4
\n
"
"s_mov_b64 exec %
6
"
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %
4
\n
"
"buffer_store_byte %0, %1, %2,
0
offen offset:%
3
\n
"
"s_mov_b64 exec %
5
"
:
:
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
save_exec
)
...
...
@@ -901,17 +976,26 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
int
soffset
,
// dst_wave_addr_offset
int
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.atomic.fmax.f64"
);
CK_TILE_DEVICE
void
async_buffer_load_dword
(
void
*
smem
,
int32x4_t
rsrc
,
index_t
voffset
,
index_t
soffset
,
index_t
ioffset
/*max 0xFFF*/
,
index_t
/*flag*/
=
0
)
template
<
bool
pre_nop
=
false
>
CK_TILE_DEVICE
void
async_buffer_load_dword_v
(
void
*
smem
,
int32x4_t
rsrc
,
index_t
voffset
,
index_t
/*soffset*/
,
index_t
ioffset
/*max 0xFFF*/
,
index_t
/*flag*/
=
0
,
bool_constant
<
pre_nop
>
=
{})
{
asm
volatile
(
"buffer_load_dword %1, %2, %3 offen offset:%4 lds"
:
"=r"
(
smem
)
/*dummy dependency for smem*/
:
"v"
(
voffset
),
"s"
(
rsrc
),
"s"
(
soffset
),
"n"
(
ioffset
)
:
"memory"
);
if
constexpr
(
pre_nop
)
asm
volatile
(
"s_nop 4
\n
"
"buffer_load_dword %1, %2, 0 offen offset:%3 lds"
:
"=r"
(
smem
)
/*dummy dependency for smem*/
:
"v"
(
voffset
),
"s"
(
rsrc
),
"n"
(
ioffset
)
:
"memory"
);
else
asm
volatile
(
"buffer_load_dword %1, %2, 0 offen offset:%3 lds"
:
"=r"
(
smem
)
/*dummy dependency for smem*/
:
"v"
(
voffset
),
"s"
(
rsrc
),
"n"
(
ioffset
)
:
"memory"
);
}
CK_TILE_DEVICE
void
async_buffer_load_fence
(
index_t
cnt
=
0
)
...
...
@@ -1223,12 +1307,14 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
template
<
typename
T
,
index_t
N
,
amd_buffer_coherence_enum
coherence
=
amd_buffer_coherence_enum
::
coherence_default
,
bool
oob_conditional_check
=
true
>
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
void
amd_buffer_load_raw_impl
(
thread_buffer
<
T
,
N
>&
dst
,
int32x4_t
src_wave_buffer_resource
,
index_t
src_thread_addr_offset
,
index_t
src_wave_addr_offset
,
index_t
flag
=
0
)
index_t
flag
=
0
,
bool_constant
<
pre_nop
>
=
{})
{
constexpr
index_t
bytes
=
sizeof
(
T
)
*
N
;
static_assert
(
bytes
==
1
||
bytes
==
2
||
bytes
==
4
||
bytes
==
8
||
bytes
==
16
,
...
...
@@ -1237,32 +1323,46 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst,
using
type
=
thread_buffer
<
T
,
N
>
;
if
constexpr
(
oob_conditional_check
)
{
buffer_load_if
<
sizeof
(
type
)
>
{}(
dst
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
,
flag
);
buffer_load_if
<
sizeof
(
type
),
pre_nop
>
{}(
dst
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
,
flag
,
bool_constant
<
pre_nop
>
{});
}
else
{
buffer_load
<
sizeof
(
type
)
>
{}(
dst
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
,
flag
);
buffer_load
<
sizeof
(
type
),
pre_nop
>
{}(
dst
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
,
flag
,
bool_constant
<
pre_nop
>
{});
}
}
template
<
typename
T
,
index_t
N
,
amd_buffer_coherence_enum
coherence
=
amd_buffer_coherence_enum
::
coherence_default
>
amd_buffer_coherence_enum
coherence
=
amd_buffer_coherence_enum
::
coherence_default
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
void
amd_async_buffer_load_impl
(
T
*
smem
,
int32x4_t
src_wave_buffer_resource
,
index_t
src_thread_addr_offset
,
index_t
src_wave_addr_offset
,
index_t
src_immediate_addr_offset
=
0
)
index_t
src_immediate_addr_offset
=
0
,
bool_constant
<
pre_nop
>
=
{})
{
static_assert
(
sizeof
(
T
)
*
N
==
4
,
"wrong! not implemented vector size"
);
async_buffer_load_dword
(
smem
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
src_immediate_addr_offset
);
async_buffer_load_dword_v
(
smem
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
src_immediate_addr_offset
,
0
,
bool_constant
<
pre_nop
>
{});
}
template
<
index_t
N
,
...
...
@@ -1909,20 +2009,50 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave,
template
<
typename
T
,
index_t
N
,
amd_buffer_coherence_enum
coherence
=
amd_buffer_coherence_enum
::
coherence_default
,
bool
oob_conditional_check
=
true
>
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
void
amd_buffer_load_raw
(
thread_buffer
<
T
,
N
>&
dst
,
const
T
*
p_src_wave
,
index_t
src_thread_element_offset
,
index_t
src_element_space_size
,
index_t
is_valid_element
=
0
)
index_t
is_valid_element
=
0
,
bool_constant
<
pre_nop
>
=
{})
{
const
int32x4_t
src_wave_buffer_resource
=
make_wave_buffer_resource
(
p_src_wave
,
src_element_space_size
*
sizeof
(
T
));
index_t
src_thread_addr_offset
=
src_thread_element_offset
*
sizeof
(
T
);
amd_buffer_load_raw_impl
<
T
,
N
,
coherence
,
oob_conditional_check
>
(
dst
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
,
is_valid_element
);
amd_buffer_load_raw_impl
<
T
,
N
,
coherence
,
oob_conditional_check
,
pre_nop
>
(
dst
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
,
is_valid_element
,
bool_constant
<
pre_nop
>
{});
}
// This version support buffer resource as input arg
template
<
typename
T
,
index_t
N
,
amd_buffer_coherence_enum
coherence
=
amd_buffer_coherence_enum
::
coherence_default
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
void
amd_buffer_load_raw
(
thread_buffer
<
T
,
N
>&
dst
,
const
int32x4_t
src_wave_buffer_resource
,
index_t
src_thread_element_offset
,
index_t
is_valid_element
=
0
,
bool_constant
<
pre_nop
>
=
{})
{
index_t
src_thread_addr_offset
=
src_thread_element_offset
*
sizeof
(
T
);
amd_buffer_load_raw_impl
<
T
,
N
,
coherence
,
oob_conditional_check
,
pre_nop
>
(
dst
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
,
is_valid_element
,
bool_constant
<
pre_nop
>
{});
}
// unfortunately async copy can not make sure invalid data is zero inside LDS
...
...
@@ -1931,11 +2061,13 @@ CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst,
// buffer_load OOB still working.
template
<
typename
T
,
index_t
N
,
amd_buffer_coherence_enum
coherence
=
amd_buffer_coherence_enum
::
coherence_default
>
CK_TILE_DEVICE
void
amd_async_buffer_load_with_oob
(
T
*
smem
,
const
T
*
p_src_wave
,
index_t
src_thread_element_offset
,
index_t
src_element_space_size
)
amd_buffer_coherence_enum
coherence
=
amd_buffer_coherence_enum
::
coherence_default
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
void
amd_async_buffer_load_with_oob_raw
(
T
*
smem
,
const
T
*
p_src_wave
,
index_t
src_thread_element_offset
,
index_t
src_element_space_size
,
bool_constant
<
pre_nop
>
=
{})
{
const
int32x4_t
src_wave_buffer_resource
=
make_wave_buffer_resource
(
p_src_wave
,
src_element_space_size
*
sizeof
(
T
));
...
...
@@ -1943,7 +2075,23 @@ CK_TILE_DEVICE void amd_async_buffer_load_with_oob(T* smem,
index_t
src_thread_addr_offset
=
src_thread_element_offset
*
sizeof
(
T
);
amd_async_buffer_load_impl
<
T
,
N
,
coherence
>
(
smem
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
,
0
);
smem
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
,
0
,
bool_constant
<
pre_nop
>
{});
}
// This version support buffer resource as input arg
template
<
typename
T
,
index_t
N
,
amd_buffer_coherence_enum
coherence
=
amd_buffer_coherence_enum
::
coherence_default
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
void
amd_async_buffer_load_with_oob_raw
(
T
*
smem
,
const
int32x4_t
src_wave_buffer_resource
,
index_t
src_thread_element_offset
,
bool_constant
<
pre_nop
>
=
{})
{
index_t
src_thread_addr_offset
=
src_thread_element_offset
*
sizeof
(
T
);
amd_async_buffer_load_impl
<
T
,
N
,
coherence
>
(
smem
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
,
0
,
bool_constant
<
pre_nop
>
{});
}
// buffer_store requires:
...
...
include/ck_tile/core/arch/arch.hpp
View file @
3552041a
...
...
@@ -82,14 +82,12 @@ CK_TILE_DEVICE void block_sync_lds_direct_load()
"
::
);
}
CK_TILE_DEVICE
void
s_nop
()
CK_TILE_DEVICE
void
s_nop
(
index_t
cnt
=
0
)
{
#if 1
asm
volatile
(
"\
s_nop 0
\n
\
"
::
);
asm
volatile
(
"s_nop %0"
:
:
"n"
(
cnt
)
:
);
#else
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
cnt
);
#endif
}
...
...
include/ck_tile/core/config.hpp
View file @
3552041a
...
...
@@ -21,6 +21,7 @@
#define __gfx12__
#endif
#include "hip/hip_version.h"
#ifndef CK_TILE_DONT_USE_HIP_RUNTIME_HEADERS
#include "hip/hip_runtime.h"
#include "hip/hip_fp16.h"
...
...
@@ -147,6 +148,14 @@
#define CK_TILE_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1
#endif
#ifndef CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
#if HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 1 && HIP_VERSION_PATCH >= 40091
#define CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE 1
#else
#define CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE 0
#endif
#endif
#ifndef CK_TILE_DEBUG_LOG
#define CK_TILE_DEBUG_LOG 0
#endif
...
...
include/ck_tile/core/numeric/bfloat16.hpp
View file @
3552041a
...
...
@@ -331,7 +331,10 @@ bfloat16_t sqrt(bfloat16_t x)
};
CK_TILE_DEVICE
bfloat16_t
exp
(
bfloat16_t
x
)
{
return
static_cast
<
bfloat16_t
>
(
__expf
(
static_cast
<
float
>
(
x
)));
};
bfloat16_t
exp
(
bfloat16_t
x
)
{
return
static_cast
<
bfloat16_t
>
(
__ocml_exp_f32
(
static_cast
<
float
>
(
x
)));
};
CK_TILE_DEVICE
bfloat16_t
exp2
(
bfloat16_t
x
)
{
return
static_cast
<
bfloat16_t
>
(
exp2f
(
static_cast
<
float
>
(
x
)));
};
...
...
include/ck_tile/core/numeric/float8.hpp
View file @
3552041a
...
...
@@ -835,7 +835,7 @@ CK_TILE_DEVICE
fp8_t
sqrt
(
fp8_t
x
)
{
return
static_cast
<
fp8_t
>
(
__builtin_amdgcn_sqrtf
(
static_cast
<
float
>
(
x
)));
};
CK_TILE_DEVICE
fp8_t
exp
(
fp8_t
x
)
{
return
static_cast
<
fp8_t
>
(
__
expf
(
static_cast
<
float
>
(
x
)));
};
fp8_t
exp
(
fp8_t
x
)
{
return
static_cast
<
fp8_t
>
(
__
ocml_exp_f32
(
static_cast
<
float
>
(
x
)));
};
CK_TILE_DEVICE
fp8_t
exp2
(
fp8_t
x
)
{
return
static_cast
<
fp8_t
>
(
exp2f
(
static_cast
<
float
>
(
x
)));
};
...
...
@@ -860,7 +860,7 @@ CK_TILE_DEVICE
bf8_t
sqrt
(
bf8_t
x
)
{
return
static_cast
<
bf8_t
>
(
__builtin_amdgcn_sqrtf
(
static_cast
<
float
>
(
x
)));
};
CK_TILE_DEVICE
bf8_t
exp
(
bf8_t
x
)
{
return
static_cast
<
bf8_t
>
(
__
expf
(
static_cast
<
float
>
(
x
)));
};
bf8_t
exp
(
bf8_t
x
)
{
return
static_cast
<
bf8_t
>
(
__
ocml_exp_f32
(
static_cast
<
float
>
(
x
)));
};
CK_TILE_DEVICE
bf8_t
exp2
(
bf8_t
x
)
{
return
static_cast
<
bf8_t
>
(
exp2f
(
static_cast
<
float
>
(
x
)));
};
...
...
include/ck_tile/core/numeric/half.hpp
View file @
3552041a
...
...
@@ -374,7 +374,7 @@ half_t sqrt(half_t x)
};
CK_TILE_DEVICE
half_t
exp
(
half_t
x
)
{
return
static_cast
<
half_t
>
(
__
expf
(
static_cast
<
float
>
(
x
)));
};
half_t
exp
(
half_t
x
)
{
return
static_cast
<
half_t
>
(
__
ocml_exp_f32
(
static_cast
<
float
>
(
x
)));
};
CK_TILE_DEVICE
half_t
exp2
(
half_t
x
)
{
return
static_cast
<
half_t
>
(
exp2f
(
static_cast
<
float
>
(
x
)));
};
...
...
include/ck_tile/core/numeric/math.hpp
View file @
3552041a
...
...
@@ -519,7 +519,7 @@ CK_TILE_DEVICE
double
sqrt
(
double
x
)
{
return
__builtin_amdgcn_sqrt
(
x
);
};
CK_TILE_DEVICE
float
exp
(
float
x
)
{
return
__
expf
(
x
);
};
float
exp
(
float
x
)
{
return
__
ocml_exp_f32
(
x
);
};
CK_TILE_HOST
float
exp
(
float
x
)
{
return
std
::
expf
(
x
);
}
...
...
include/ck_tile/core/tensor/buffer_view.hpp
View file @
3552041a
...
...
@@ -69,6 +69,8 @@ struct buffer_view<address_space_enum::generic,
{
}
CK_TILE_HOST_DEVICE
void
init_raw
()
{}
CK_TILE_DEVICE
static
constexpr
address_space_enum
get_address_space
()
{
return
address_space_enum
::
generic
;
...
...
@@ -224,25 +226,36 @@ struct buffer_view<address_space_enum::global,
T
*
p_data_
=
nullptr
;
BufferSizeType
buffer_size_
;
int32x4_t
cached_buf_res_
;
remove_cvref_t
<
T
>
invalid_element_value_
=
T
{
0
};
CK_TILE_HOST_DEVICE
constexpr
buffer_view
()
:
p_data_
{},
buffer_size_
{},
invalid_element_value_
{}
:
p_data_
{},
buffer_size_
{},
cached_buf_res_
{
0
},
invalid_element_value_
{}
{
}
CK_TILE_HOST_DEVICE
constexpr
buffer_view
(
T
*
p_data
,
BufferSizeType
buffer_size
)
:
p_data_
{
p_data
},
buffer_size_
{
buffer_size
},
invalid_element_value_
{
0
}
:
p_data_
{
p_data
},
buffer_size_
{
buffer_size
},
cached_buf_res_
{
0
},
invalid_element_value_
{
0
}
{
}
CK_TILE_HOST_DEVICE
constexpr
buffer_view
(
T
*
p_data
,
BufferSizeType
buffer_size
,
T
invalid_element_value
)
:
p_data_
{
p_data
},
buffer_size_
{
buffer_size
},
invalid_element_value_
{
invalid_element_value
}
:
p_data_
{
p_data
},
buffer_size_
{
buffer_size
},
cached_buf_res_
{
0
},
invalid_element_value_
{
invalid_element_value
}
{
}
// this is non constexpr intentially (will call some intrinsic internally)
// Must call for buffers that need *_raw load/store
CK_TILE_HOST_DEVICE
void
init_raw
()
{
cached_buf_res_
=
make_wave_buffer_resource
(
p_data_
,
buffer_size_
*
sizeof
(
type
));
}
CK_TILE_DEVICE
static
constexpr
address_space_enum
get_address_space
()
{
return
address_space_enum
::
global
;
...
...
@@ -333,12 +346,15 @@ struct buffer_view<address_space_enum::global,
// i is offset of T, not X. i should be aligned to X
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
,
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
constexpr
auto
get_raw
(
remove_cvref_t
<
X
>&
dst
,
index_t
i
,
bool
is_valid_element
)
const
CK_TILE_DEVICE
constexpr
auto
get_raw
(
remove_cvref_t
<
X
>&
dst
,
index_t
i
,
bool
is_valid_element
,
bool_constant
<
pre_nop
>
=
{})
const
{
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
...
...
@@ -349,18 +365,21 @@ struct buffer_view<address_space_enum::global,
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_buffer_load_raw
<
remove_cvref_t
<
T
>
,
t_per_x
,
Coherence
,
oob_conditional_check
>
(
dst
,
p_data_
,
i
,
buffer_size_
,
is_valid_element
);
amd_buffer_load_raw
<
remove_cvref_t
<
T
>
,
t_per_x
,
Coherence
,
oob_conditional_check
,
pre_nop
>
(
dst
,
cached_buf_res_
,
i
,
is_valid_element
,
bool_constant
<
pre_nop
>
{}
);
}
// i is offset of T, not X. i should be aligned to X
template
<
typename
X
,
bool
pre_nop
=
false
,
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
constexpr
auto
async_get
(
remove_cvref_t
<
T
>*
smem
,
index_t
i
,
bool
/*is_valid_element*/
)
const
CK_TILE_DEVICE
constexpr
auto
async_get_raw
(
remove_cvref_t
<
T
>*
smem
,
index_t
i
,
bool
/*is_valid_element*/
,
bool_constant
<
pre_nop
>
=
{})
const
{
// X is vector of T
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
...
...
@@ -371,8 +390,8 @@ struct buffer_view<address_space_enum::global,
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_async_buffer_load_with_oob
<
remove_cvref_t
<
T
>
,
t_per_x
,
Coherence
>
(
smem
,
p_data_
,
i
,
buffer_size_
);
amd_async_buffer_load_with_oob
_raw
<
remove_cvref_t
<
T
>
,
t_per_x
,
Coherence
>
(
smem
,
cached_buf_res_
,
i
,
bool_constant
<
pre_nop
>
{}
);
}
// i is offset of T, not X. i should be aligned to X
...
...
@@ -627,6 +646,8 @@ struct buffer_view<address_space_enum::lds,
{
}
CK_TILE_HOST_DEVICE
void
init_raw
()
{}
CK_TILE_DEVICE
static
constexpr
address_space_enum
get_address_space
()
{
return
address_space_enum
::
lds
;
...
...
@@ -909,6 +930,8 @@ struct buffer_view<address_space_enum::vgpr,
{
}
CK_TILE_HOST_DEVICE
void
init_raw
()
{}
CK_TILE_DEVICE
static
constexpr
address_space_enum
get_address_space
()
{
return
address_space_enum
::
vgpr
;
...
...
include/ck_tile/core/tensor/load_tile.hpp
View file @
3552041a
...
...
@@ -36,30 +36,37 @@ template <typename T,
typename
WindowLengths_
,
typename
TileDistribution_
,
index_t
NumCoord
,
bool
oob_conditional_check
=
true
>
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
load_tile_raw
(
T
&
tile
,
const
tile_window_with_static_distribution
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
NumCoord
>&
tile_window
,
bool_constant
<
oob_conditional_check
>
=
{})
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
{
tile_window
.
load_raw
(
tile
,
bool_constant
<
oob_conditional_check
>
{});
tile_window
.
load_raw
(
tile
,
bool_constant
<
oob_conditional_check
>
{}
,
bool_constant
<
pre_nop
>
{}
);
}
template
<
typename
LdsTileWindow_
,
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
index_t
NumCoord
>
index_t
NumCoord
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
async_load_tile_raw
(
LdsTileWindow_
&&
lds_tile
,
const
tile_window_with_static_distribution
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
NumCoord
>&
tile_window
)
NumCoord
>&
tile_window
,
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
{
return
tile_window
.
async_load
(
lds_tile
);
return
tile_window
.
async_load_raw
(
lds_tile
,
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
}
CK_TILE_DEVICE
auto
async_load_fence
(
index_t
cnt
=
0
)
...
...
include/ck_tile/core/tensor/null_tile_window.hpp
View file @
3552041a
...
...
@@ -35,6 +35,8 @@ struct null_tile_window
CK_TILE_DEVICE
constexpr
auto
get_window_origin
()
const
{
return
BottomTensorIndex
{};
}
CK_TILE_DEVICE
void
init_raw
()
{}
WindowLengths
window_lengths_
;
};
...
...
include/ck_tile/core/tensor/tensor_view.hpp
View file @
3552041a
...
...
@@ -36,6 +36,8 @@ struct tensor_view
{
}
CK_TILE_HOST_DEVICE
void
init_raw
()
{
buf_
.
init_raw
();
}
CK_TILE_HOST_DEVICE
constexpr
auto
&
get_tensor_descriptor
()
const
{
return
desc_
;
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
get_num_of_dimension
()
...
...
@@ -85,30 +87,34 @@ struct tensor_view
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
,
typename
std
::
enable_if
<
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
void
get_vectorized_elements_raw
(
remove_cvref_t
<
X
>&
dst
,
const
TensorCoord
&
coord
,
bool_constant
<
oob_conditional_check
>
=
{})
const
CK_TILE_HOST_DEVICE
void
get_vectorized_elements_raw
(
remove_cvref_t
<
X
>&
dst
,
const
TensorCoord
&
coord
,
bool_constant
<
oob_conditional_check
>
=
{}
,
bool_constant
<
pre_nop
>
=
{})
const
{
return
buf_
.
template
get_raw
<
X
,
oob_conditional_check
>(
return
buf_
.
template
get_raw
<
X
,
oob_conditional_check
,
pre_nop
>(
dst
,
coord
.
get_offset
(),
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
));
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
bool_constant
<
pre_nop
>
{});
}
template
<
typename
X
,
bool
pre_nop
=
false
,
typename
std
::
enable_if
<
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
void
async_get_vectorized_elements
(
remove_cvref_t
<
DataType
>*
smem
,
const
TensorCoord
&
coord
)
const
CK_TILE_HOST_DEVICE
constexpr
void
async_get_vectorized_elements
_raw
(
remove_cvref_t
<
DataType
>*
smem
,
const
TensorCoord
&
coord
,
bool_constant
<
pre_nop
>
=
{}
)
const
{
return
buf_
.
template
async_get
<
X
>(
smem
,
coord
.
get_offset
(),
true
/*not used*/
);
return
buf_
.
template
async_get_raw
<
X
>(
smem
,
coord
.
get_offset
(),
true
/*not used*/
,
bool_constant
<
pre_nop
>
{});
}
// X is vector of DataType.
...
...
include/ck_tile/core/tensor/tile_elementwise.hpp
View file @
3552041a
...
...
@@ -76,23 +76,63 @@ CK_TILE_DEVICE void set_tile(null_tensor&, const T&)
// TODO: prefer to use per-dword value to set a tensor, in case compiler not doing well with
// sub-dword tensor...
template
<
typename
DstrTensors
,
index_t
v
>
CK_TILE_DEVICE
void
set_tile
(
DstrTensors
&
dstr_tensor
,
number
<
v
>
)
template
<
typename
DstrTensors
,
index_t
v
,
bool
skip_subdword_opt
=
false
>
CK_TILE_DEVICE
void
set_tile
(
DstrTensors
&
dstr_tensor
,
number
<
v
>
,
bool_constant
<
skip_subdword_opt
>
=
{})
{
constexpr
index_t
tensor_bytes
=
DstrTensors
::
get_thread_buffer_size
()
*
sizeof
(
typename
DstrTensors
::
DataType
);
if
constexpr
(
v
==
0
&&
tensor_bytes
%
4
==
0
)
using
elem_type
=
typename
DstrTensors
::
DataType
;
constexpr
index_t
elem_size
=
sizeof
(
elem_type
);
constexpr
index_t
tensor_bytes
=
DstrTensors
::
get_thread_buffer_size
()
*
elem_size
;
// # bytes per write = 4
if
constexpr
(
v
==
0
&&
tensor_bytes
%
4
==
0
&&
!
skip_subdword_opt
)
{
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
auto
&
buffer
=
dstr_tensor
.
get_thread_buffer
();
static_for
<
0
,
tensor_bytes
/
4
,
1
>
{}([
&
](
auto
i_write
)
{
if
constexpr
(
elem_size
==
1
)
{
// # elements per write = 4
constexpr
auto
values
=
ext_vector_t
<
elem_type
,
4
>
{
0
,
0
,
0
,
0
};
buffer
[
i_write
*
4
+
0
]
=
values
.
x
;
buffer
[
i_write
*
4
+
1
]
=
values
.
y
;
buffer
[
i_write
*
4
+
2
]
=
values
.
z
;
buffer
[
i_write
*
4
+
3
]
=
values
.
w
;
}
else
if
constexpr
(
elem_size
==
2
)
{
// # elements per write = 2
constexpr
auto
values
=
ext_vector_t
<
elem_type
,
2
>
{
0
,
0
};
buffer
[
i_write
*
2
+
0
]
=
values
.
x
;
buffer
[
i_write
*
2
+
1
]
=
values
.
y
;
}
else
if
constexpr
(
elem_size
==
4
)
{
// # elements per write = 1
constexpr
elem_type
value
=
0
;
buffer
[
i_write
]
=
value
;
}
else
{
static_assert
(
false
,
"type not supported"
);
}
});
#else
using
dvec_t
=
array
<
index_t
,
tensor_bytes
/
4
>
;
auto
&
tensor
=
reinterpret_cast
<
dvec_t
&>
(
dstr_tensor
.
get_thread_buffer
());
for
(
auto
i
=
0
;
i
<
tensor
.
size
();
i
++
)
tensor
.
get
(
i
)
=
v
;
#endif
}
else
{
tile_elementwise_inout
(
[](
auto
&
x
)
{
x
=
type_convert
<
typename
DstrTensors
::
DataType
,
index_t
>
(
v
);
},
dstr_tensor
);
tile_elementwise_inout
([](
auto
&
x
)
{
x
=
type_convert
<
elem_type
,
index_t
>
(
v
);
},
dstr_tensor
);
}
}
...
...
include/ck_tile/core/tensor/tile_window.hpp
View file @
3552041a
...
...
@@ -344,9 +344,10 @@ struct tile_window_with_static_distribution
return
dst_tensor
;
}
template
<
typename
DstTile
,
bool
oob_conditional_check
=
true
>
template
<
typename
DstTile
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
void
load_raw
(
DstTile
&
dst_tensor
,
bool_constant
<
oob_conditional_check
>
=
{})
const
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
const
{
using
Traits
=
load_store_traits
;
...
...
@@ -373,7 +374,13 @@ struct tile_window_with_static_distribution
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
static_for
<
0
,
NumAccessPerCoord
,
1
>
{}([
&
](
auto
iCoordAccess
)
{
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
constexpr
auto
pre_nop_
=
[
&
]()
{
if
constexpr
(
pre_nop
&&
iCoord
==
0
&&
iCoordAccess
==
0
)
return
bool_constant
<
true
>
{};
else
return
bool_constant
<
false
>
{};
}();
// data index [y0, y1, ...]
constexpr
auto
idx_ys_start
=
SFC_Ys
::
get_index
(
iAccess
);
...
...
@@ -384,7 +391,8 @@ struct tile_window_with_static_distribution
get_bottom_tensor_view
().
template
get_vectorized_elements_raw
<
vector_t
>(
dst_vec_tbuf
.
template
at
<
d
/
Traits
::
ScalarPerVector
>(),
bottom_tensor_thread_coord
,
bool_constant
<
oob_conditional_check
>
{});
bool_constant
<
oob_conditional_check
>
{},
pre_nop_
);
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
...
...
@@ -399,12 +407,17 @@ struct tile_window_with_static_distribution
}
});
});
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
asm
volatile
(
"; this inline asm is workaround to prevent compiler from using too much "
"scratch memory"
::
);
#endif
}
// TODO: currently async load only implemented in inline asm
template
<
typename
LdsTileWindow_
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
async_load
(
LdsTileWindow_
&&
lds_tile
,
bool_constant
<
oob_conditional_check
>
=
{})
const
template
<
typename
LdsTileWindow_
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
async_load_raw
(
LdsTileWindow_
&&
lds_tile
,
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
const
{
using
LdsTileWindow
=
remove_cvref_t
<
LdsTileWindow_
>
;
// using LdsTensorView = typename LdsTileWindow::BottomTensorView;
...
...
@@ -449,11 +462,17 @@ struct tile_window_with_static_distribution
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
static_for
<
0
,
NumAccessPerCoord
,
1
>
{}([
&
](
auto
iCoordAccess
)
{
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
constexpr
auto
pre_nop_
=
[
&
]()
{
if
constexpr
(
pre_nop
&&
iCoord
==
0
&&
iCoordAccess
==
0
)
return
bool_constant
<
true
>
{};
else
return
bool_constant
<
false
>
{};
}();
// read from bottom tensor
get_bottom_tensor_view
().
template
async_get_vectorized_elements
<
vector_t
>(
smem
,
bottom_tensor_thread_coord
);
get_bottom_tensor_view
().
template
async_get_vectorized_elements
_raw
<
vector_t
>(
smem
,
bottom_tensor_thread_coord
,
pre_nop_
);
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
...
...
@@ -668,6 +687,67 @@ struct tile_window_with_static_distribution
});
}
CK_TILE_DEVICE
void
set_window_origin
(
const
BottomTensorIndex
&
new_window_origin
)
{
window_origin_
=
new_window_origin
;
#if 0 // debug
// TODO: this use more register for FA, but less register for GEMM
// need investigation
// only support warp-tile and block-tile
static_assert(NDimP == 1 or NDimP == 2, "wrong!");
WindowAdaptorCoord window_adaptor_thread_coord_tmp;
if constexpr(NDimP == 1)
{
window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
tile_dstr_.get_ps_ys_to_xs_adaptor(), AdaptorTopIndex{get_lane_id(), 0});
}
else if constexpr(NDimP == 2)
{
window_adaptor_thread_coord_tmp =
make_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(),
AdaptorTopIndex{get_warp_id(), get_lane_id(), 0});
}
#else
// TODO: this use less register for FA, but more register for GEMM
// need investigation
const
auto
window_adaptor_thread_coord_tmp
=
make_tensor_adaptor_coordinate
(
tile_dstr_
.
get_ps_ys_to_xs_adaptor
(),
container_concat
(
detail
::
get_partition_index
(
tile_dstr_
),
array
<
index_t
,
NDimY
>
{
0
}));
#endif
BottomTensorIndex
bottom_tensor_thread_origin_idx_tmp
=
window_origin_
+
window_adaptor_thread_coord_tmp
.
get_bottom_index
();
const
auto
bottom_tensor_thread_coord_tmp
=
make_tensor_coordinate
(
bottom_tensor_view_
.
get_tensor_descriptor
(),
bottom_tensor_thread_origin_idx_tmp
);
// pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
// future load/store() calls (might allocate more registers)
using
Traits
=
load_store_traits
;
using
SFC_Ys
=
typename
Traits
::
SFC_Ys
;
static_for
<
0
,
NumCoord
,
1
>
{}([
&
](
auto
iCoord
)
{
auto
window_adaptor_thread_coord
=
window_adaptor_thread_coord_tmp
;
auto
bottom_tensor_thread_coord
=
bottom_tensor_thread_coord_tmp
;
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_step_between
(
number
<
0
>
{},
number
<
iCoord
*
NumAccessPerCoord
>
{});
constexpr
auto
idx_diff_ps_ys
=
container_concat
(
array
<
index_t
,
NDimP
>
{
0
},
idx_diff_ys
);
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
pre_computed_coords_
(
iCoord
)
=
make_tuple
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
);
});
}
CK_TILE_HOST_DEVICE
void
init_raw
()
{
bottom_tensor_view_
.
init_raw
();
}
// this is the bottom tensor view
// [x0', x1', ...] ==> [offset]
BottomTensorView
bottom_tensor_view_
;
...
...
Prev
1
2
3
4
5
6
7
8
9
10
…
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment