Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
2b8e3ece
Commit
2b8e3ece
authored
Apr 17, 2020
by
Jing Zhang
Browse files
add fp16
parent
9b4fdeee
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
743 additions
and
6 deletions
+743
-6
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw.hpp
...olution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw.hpp
+189
-0
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_fp16_bfp16.hpp
...lude/tensor_operation/gridwise_gemm_xdlops_fp16_bfp16.hpp
+354
-0
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
+1
-1
composable_kernel/include/utility/float_type.nvidia.hpp.in
composable_kernel/include/utility/float_type.nvidia.hpp.in
+15
-2
driver/include/device_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw.hpp
...olution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw.hpp
+180
-0
driver/src/conv_driver.cpp
driver/src/conv_driver.cpp
+4
-3
No files found.
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw.hpp
0 → 100644
View file @
2b8e3ece
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_XDLOPS_FP16_NCHW_KCYX_NKHW_HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_XDLOPS_FP16_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_fp16_bfp16.hpp"
namespace
ck
{
// GemmM = K
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
template
<
index_t
GridSize
,
index_t
BlockSize
,
class
Float
,
class
AccDataType
,
class
InGlobalDesc
,
class
WeiGlobalDesc
,
class
OutGlobalDesc
,
class
ConvStrides
,
class
ConvDilations
,
class
LeftPads
,
class
RightPads
,
index_t
GemmMPerBlock
,
index_t
GemmNPerBlock
,
index_t
GemmKPerBlock
,
index_t
GemmKPACK
,
index_t
GemmMPerWave
,
index_t
GemmNPerWave
,
index_t
GemmThreadGemmDataPerReadM
,
index_t
GemmThreadGemmDataPerReadN
,
class
GemmABlockCopyThreadSliceLengths_GemmK_GemmM_GemmKPACK
,
class
GemmABlockCopyThreadClusterLengths_GemmK_GemmM_GemmKPACK
,
index_t
GemmABlockCopySrcDataPerRead_GemmKPACK
,
index_t
GemmABlockCopyDstDataPerWrite_GemmKPACK
,
class
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN_GemmKPACK
,
class
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN_GemmKPACK
,
index_t
GemmBBlockCopySrcDataPerRead_GemmN
,
index_t
GemmBBlockCopyDstDataPerWrite_GemmKPACK
>
struct
GridwiseConvolutionImplicitGemm_v4r4_xdlops_fwd_fp16_nchw_kcyx_nkhw
{
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_n_c_hi_wi_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_k_c_y_x_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_n_k_ho_wo_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
N
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I0
);
constexpr
index_t
C
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Hi
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I2
);
constexpr
index_t
Wi
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I3
);
constexpr
index_t
K
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Ho
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I2
);
constexpr
index_t
Wo
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I3
);
constexpr
index_t
Y
=
wei_k_c_y_x_global_desc
.
GetLength
(
I2
);
constexpr
index_t
X
=
wei_k_c_y_x_global_desc
.
GetLength
(
I3
);
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
constexpr
index_t
GemmM
=
K
;
constexpr
index_t
GemmK
=
(
C
*
Y
*
X
)
/
GemmKPACK
;
constexpr
index_t
GemmN
=
N
*
Ho
*
Wo
;
static_assert
(
GemmM
%
GemmMPerBlock
==
0
&&
GemmN
%
GemmNPerBlock
==
0
&&
GemmK
%
GemmKPerBlock
==
0
,
"wrong! cannot divide work evenly among block"
);
// sanity-check for vectorized memory load
static_assert
((
Wo
==
1
||
(
ConvStrideW
==
1
||
GemmBBlockCopySrcDataPerRead_GemmN
==
1
))
&&
(
X
==
1
||
ConvDilationW
%
GemmBBlockCopySrcDataPerRead_GemmN
==
0
),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated"
);
// input tensor
// global mem
constexpr
auto
in_n_c_hip_wip_global_desc
=
transform_tensor_descriptor
(
in_n_c_hi_wi_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Pad
<
Sequence
<
Hi
,
Wi
>
,
LeftPads
,
RightPads
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}));
constexpr
index_t
Hip
=
in_n_c_hip_wip_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wip
=
in_n_c_hip_wip_global_desc
.
GetLengths
()[
3
];
constexpr
auto
in_n_c_y_ho_x_wo_global_desc
=
transform_tensor_descriptor
(
in_n_c_hip_wip_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Embed
<
Hip
,
Sequence
<
Y
,
Ho
>
,
Sequence
<
ConvDilationH
,
ConvStrideH
,
0
>>
{},
Embed
<
Wip
,
Sequence
<
X
,
Wo
>
,
Sequence
<
ConvDilationW
,
ConvStrideW
,
0
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
constexpr
auto
in_gemmk_gemmn_global_desc
=
transform_tensor_descriptor
(
in_n_c_y_ho_x_wo_global_desc
,
make_tuple
(
Merge
<
Sequence
<
C
,
Y
,
X
>>
{},
Merge
<
Sequence
<
N
,
Ho
,
Wo
>>
{}),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
constexpr
auto
in_gemmk_gemmkpack_gemmn_global_desc
=
transform_tensor_descriptor
(
in_gemmk_gemmn_global_desc
,
make_tuple
(
UnMerge
<
Sequence
<
GemmK
,
GemmKPACK
>>
{},
PassThrough
<
GemmN
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{}));
constexpr
auto
in_gemmk_gemmn_gemmkpack_global_desc
=
transform_tensor_descriptor
(
in_gemmk_gemmkpack_gemmn_global_desc
,
make_tuple
(
PassThrough
<
GemmK
>
{},
PassThrough
<
GemmN
>
{},
PassThrough
<
GemmKPACK
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
2
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
constexpr
auto
wei_gemmk_gemmm_global_desc
=
reorder_tensor_descriptor_given_upper2lower
(
unfold_tensor_descriptor
(
wei_k_c_y_x_global_desc
,
I1
,
I3
),
Sequence
<
1
,
0
>
{});
constexpr
auto
wei_gemmm_gemmk_gemmkpack_global_desc
=
transform_tensor_descriptor
(
wei_gemmk_gemmm_global_desc
,
make_tuple
(
PassThrough
<
K
>
{},
UnMerge
<
Sequence
<
GemmK
,
GemmKPACK
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}));
constexpr
auto
wei_gemmk_gemmm_gemmkpack_global_desc
=
transform_tensor_descriptor
(
wei_gemmm_gemmk_gemmkpack_global_desc
,
make_tuple
(
PassThrough
<
GemmK
>
{},
PassThrough
<
K
>
{},
PassThrough
<
GemmKPACK
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
constexpr
auto
out_gemmm_gemmn_global_desc
=
transform_tensor_descriptor
(
out_n_k_ho_wo_global_desc
,
make_tuple
(
PassThrough
<
K
>
{},
Merge
<
Sequence
<
N
,
Ho
,
Wo
>>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// GEMM
constexpr
auto
gridwise_gemm
=
GridwiseGemmTransposedANormalBNormalCXdlopsFp16Bfp16_v1
<
GridSize
,
BlockSize
,
Float
,
AccDataType
,
Float
,
decltype
(
wei_gemmk_gemmm_gemmkpack_global_desc
),
decltype
(
in_gemmk_gemmn_gemmkpack_global_desc
),
decltype
(
out_gemmm_gemmn_global_desc
),
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerWave
,
GemmNPerWave
,
GemmThreadGemmDataPerReadM
,
GemmThreadGemmDataPerReadN
,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM_GemmKPACK
,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM_GemmKPACK
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
0
,
1
,
2
>
,
2
,
GemmABlockCopySrcDataPerRead_GemmKPACK
,
GemmABlockCopyDstDataPerWrite_GemmKPACK
,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN_GemmKPACK
,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN_GemmKPACK
,
Sequence
<
0
,
2
,
1
>
,
Sequence
<
0
,
2
,
1
>
,
Sequence
<
0
,
1
,
2
>
,
1
,
GemmBBlockCopySrcDataPerRead_GemmN
,
GemmBBlockCopyDstDataPerWrite_GemmKPACK
,
InMemoryDataOperation
::
Set
>
{};
gridwise_gemm
.
Run
(
p_wei_global
,
p_in_global
,
p_out_global
);
}
};
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_fp16_bfp16.hpp
0 → 100644
View file @
2b8e3ece
#ifndef CK_GRIDWISE_GEMM_XDLOPS_FP16_BFP16_HPP
#define CK_GRIDWISE_GEMM_XDLOPS_FP16_BFP16_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm_xdlops.hpp"
namespace
ck
{
template
<
index_t
GridSize
,
index_t
BlockSize
,
class
ABFloat
,
class
AccFloat
,
class
CFloat
,
class
AGlobalDesc
,
class
BGlobalDesc
,
class
CGlobalDesc
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerWave
,
index_t
NPerWave
,
index_t
GemmDataPerReadM
,
index_t
GemmDataPerReadN
,
class
ABlockCopyThreadSliceLengths_K_M_KPACK
,
class
ABlockCopyThreadClusterLengths_K_M_KPACK
,
class
ABlockCopyThreadClusterArrangeOrder
,
class
ABlockCopySrcAccessOrder
,
class
ABlockCopyDstAccessOrder
,
index_t
ABlockCopySrcVectorReadDim
,
index_t
ABlockCopySrcDataPerRead
,
index_t
ABlockCopyDstDataPerWrite_KPACK
,
class
BBlockCopyThreadSliceLengths_K_N_KPACK
,
class
BBlockCopyThreadClusterLengths_K_N_KPACK
,
class
BBlockCopyThreadClusterArrangeOrder
,
class
BBlockCopySrcAccessOrder
,
class
BBlockCopyDstAccessOrder
,
index_t
BBlockCopySrcVectorReadDim
,
index_t
BBlockCopySrcDataPerRead
,
index_t
BBlockCopyDstDataPerWrite_KPACK
,
InMemoryDataOperation
OutputMemOp
>
struct
GridwiseGemmTransposedANormalBNormalCXdlopsFp16Bfp16_v1
{
__device__
void
Run
(
const
ABFloat
*
const
__restrict__
p_a_global
,
const
ABFloat
*
const
__restrict__
p_b_global
,
CFloat
*
const
__restrict__
p_c_global
)
const
{
constexpr
auto
b_k_n_kpack_global_desc
=
BGlobalDesc
{};
constexpr
auto
a_k_m_kpack_global_desc
=
AGlobalDesc
{};
constexpr
auto
c_m_n_global_desc
=
CGlobalDesc
{};
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
K
=
b_k_n_kpack_global_desc
.
GetLengths
()[
0
];
constexpr
auto
N
=
b_k_n_kpack_global_desc
.
GetLengths
()[
1
];
constexpr
auto
M
=
a_k_m_kpack_global_desc
.
GetLengths
()[
1
];
constexpr
auto
KPACK
=
b_k_n_kpack_global_desc
.
GetLengths
()[
2
];
// divide block work by [M, N]
static_assert
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
,
"wrong! cannot divide work evenly among block"
);
constexpr
index_t
MBlockWork
=
M
/
MPerBlock
;
constexpr
index_t
NBlockWork
=
N
/
NPerBlock
;
constexpr
index_t
MWaves
=
MPerBlock
/
MPerWave
;
constexpr
index_t
NWaves
=
NPerBlock
/
NPerWave
;
constexpr
auto
block_work_desc
=
make_cluster_descriptor
(
Sequence
<
MBlockWork
,
NBlockWork
>
{});
const
auto
block_work_id
=
block_work_desc
.
CalculateClusterIndex
(
get_block_1d_id
());
const
index_t
k_block_data_on_global
=
block_work_id
[
0
]
*
MPerBlock
;
const
index_t
b_block_data_on_global
=
block_work_id
[
1
]
*
NPerBlock
;
// LDS mem
constexpr
index_t
max_align
=
math
::
lcm
(
BBlockCopyDstDataPerWrite_KPACK
,
ABlockCopyDstDataPerWrite_KPACK
,
KPACK
*
GemmDataPerReadM
,
KPACK
*
GemmDataPerReadN
);
// LDS
// be careful of LDS alignment
constexpr
auto
a_k_m_kpack_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
KPerBlock
,
MPerBlock
,
KPACK
>
{},
Number
<
max_align
>
{});
auto
a_blockwise_copy
=
BlockwiseGenericTensorSliceCopy_v4
<
BlockSize
,
decltype
(
a_k_m_kpack_global_desc
),
decltype
(
a_k_m_kpack_block_desc
),
decltype
(
a_k_m_kpack_block_desc
.
GetLengths
()),
ABlockCopyThreadSliceLengths_K_M_KPACK
,
ABlockCopyThreadClusterLengths_K_M_KPACK
,
ABlockCopyThreadClusterArrangeOrder
,
ABlockCopySrcAccessOrder
,
ABlockCopyDstAccessOrder
,
ABlockCopySrcVectorReadDim
,
// Src dim to be read in vector form (M dimension)
2
,
// Dst dim to be written in vector form (KPACK dimension)
ABlockCopySrcDataPerRead
,
ABlockCopyDstDataPerWrite_KPACK
,
AddressSpace
::
Generic
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Lds
,
InMemoryDataOperation
::
Set
>
({
0
,
k_block_data_on_global
,
0
},
{
0
,
0
,
0
});
constexpr
auto
b_k_n_kpack_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
KPerBlock
,
NPerBlock
,
KPACK
>
{},
Number
<
max_align
>
{});
// input blockwise copy
auto
b_blockwise_copy
=
BlockwiseGenericTensorSliceCopy_v4
<
BlockSize
,
decltype
(
b_k_n_kpack_global_desc
),
decltype
(
b_k_n_kpack_block_desc
),
decltype
(
b_k_n_kpack_block_desc
.
GetLengths
()),
BBlockCopyThreadSliceLengths_K_N_KPACK
,
BBlockCopyThreadClusterLengths_K_N_KPACK
,
BBlockCopyThreadClusterArrangeOrder
,
BBlockCopySrcAccessOrder
,
BBlockCopyDstAccessOrder
,
BBlockCopySrcVectorReadDim
,
// Src dim to be read in vector form (N dimension)
2
,
// Dst dim to be written in vector form (KPACK dimension)
BBlockCopySrcDataPerRead
,
BBlockCopyDstDataPerWrite_KPACK
,
AddressSpace
::
Generic
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Lds
,
InMemoryDataOperation
::
Set
>
({
0
,
b_block_data_on_global
,
0
},
{
0
,
0
,
0
});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
constexpr
auto
a_k_m_block_mtx_desc
=
make_ConstantMatrixDescriptor_packed
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{});
constexpr
auto
b_k_n_block_mtx_desc
=
make_ConstantMatrixDescriptor_packed
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{});
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
<
BlockSize
,
decltype
(
a_k_m_block_mtx_desc
),
decltype
(
b_k_n_block_mtx_desc
),
ABFloat
,
MPerWave
,
NPerWave
,
MWaves
,
NWaves
,
GemmDataPerReadM
,
GemmDataPerReadN
>
{};
constexpr
auto
c_k_thread_mtx_desc
=
blockwise_gemm
.
GetThreadMatrixCDescriptor
();
constexpr
index_t
a_block_space
=
math
::
integer_least_multiple
(
a_k_m_kpack_block_desc
.
GetElementSpace
(),
max_align
);
constexpr
index_t
b_block_space
=
math
::
integer_least_multiple
(
b_k_n_kpack_block_desc
.
GetElementSpace
(),
max_align
);
__shared__
ABFloat
p_a_block_double
[
2
*
a_block_space
];
__shared__
ABFloat
p_b_block_double
[
2
*
b_block_space
];
// register allocation for output
AccFloat
p_c_thread
[
c_k_thread_mtx_desc
.
GetElementSpace
()];
// zero out threadwise output
threadwise_matrix_set_zero
(
c_k_thread_mtx_desc
,
p_c_thread
);
blockwise_gemm
.
XdlopsMatrixCSetZero
();
// LDS double buffer: preload data into LDS
{
a_blockwise_copy
.
Run
(
p_a_global
,
p_a_block_double
);
b_blockwise_copy
.
Run
(
p_b_global
,
p_b_block_double
);
}
using
blockwise_a_copy_src_step
=
Sequence
<
KPerBlock
,
0
,
0
>
;
using
blockwise_b_copy_src_step
=
Sequence
<
KPerBlock
,
0
,
0
>
;
// LDS double buffer: main body
for
(
index_t
k_block_data_begin
=
0
;
k_block_data_begin
+
2
*
KPerBlock
<
K
;
k_block_data_begin
+=
2
*
KPerBlock
)
{
#pragma unroll
for
(
index_t
iloop
=
0
;
iloop
<
2
;
++
iloop
)
{
const
bool
even_loop
=
(
iloop
%
2
==
0
);
ABFloat
*
p_a_block_now
=
even_loop
?
p_a_block_double
:
p_a_block_double
+
a_block_space
;
ABFloat
*
p_b_block_now
=
even_loop
?
p_b_block_double
:
p_b_block_double
+
b_block_space
;
ABFloat
*
p_a_block_next
=
even_loop
?
p_a_block_double
+
a_block_space
:
p_a_block_double
;
ABFloat
*
p_b_block_next
=
even_loop
?
p_b_block_double
+
b_block_space
:
p_b_block_double
;
ABFloat
p_a_thread_buffer
[
a_blockwise_copy
.
GetThreadBufferSize
()];
ABFloat
p_b_thread_buffer
[
b_blockwise_copy
.
GetThreadBufferSize
()];
a_blockwise_copy
.
MoveSrcSliceWindow
(
blockwise_a_copy_src_step
{},
True
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
blockwise_b_copy_src_step
{},
True
);
__syncthreads
();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunLoadThreadBuffer
(
p_a_global
,
p_a_thread_buffer
);
b_blockwise_copy
.
RunLoadThreadBuffer
(
p_b_global
,
p_b_thread_buffer
);
// LDS double buffer: GEMM on current data
// Vectorize the pointer to match with how half/bfloat16 datatypes are
// processed in gemm operation. Half type packs 4 half values while
// bfloat16 packs 2 bfloat16 values. Since gemm's matrix A and B
// 2D indexes are computed with vectorized value in mind (e.g. float, half2, half4),
// we recast datatype from a single half to 4 packed half/2 packed bfloat16
// respectively.
auto
p_a_block_vec
=
reinterpret_cast
<
const
half4_t
*>
(
p_a_block_now
);
auto
p_b_block_vec
=
reinterpret_cast
<
const
half4_t
*>
(
p_b_block_now
);
blockwise_gemm
.
Run
(
p_a_block_vec
,
p_b_block_vec
,
p_c_thread
);
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunStoreThreadBuffer
(
p_a_thread_buffer
,
p_a_block_next
);
b_blockwise_copy
.
RunStoreThreadBuffer
(
p_b_thread_buffer
,
p_b_block_next
);
}
}
// LDS double buffer: tail
{
constexpr
bool
has_two_iteration_left
=
(
K
%
(
2
*
KPerBlock
)
==
0
);
if
(
has_two_iteration_left
)
// if has 2 iteration left
{
ABFloat
p_a_thread_buffer
[
a_blockwise_copy
.
GetThreadBufferSize
()];
ABFloat
p_b_thread_buffer
[
b_blockwise_copy
.
GetThreadBufferSize
()];
a_blockwise_copy
.
MoveSrcSliceWindow
(
blockwise_a_copy_src_step
{},
True
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
blockwise_b_copy_src_step
{},
True
);
__syncthreads
();
// LDS double buffer: load last data from device mem
a_blockwise_copy
.
RunLoadThreadBuffer
(
p_a_global
,
p_a_thread_buffer
);
b_blockwise_copy
.
RunLoadThreadBuffer
(
p_b_global
,
p_b_thread_buffer
);
// LDS double buffer: GEMM on 2nd-last data
auto
p_a_block_vec
=
reinterpret_cast
<
const
half4_t
*>
(
p_a_block_double
);
auto
p_b_block_vec
=
reinterpret_cast
<
const
half4_t
*>
(
p_b_block_double
);
blockwise_gemm
.
Run
(
p_a_block_vec
,
p_b_block_vec
,
p_c_thread
);
// LDS double buffer: store last data to LDS
a_blockwise_copy
.
RunStoreThreadBuffer
(
p_a_thread_buffer
,
p_a_block_double
+
a_block_space
);
b_blockwise_copy
.
RunStoreThreadBuffer
(
p_b_thread_buffer
,
p_b_block_double
+
b_block_space
);
__syncthreads
();
// LDS double buffer: GEMM on current data
p_a_block_vec
=
reinterpret_cast
<
const
half4_t
*>
(
p_a_block_double
+
a_block_space
);
p_b_block_vec
=
reinterpret_cast
<
const
half4_t
*>
(
p_b_block_double
+
b_block_space
);
blockwise_gemm
.
Run
(
p_a_block_vec
,
p_b_block_vec
,
p_c_thread
);
}
else
// if has 1 iteration left
{
__syncthreads
();
// LDS double buffer: GEMM on last data
auto
p_a_block_vec
=
reinterpret_cast
<
const
half4_t
*>
(
p_a_block_double
);
auto
p_b_block_vec
=
reinterpret_cast
<
const
half4_t
*>
(
p_b_block_double
);
blockwise_gemm
.
Run
(
p_a_block_vec
,
p_b_block_vec
,
p_c_thread
);
}
}
// load data from xldop_acc_regs
blockwise_gemm
.
XdlopsMatrixCRead
(
p_c_thread
);
// copy output: register to global memory
{
constexpr
auto
OutputLayout
=
blockwise_gemm
.
GetOutputLayout
();
constexpr
index_t
K0
=
OutputLayout
.
M1
();
constexpr
index_t
K1
=
OutputLayout
.
N1
();
constexpr
index_t
K2
=
OutputLayout
.
M0
();
constexpr
auto
out_k0_k1_k2_b_global_desc
=
transform_tensor_descriptor
(
c_m_n_global_desc
,
make_tuple
(
UnMerge
<
Sequence
<
K0
,
K1
,
K2
>>
{},
PassThrough
<
N
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{}));
// src descriptor
constexpr
auto
out_k0_k1_k2_b_thread_desc
=
make_native_tensor_descriptor_packed
(
Sequence
<
K0
,
1
,
K2
,
1
>
{});
using
OutThreadCopySliceLengths
=
Sequence
<
K0
,
1
,
K2
,
1
>
;
constexpr
index_t
BlkSize
=
OutputLayout
.
GetBlkSize
();
constexpr
index_t
NumBlks
=
OutputLayout
.
GetNumBlks
();
for
(
index_t
i
=
0
;
i
<
NumBlks
;
++
i
)
{
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
i
);
const
index_t
k_thread_data_on_global
=
k_block_data_on_global
+
c_thread_mtx_on_block
.
row
;
const
index_t
b_thread_data_on_global
=
b_block_data_on_global
+
c_thread_mtx_on_block
.
col
;
ThreadwiseGenericTensorSliceCopy_v4r2
<
decltype
(
out_k0_k1_k2_b_thread_desc
),
decltype
(
out_k0_k1_k2_b_global_desc
),
OutThreadCopySliceLengths
,
arithmetic_sequence_gen
<
0
,
4
,
1
>::
type
,
3
,
1
,
1
,
AddressSpace
::
Vgpr
,
is_same
<
AccFloat
,
CFloat
>::
value
?
AddressSpace
::
Global
:
AddressSpace
::
Generic
,
OutputMemOp
>
({
0
,
0
,
0
,
0
},
{
k_thread_data_on_global
/
(
K2
*
K1
),
k_thread_data_on_global
%
(
K2
*
K1
)
/
K2
,
k_thread_data_on_global
%
K2
,
b_thread_data_on_global
})
.
Run
(
p_c_thread
+
i
*
BlkSize
,
p_c_global
);
}
}
}
};
}
#endif
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
View file @
2b8e3ece
...
@@ -810,7 +810,7 @@ struct XdlopsGemm_t
...
@@ -810,7 +810,7 @@ struct XdlopsGemm_t
(
mfma_type
.
group_size
*
mfma_type
.
num_input_blks
);
(
mfma_type
.
group_size
*
mfma_type
.
num_input_blks
);
index_t
bindex
=
blk_td
;
index_t
bindex
=
blk_td
;
p_c_thread
[
m
+
c_off
]
+=
inner_product_with_conversion
<
FloatC
>
{}(
p_c_thread
[
m
+
c_off
]
+=
inner_product_with_conversion
<
FloatC
>
{}(
p_a_wave
[
aindex
+
a_off
],
p_b_wave
[
bindex
+
b_off
]);
p_a_wave
[
aindex
+
a_off
],
p_b_wave
[
bindex
+
b_off
]);
}
}
}
}
}
}
...
...
composable_kernel/include/utility/float_type.nvidia.hpp.in
View file @
2b8e3ece
...
@@ -18,8 +18,6 @@ typedef float float16_t __attribute__((ext_vector_type(16)));
...
@@ -18,8 +18,6 @@ typedef float float16_t __attribute__((ext_vector_type(16)));
typedef float float32_t __attribute__((ext_vector_type(32)));
typedef float float32_t __attribute__((ext_vector_type(32)));
// float16
// float16
typedef float half4_t __attribute__((ext_vector_type(2)));
typedef float half8_t __attribute__((ext_vector_type(4)));
// bfloat16
// bfloat16
typedef ushort ushort2_t __attribute__((ext_vector_type(2)));
typedef ushort ushort2_t __attribute__((ext_vector_type(2)));
...
@@ -28,6 +26,7 @@ typedef ushort ushort8_t __attribute__((ext_vector_type(8)));
...
@@ -28,6 +26,7 @@ typedef ushort ushort8_t __attribute__((ext_vector_type(8)));
// float16
// float16
using half2_t = half2;
using half2_t = half2;
using half4_t = float2;
template <class T, index_t N>
template <class T, index_t N>
struct vector_type
struct vector_type
...
@@ -164,6 +163,20 @@ struct inner_product_with_conversion
...
@@ -164,6 +163,20 @@ struct inner_product_with_conversion
return acc;
return acc;
}
}
__device__ T operator()(half4_t a, half4_t b) const
{
const half* p_a_half = reinterpret_cast<const half*>(&a);
const half* p_b_half = reinterpret_cast<const half*>(&b);
T acc = 0;
for(index_t v = 0; v < 4; ++v)
{
acc += convert(p_a_half[v]) * convert(p_b_half[v]);
}
return acc;
}
};
};
} // namespace ck
} // namespace ck
...
...
driver/include/device_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw.hpp
0 → 100644
View file @
2b8e3ece
#include <unistd.h>
#include "device.hpp"
#include "tensor.hpp"
#include "gridwise_convolution_kernel_wrapper.hpp"
#include "gridwise_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw.hpp"
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
,
class
ConvStrides
,
class
ConvDilations
,
class
InLeftPads
,
class
InRightPads
>
void
device_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw
(
InDesc
,
const
Tensor
<
T
>&
in_nchw
,
WeiDesc
,
const
Tensor
<
T
>&
wei_kcyx
,
OutDesc
,
Tensor
<
T
>&
out_nkhw
,
ConvStrides
,
ConvDilations
,
InLeftPads
,
InRightPads
,
ck
::
index_t
nrepeat
)
{
using
namespace
ck
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_nchw_desc
=
make_native_tensor_descriptor
(
InDesc
::
GetLengths
(),
InDesc
::
GetStrides
());
constexpr
auto
wei_kcyx_desc
=
make_native_tensor_descriptor
(
WeiDesc
::
GetLengths
(),
WeiDesc
::
GetStrides
());
constexpr
auto
out_nkhw_desc
=
make_native_tensor_descriptor
(
OutDesc
::
GetLengths
(),
OutDesc
::
GetStrides
());
constexpr
index_t
N
=
out_nkhw_desc
.
GetLength
(
I0
);
constexpr
index_t
K
=
out_nkhw_desc
.
GetLength
(
I1
);
constexpr
index_t
Ho
=
out_nkhw_desc
.
GetLength
(
I2
);
constexpr
index_t
Wo
=
out_nkhw_desc
.
GetLength
(
I3
);
std
::
size_t
data_sz
=
sizeof
(
T
);
DeviceMem
in_nchw_device_buf
(
data_sz
*
in_nchw
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_kcyx_device_buf
(
data_sz
*
wei_kcyx
.
mDesc
.
GetElementSpace
());
DeviceMem
out_nkhw_device_buf
(
data_sz
*
out_nkhw
.
mDesc
.
GetElementSpace
());
in_nchw_device_buf
.
ToDevice
(
in_nchw
.
mData
.
data
());
wei_kcyx_device_buf
.
ToDevice
(
wei_kcyx
.
mData
.
data
());
out_nkhw_device_buf
.
ToDevice
(
out_nkhw
.
mData
.
data
());
// cdata = 64, BlockSize = 256, 128x128x16
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
64
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmKPACK
=
4
;
constexpr
index_t
GemmMPerWave
=
64
;
constexpr
index_t
GemmNPerWave
=
64
;
constexpr
index_t
ThreadGemmDataPerReadM
=
1
;
constexpr
index_t
ThreadGemmDataPerReadN
=
1
;
using
GemmABlockCopyThreadSliceLengths_GemmK_GemmM_GemmKPACK
=
Sequence
<
1
,
4
,
4
>
;
using
GemmABlockCopyThreadClusterLengths_GemmK_GemmM_GemmKPACK
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
GemmABlockCopySrcDataPerRead_GemmKPACK
=
1
;
constexpr
index_t
GemmABlockCopyDstDataPerWrite_GemmKPACK
=
1
;
using
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN_GemmKPACK
=
Sequence
<
1
,
2
,
4
>
;
using
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN_GemmKPACK
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
GemmBBlockCopySrcDataPerRead_GemmN
=
1
;
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmKPACK
=
1
;
constexpr
index_t
GemmM
=
K
;
constexpr
index_t
GemmN
=
N
*
Ho
*
Wo
;
constexpr
index_t
GridSize
=
math
::
integer_divide_ceil
(
GemmM
,
GemmMPerBlock
)
*
math
::
integer_divide_ceil
(
GemmN
,
GemmNPerBlock
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
constexpr
auto
gridwise_conv
=
GridwiseConvolutionImplicitGemm_v4r4_xdlops_fwd_fp16_nchw_kcyx_nkhw
<
GridSize
,
BlockSize
,
half
,
float
,
decltype
(
in_nchw_desc
),
decltype
(
wei_kcyx_desc
),
decltype
(
out_nkhw_desc
),
ConvStrides
,
ConvDilations
,
InLeftPads
,
InRightPads
,
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmKPACK
,
GemmMPerWave
,
GemmNPerWave
,
ThreadGemmDataPerReadM
,
ThreadGemmDataPerReadN
,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM_GemmKPACK
,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM_GemmKPACK
,
GemmABlockCopySrcDataPerRead_GemmKPACK
,
GemmABlockCopyDstDataPerWrite_GemmKPACK
,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN_GemmKPACK
,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN_GemmKPACK
,
GemmBBlockCopySrcDataPerRead_GemmN
,
GemmBBlockCopyDstDataPerWrite_GemmKPACK
>
{};
for
(
index_t
i
=
0
;
i
<
10
;
++
i
)
{
float
time
=
launch_and_time_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
printf
(
"Elapsed time : %f ms, %f TFlop/s
\n
"
,
time
,
(
float
)
calculate_convolution_flops
(
InDesc
{},
WeiDesc
{},
OutDesc
{})
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
time
);
}
// warm up
printf
(
"Warn up running %d times...
\n
"
,
nrepeat
);
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
launch_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
}
printf
(
"Start running %d times...
\n
"
,
nrepeat
);
cudaDeviceSynchronize
();
auto
start
=
std
::
chrono
::
steady_clock
::
now
();
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
launch_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
}
cudaDeviceSynchronize
();
auto
end
=
std
::
chrono
::
steady_clock
::
now
();
float
ave_time
=
std
::
chrono
::
duration
<
float
,
std
::
milli
>
(
end
-
start
).
count
()
/
nrepeat
;
printf
(
"Average elapsed time : %f ms, %f TFlop/s
\n
"
,
ave_time
,
(
float
)
calculate_convolution_flops
(
InDesc
{},
WeiDesc
{},
OutDesc
{})
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
);
out_nkhw_device_buf
.
FromDevice
(
out_nkhw
.
mData
.
data
());
}
driver/src/conv_driver.cpp
View file @
2b8e3ece
...
@@ -21,6 +21,7 @@
...
@@ -21,6 +21,7 @@
#include "device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw.hpp"
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
...
@@ -524,8 +525,8 @@ int main(int argc, char* argv[])
...
@@ -524,8 +525,8 @@ int main(int argc, char* argv[])
print_sequence
(
"ConvStrides"
,
ConvStrides
{});
print_sequence
(
"ConvStrides"
,
ConvStrides
{});
print_sequence
(
"ConvDilations"
,
ConvDilations
{});
print_sequence
(
"ConvDilations"
,
ConvDilations
{});
using
in_data_t
=
float
;
using
in_data_t
=
half
;
using
out_data_t
=
float
;
using
out_data_t
=
half
;
Tensor
<
in_data_t
>
in_nchw
(
make_TensorDescriptor
(
in_nchw_desc
));
Tensor
<
in_data_t
>
in_nchw
(
make_TensorDescriptor
(
in_nchw_desc
));
Tensor
<
in_data_t
>
wei_kcyx
(
make_TensorDescriptor
(
wei_kcyx_desc
));
Tensor
<
in_data_t
>
wei_kcyx
(
make_TensorDescriptor
(
wei_kcyx_desc
));
Tensor
<
out_data_t
>
out_nkhw_host
(
make_TensorDescriptor
(
out_nkhw_desc
));
Tensor
<
out_data_t
>
out_nkhw_host
(
make_TensorDescriptor
(
out_nkhw_desc
));
...
@@ -616,7 +617,7 @@ int main(int argc, char* argv[])
...
@@ -616,7 +617,7 @@ int main(int argc, char* argv[])
RightPads
{},
RightPads
{},
nrepeat
);
nrepeat
);
#elif 1
#elif 1
device_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
(
in_nchw_desc
,
device_convolution_implicit_gemm_v4r4_xdlops_
fp16_
nchw_kcyx_nkhw
(
in_nchw_desc
,
in_nchw
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx_desc
,
wei_kcyx
,
wei_kcyx
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment