Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
2185affb
Commit
2185affb
authored
Jul 26, 2019
by
Tejash Shah
Browse files
Added fp16 support in implicit gemm
parent
c15ff3c8
Changes
19
Show whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
1780 additions
and
355 deletions
+1780
-355
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_buffer.hpp
...t_gemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_buffer.hpp
+418
-0
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4_nchw_kc1x1_nkhw_lds_double_buffer.hpp
...on_implicit_gemm_v4_nchw_kc1x1_nkhw_lds_double_buffer.hpp
+461
-0
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw_lds_double_buffer.hpp
...ion_implicit_gemm_v4_nchw_kcyx_nkhw_lds_double_buffer.hpp
+10
-23
composable_kernel/include/tensor_description/ConstantMergedTensorDescriptor.hpp
...ude/tensor_description/ConstantMergedTensorDescriptor.hpp
+4
-4
composable_kernel/include/tensor_description/ConstantTensorDescriptor.hpp
...l/include/tensor_description/ConstantTensorDescriptor.hpp
+13
-15
composable_kernel/include/tensor_operation/blockwise_gemm.hpp
...osable_kernel/include/tensor_operation/blockwise_gemm.hpp
+72
-35
composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp
.../tensor_operation/blockwise_generic_tensor_slice_copy.hpp
+57
-60
composable_kernel/include/tensor_operation/threadwise_gemm.hpp
...sable_kernel/include/tensor_operation/threadwise_gemm.hpp
+75
-11
composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp
...tensor_operation/threadwise_generic_tensor_slice_copy.hpp
+26
-8
composable_kernel/include/utility/Sequence.hpp
composable_kernel/include/utility/Sequence.hpp
+17
-18
composable_kernel/include/utility/amd_inline_asm.hpp
composable_kernel/include/utility/amd_inline_asm.hpp
+52
-0
composable_kernel/include/utility/bfloat16_dev.hpp
composable_kernel/include/utility/bfloat16_dev.hpp
+125
-0
composable_kernel/include/utility/common_header.hpp
composable_kernel/include/utility/common_header.hpp
+6
-0
composable_kernel/include/utility/float_types.h
composable_kernel/include/utility/float_types.h
+111
-0
composable_kernel/include/utility/integral_constant.hpp
composable_kernel/include/utility/integral_constant.hpp
+14
-48
composable_kernel/include/utility/math.hpp
composable_kernel/include/utility/math.hpp
+10
-6
composable_kernel/include/utility/vector_type.hpp
composable_kernel/include/utility/vector_type.hpp
+119
-8
driver/include/device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp
...de/device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp
+181
-110
driver/src/driver.cpp
driver/src/driver.cpp
+9
-9
No files found.
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_buffer.hpp
0 → 100644
View file @
2185affb
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_FP16_BFP16_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_FP16_BFP16_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
namespace
ck
{
// define B = merge(N0, Ho, Wo)
template
<
index_t
GridSize
,
index_t
BlockSize
,
class
Float
,
class
AccDataType
,
class
InGlobalDesc
,
class
WeiGlobalDesc
,
class
OutGlobalDesc
,
class
ConvStrides
,
class
ConvDilations
,
index_t
BPerBlock
,
index_t
KPerBlock
,
index_t
EPerBlock
,
index_t
N1
,
index_t
N2
,
index_t
ES
,
index_t
GemmMPerThreadSubC
,
index_t
GemmNPerThreadSubC
,
index_t
GemmMLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
GemmKPerThreadLoop
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadB
,
class
InBlockCopySubLengths_E_N1_B_N2_ES
,
class
InBlockCopyClusterLengths_E_N1_B_N2_ES
,
class
InBlockCopyThreadClusterArrangeOrder
,
class
InBlockCopySrcAccessOrder
,
class
InBlockCopyDstAccessOrder
,
index_t
InBlockCopySrcDataPerRead_B
,
index_t
InBlockCopyDstDataPerWrite_N2
,
class
WeiBlockCopySubLengths_E_K_ES
,
class
WeiBlockCopyClusterLengths_E_K_ES
,
class
WeiBlockCopyThreadClusterArrangeOrder
,
class
WeiBlockCopySrcAccessOrder
,
class
WeiBlockCopyDstAccessOrder
,
index_t
WeiBlockCopySrcDataPerRead_E
,
index_t
WeiBlockCopyDstDataPerWrite_K
>
struct
GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_buffer
{
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
const
{
// this is a mess
// TODO: find more elegent way of specifying (or calculating) performance parameters
static_assert
(
N2
==
GemmNPerThreadSubC
,
"wrong!"
);
static_assert
((
N1
*
N2
*
BPerBlock
)
%
(
GemmNPerThreadSubC
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
)
==
0
,
"wrong!"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I4
=
Number
<
4
>
{};
constexpr
auto
I5
=
Number
<
5
>
{};
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
in_n_c_h_w_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_k_c_y_x_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_n_k_h_w_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
N
=
in_n_c_h_w_global_desc
.
GetLength
(
I0
);
constexpr
index_t
C
=
in_n_c_h_w_global_desc
.
GetLength
(
I1
);
constexpr
index_t
K
=
out_n_k_h_w_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Ho
=
out_n_k_h_w_global_desc
.
GetLength
(
I2
);
constexpr
index_t
Wo
=
out_n_k_h_w_global_desc
.
GetLength
(
I3
);
constexpr
index_t
Y
=
wei_k_c_y_x_global_desc
.
GetLength
(
I2
);
constexpr
index_t
X
=
wei_k_c_y_x_global_desc
.
GetLength
(
I3
);
static_assert
(
N
%
(
N1
*
N2
)
==
0
,
"wrong! cannot divice N evenly among thread"
);
constexpr
index_t
N0
=
N
/
(
N1
*
N2
);
constexpr
index_t
B
=
N0
*
Ho
*
Wo
;
// ES=1 for float32, =2 for bfloat16, =4 for float16
static_assert
(
C
%
ES
==
0
,
"C needs to be multiple of vectorized C (ES)"
);
constexpr
auto
nonVectorizedC
=
C
/
ES
;
constexpr
index_t
E
=
nonVectorizedC
*
Y
*
X
;
// divide block work by [K, B]
static_assert
(
K
%
KPerBlock
==
0
&&
B
%
BPerBlock
==
0
&&
E
%
(
2
*
EPerBlock
)
==
0
,
"wrong! cannot divide work evenly among block"
);
constexpr
index_t
KBlockWork
=
K
/
KPerBlock
;
constexpr
index_t
BBlockWork
=
B
/
BPerBlock
;
constexpr
auto
block_work_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
KBlockWork
,
BBlockWork
>
{});
const
auto
block_work_multi_id
=
block_work_desc
.
GetMultiIndexFrom1dIndex
(
get_block_1d_id
());
const
index_t
k_block_data_on_global
=
block_work_multi_id
[
0
]
*
KPerBlock
;
const
index_t
b_block_data_on_global
=
block_work_multi_id
[
1
]
*
BPerBlock
;
// input tensor
// tensor descriptor in device memory [N0, N1, N2, Ho, Wo, {2C/4C}]
constexpr
auto
in_n0_n1_n2_h_w_2cor4c_global_desc
=
in_n_c_h_w_global_desc
.
StridedSlice
(
I2
,
Number
<
Ho
>
{},
Number
<
ConvStrides
::
Get
(
I0
)
>
{})
.
StridedSlice
(
I3
,
Number
<
Wo
>
{},
Number
<
ConvStrides
::
Get
(
I1
)
>
{})
.
Fold
(
I1
,
Number
<
nonVectorizedC
>
{})
.
Fold
(
I0
,
Number
<
N1
>
{},
Number
<
N2
>
{})
.
Extract
(
Sequence
<
0
,
1
,
2
,
3
,
5
,
6
>
{})
.
ReorderGivenNew2Old
(
Sequence
<
0
,
1
,
2
,
4
,
5
,
3
>
{});
// batch descritpor for device memory
constexpr
auto
in_c_y_x_global_desc
=
in_n_c_h_w_global_desc
.
StridedSlice
(
I2
,
Number
<
Y
>
{},
Number
<
ConvDilations
::
Get
(
I0
)
>
{})
.
StridedSlice
(
I3
,
Number
<
X
>
{},
Number
<
ConvDilations
::
Get
(
I1
)
>
{})
.
Fold
(
I1
,
Number
<
nonVectorizedC
>
{})
.
Extract
(
Sequence
<
2
,
3
,
4
>
{});
// merged tensor descriptor in device memory [E, N1, B, N2, {2E/4E}], src of blockwise
// copy
constexpr
auto
in_e_n1_b_n2_2eor4e_global_merged_desc
=
make_ConstantMergedTensorDescriptor
(
in_c_y_x_global_desc
.
Embed
(
in_n0_n1_n2_h_w_2cor4c_global_desc
),
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
4
>
{},
Sequence
<
3
,
6
,
7
>
{},
Sequence
<
5
>
{},
Sequence
<
8
>
{});
// memory layout descriptor in LDS [E, N1, B, N2, {2C/4C}], dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
in_e_n1_b_n2_2eor4e_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
EPerBlock
,
N1
,
BPerBlock
,
N2
,
ES
>
{},
Number
<
InBlockCopyDstDataPerWrite_N2
>
{});
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with multiple alignment
// requirements
static_assert
(
in_e_n1_b_n2_2eor4e_block_desc
.
GetStride
(
I1
)
%
GemmDataPerReadB
==
0
,
"GemmDataPerReadB alignment requirement is not satisfied"
);
// input blockwise copy
// slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in
auto
blockwise_in_copy
=
BlockwiseGenericTensorSliceCopy_v1
<
BlockSize
,
Float
,
decltype
(
in_e_n1_b_n2_2eor4e_global_merged_desc
),
decltype
(
in_e_n1_b_n2_2eor4e_block_desc
),
decltype
(
in_e_n1_b_n2_2eor4e_block_desc
.
GetLengths
()),
InBlockCopySubLengths_E_N1_B_N2_ES
,
InBlockCopyClusterLengths_E_N1_B_N2_ES
,
InBlockCopyThreadClusterArrangeOrder
,
InBlockCopySrcAccessOrder
,
InBlockCopyDstAccessOrder
,
InBlockCopySrcDataPerRead_B
,
InBlockCopyDstDataPerWrite_N2
>
({
0
,
0
,
b_block_data_on_global
,
0
,
0
},
{
0
,
0
,
0
,
0
,
0
});
// weight tensor
// tensor descriptor in device memory, src of blockwise copy
constexpr
auto
wei_e_k_2eor4e_global_desc
=
wei_k_c_y_x_global_desc
.
Fold
(
I1
,
Number
<
nonVectorizedC
>
{})
.
Unfold
(
I2
,
I4
)
.
ReorderGivenNew2Old
(
Sequence
<
2
,
0
,
1
>
{});
// tensor descriptor in LDS, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
wei_e_k_2eor4e_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
EPerBlock
,
KPerBlock
,
ES
>
{},
Number
<
math
::
lcm
(
WeiBlockCopyDstDataPerWrite_K
,
GemmDataPerReadA
)
>
{});
// operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in
auto
blockwise_wei_copy
=
BlockwiseGenericTensorSliceCopy_v1
<
BlockSize
,
Float
,
decltype
(
wei_e_k_2eor4e_global_desc
),
decltype
(
wei_e_k_2eor4e_block_desc
),
decltype
(
wei_e_k_2eor4e_block_desc
.
GetLengths
()),
WeiBlockCopySubLengths_E_K_ES
,
WeiBlockCopyClusterLengths_E_K_ES
,
WeiBlockCopyThreadClusterArrangeOrder
,
WeiBlockCopySrcAccessOrder
,
WeiBlockCopyDstAccessOrder
,
WeiBlockCopySrcDataPerRead_E
,
WeiBlockCopyDstDataPerWrite_K
>
(
{
0
,
k_block_data_on_global
,
0
},
{
0
,
0
,
0
});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[EPerBlock, KPerBlock ] is in LDS of type float/bfloat16 vec2/ float16 vec4
// b_mtx[EPerBlocl, N1 * BPerBlock * N2 ] is in LDS of type float/bfloat16 vec2/ float16
// vec4
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
// register
constexpr
auto
a_e_k_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
EPerBlock
>
{},
Number
<
KPerBlock
>
{});
constexpr
auto
b_e_n1bn2_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
EPerBlock
>
{},
Number
<
N1
*
BPerBlock
*
N2
>
{});
// sanity check
static_assert
(
KPerBlock
%
(
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
)
==
0
,
"wrong!"
);
constexpr
index_t
GemmMRepeat
=
KPerBlock
/
(
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr
auto
c_k0k2_n1n2_thread_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
GemmMRepeat
*
GemmMPerThreadSubC
>
{},
Number
<
N1
*
N2
>
{});
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
<
BlockSize
,
decltype
(
a_e_k_block_mtx_desc
),
decltype
(
b_e_n1bn2_block_mtx_desc
),
decltype
(
c_k0k2_n1n2_thread_mtx_desc
),
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{};
// LDS allocation for input and weight: be careful of alignment
constexpr
index_t
max_align
=
math
::
lcm
(
InBlockCopyDstDataPerWrite_N2
,
WeiBlockCopyDstDataPerWrite_K
,
GemmDataPerReadA
,
GemmDataPerReadB
);
constexpr
index_t
in_block_space
=
math
::
integer_least_multiple
(
in_e_n1_b_n2_2eor4e_block_desc
.
GetElementSpace
(),
max_align
);
constexpr
index_t
wei_block_space
=
math
::
integer_least_multiple
(
wei_e_k_2eor4e_block_desc
.
GetElementSpace
(),
max_align
);
__shared__
Float
p_in_block_double
[
2
*
in_block_space
];
__shared__
Float
p_wei_block_double
[
2
*
wei_block_space
];
// register allocation for output
AccDataType
p_out_thread
[
c_k0k2_n1n2_thread_mtx_desc
.
GetElementSpace
()];
// zero out threadwise output
threadwise_matrix_set_zero
(
c_k0k2_n1n2_thread_mtx_desc
,
p_out_thread
);
const
Float
*
p_wei_block_on_global
=
p_wei_global
;
// LDS double buffer: preload data into LDS
{
blockwise_in_copy
.
Run
(
p_in_global
,
p_in_block_double
);
blockwise_wei_copy
.
Run
(
p_wei_global
,
p_wei_block_double
);
}
// LDS double buffer: main body
for
(
index_t
e_block_data_begin
=
0
;
e_block_data_begin
+
2
*
EPerBlock
<
E
;
e_block_data_begin
+=
2
*
EPerBlock
)
{
// hcc compilation error: loop not unrolled: the optimizer was unable to perform the
// requested transformation;
// the transformation might be disabled or specified as part of an unsupported
// transformation
// ordering [-Werror,-Wpass-failed=transform-warning]
//#pragma unroll
for
(
index_t
iloop
=
0
;
iloop
<
2
;
++
iloop
)
{
const
bool
even_loop
=
(
iloop
%
2
==
0
);
Float
*
p_in_block_now
=
even_loop
?
p_in_block_double
:
p_in_block_double
+
in_block_space
;
Float
*
p_wei_block_now
=
even_loop
?
p_wei_block_double
:
p_wei_block_double
+
wei_block_space
;
Float
*
p_in_block_next
=
even_loop
?
p_in_block_double
+
in_block_space
:
p_in_block_double
;
Float
*
p_wei_block_next
=
even_loop
?
p_wei_block_double
+
wei_block_space
:
p_wei_block_double
;
Float
p_in_register_clipboard
[
blockwise_in_copy
.
GetRegisterClipboardSize
()];
Float
p_wei_register_clipboard
[
blockwise_wei_copy
.
GetRegisterClipboardSize
()];
blockwise_in_copy
.
MoveSlicingWindowOnSourceTensor
(
I0
,
Number
<
EPerBlock
>
{},
True
);
p_wei_block_on_global
+=
EPerBlock
*
wei_e_k_2eor4e_global_desc
.
GetStride
(
I0
);
__syncthreads
();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy
.
RunLoadRegisterClipboard
(
p_in_global
,
p_in_register_clipboard
);
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_block_on_global
,
p_wei_register_clipboard
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_wei_block_now
,
p_in_block_now
,
p_out_thread
);
// LDS double buffer: store next data to LDS
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
p_in_block_next
);
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_register_clipboard
,
p_wei_block_next
);
}
}
// LDS double buffer: tail
{
Float
p_in_register_clipboard
[
blockwise_in_copy
.
GetRegisterClipboardSize
()];
Float
p_wei_register_clipboard
[
blockwise_wei_copy
.
GetRegisterClipboardSize
()];
// even iteration
blockwise_in_copy
.
MoveSlicingWindowOnSourceTensor
(
I0
,
Number
<
EPerBlock
>
{},
True
);
p_wei_block_on_global
+=
EPerBlock
*
wei_e_k_2eor4e_global_desc
.
GetStride
(
I0
);
__syncthreads
();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy
.
RunLoadRegisterClipboard
(
p_in_global
,
p_in_register_clipboard
);
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_block_on_global
,
p_wei_register_clipboard
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_wei_block_double
,
p_in_block_double
,
p_out_thread
);
// LDS double buffer: store next data to LDS
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
p_in_block_double
+
in_block_space
);
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_register_clipboard
,
p_wei_block_double
+
wei_block_space
);
// odd iteration
__syncthreads
();
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_wei_block_double
+
wei_block_space
,
p_in_block_double
+
in_block_space
,
p_out_thread
);
}
// copy output: register to global memory
{
constexpr
index_t
K2
=
GemmMPerThreadSubC
;
constexpr
index_t
K1
=
GemmMLevel0Cluster
*
GemmMLevel1Cluster
;
// define tensor descriptor for threadwise copy
// output memory layout descriptor in register
constexpr
auto
out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
KPerBlock
/
(
K1
*
K2
),
1
,
K2
,
N1
,
1
,
1
,
1
,
N2
>
{});
// output tensor descriptor in register, src of threadwise copy
constexpr
auto
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc
=
out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc
.
ReorderGivenNew2Old
(
Sequence
<
4
,
3
,
7
,
0
,
1
,
2
,
5
,
6
>
{});
// output memory layout descriptor in device memory, dst of threadwise copy
constexpr
auto
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc
=
out_n_k_h_w_global_desc
.
Fold
(
I1
,
Number
<
K1
>
{},
Number
<
K2
>
{})
.
Fold
(
I0
,
Number
<
N1
>
{},
Number
<
N2
>
{});
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
index_t
k_thread_data_on_global
=
k_block_data_on_global
+
c_thread_mtx_on_block
.
row
;
const
index_t
b_thread_data_on_global
=
b_block_data_on_global
+
c_thread_mtx_on_block
.
col
/
N2
;
// output merged global tensor descriptor, for calculating origin of thread tensor
// in global memory
constexpr
auto
out_k_n1_b_n2_global_merged_desc
=
make_ConstantMergedTensorDescriptor
(
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc
.
Unfold
(
I3
,
I5
),
Sequence
<
3
>
{},
Sequence
<
1
>
{},
Sequence
<
0
,
4
,
5
>
{},
Sequence
<
2
>
{});
// origin of dst in device memory
Float
*
p_out_thread_on_global
=
p_out_global
+
out_k_n1_b_n2_global_merged_desc
.
GetOffsetFromMultiIndex
(
k_thread_data_on_global
,
0
,
b_thread_data_on_global
,
0
);
threadwise_generic_tensor_slice_copy_v1
(
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc
,
p_out_thread
,
{
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
},
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc
,
p_out_thread_on_global
,
{
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
},
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc
.
GetLengths
(),
arithmetic_sequence_gen
<
0
,
8
,
1
>::
type
{},
Number
<
1
>
{});
}
}
};
}
// namespace ck
#endif
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4_nchw_kc1x1_nkhw_lds_double_buffer.hpp
0 → 100644
View file @
2185affb
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_NCHW_KC1x1_NKHW_LDS_DOUBLE_BUFFER
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_NCHW_KC1x1_NKHW_LDS_DOUBLE_BUFFER
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
namespace
ck
{
template
<
bool
isForw
,
class
DescForw
,
class
DescBack
>
struct
GetDesc
{
typename
std
::
conditional
<
isForw
,
DescForw
,
DescBack
>::
type
Desc
;
};
// define B = merge(N0, Ho, Wo)
template
<
index_t
GridSize
,
index_t
BlockSize
,
class
Float
,
class
AccDataType
,
class
InGlobalDesc
,
class
WeiGlobalDesc
,
class
OutGlobalDesc
,
class
ConvStrides
,
index_t
Direction
,
index_t
BPerBlock
,
index_t
KPerBlock
,
index_t
EPerBlock
,
index_t
N1
,
index_t
N2
,
index_t
GemmMPerThreadSubC
,
index_t
GemmNPerThreadSubC
,
index_t
GemmMLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
GemmKPerThreadLoop
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadB
,
class
InBlockCopySubLengths_E_N1_B_N2
,
class
InBlockCopyClusterLengths_E_N1_B_N2
,
class
InBlockCopyThreadClusterArrangeOrder
,
class
InBlockCopySrcAccessOrder
,
class
InBlockCopyDstAccessOrder
,
index_t
InBlockCopySrcDataPerRead_B
,
index_t
InBlockCopyDstDataPerWrite_N2
,
class
WeiBlockCopySubLengths_E_K
,
class
WeiBlockCopyClusterLengths_E_K
,
class
WeiBlockCopyThreadClusterArrangeOrder
,
class
WeiBlockCopySrcAccessOrder
,
class
WeiBlockCopyDstAccessOrder
,
index_t
WeiBlockCopySrcDataPerRead_E
,
index_t
WeiBlockCopyDstDataPerWrite_K
>
struct
GridwiseConvolutionImplicitGemm_v4_nchw_kc1x1_nkhw_lds_double_buffer
{
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
const
{
// this is a mess
// TODO: find more elegent way of specifying (or calculating) performance parameters
static_assert
(
N2
==
GemmNPerThreadSubC
,
"wrong!"
);
static_assert
((
N1
*
N2
*
BPerBlock
)
%
(
GemmNPerThreadSubC
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
)
==
0
,
"wrong!"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I4
=
Number
<
4
>
{};
constexpr
auto
I5
=
Number
<
5
>
{};
constexpr
auto
I6
=
Number
<
6
>
{};
constexpr
auto
I7
=
Number
<
7
>
{};
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
in_n_c_h_w_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_k_c_y_x_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_n_k_h_w_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
N
=
in_n_c_h_w_global_desc
.
GetLength
(
I0
);
constexpr
index_t
C
=
in_n_c_h_w_global_desc
.
GetLength
(
I1
);
constexpr
index_t
K
=
out_n_k_h_w_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Ho
=
out_n_k_h_w_global_desc
.
GetLength
(
I2
);
constexpr
index_t
Wo
=
out_n_k_h_w_global_desc
.
GetLength
(
I3
);
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
static_assert
(
N
%
(
N1
*
N2
)
==
0
,
"wrong! cannot divice N evenly among thread"
);
constexpr
index_t
N0
=
N
/
(
N1
*
N2
);
constexpr
index_t
B
=
N0
*
Ho
*
Wo
;
constexpr
index_t
E
=
C
*
Y
*
X
;
// divide block work by [K, B]
static_assert
(
K
%
KPerBlock
==
0
&&
B
%
BPerBlock
==
0
&&
E
%
(
2
*
EPerBlock
)
==
0
,
"wrong! cannot divide work evenly among block"
);
constexpr
index_t
KBlockWork
=
K
/
KPerBlock
;
constexpr
index_t
BBlockWork
=
B
/
BPerBlock
;
constexpr
auto
block_work_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
KBlockWork
,
BBlockWork
>
{});
const
auto
block_work_multi_id
=
block_work_desc
.
GetMultiIndexFrom1dIndex
(
get_block_1d_id
());
const
index_t
k_block_data_on_global
=
block_work_multi_id
[
0
]
*
KPerBlock
;
const
index_t
b_block_data_on_global
=
block_work_multi_id
[
1
]
*
BPerBlock
;
// input tensor
// tensor descriptor in device memory [N0, N1, N2, Ho, Wo]
constexpr
bool
isForward
=
Direction
==
1
;
constexpr
auto
in_n0_n1_n2_h_w_global_desc_forw
=
in_n_c_h_w_global_desc
.
StridedSlice
(
I2
,
Number
<
Ho
>
{},
Number
<
ConvStrides
::
Get
(
I0
)
>
{})
.
StridedSlice
(
I3
,
Number
<
Wo
>
{},
Number
<
ConvStrides
::
Get
(
I1
)
>
{})
.
Fold
(
I0
,
Number
<
N1
>
{},
Number
<
N2
>
{})
.
Extract
(
Sequence
<
0
,
1
,
2
,
4
,
5
>
{});
constexpr
auto
in_n0_n1_n2_h_w_global_desc_back
=
in_n_c_h_w_global_desc
.
Fold
(
I0
,
Number
<
N1
>
{},
Number
<
N2
>
{})
.
Extract
(
Sequence
<
0
,
1
,
2
,
4
,
5
>
{});
constexpr
auto
in_n0_n1_n2_h_w_global_desc
=
GetDesc
<
isForward
,
decltype
(
in_n0_n1_n2_h_w_global_desc_forw
),
decltype
(
in_n0_n1_n2_h_w_global_desc_back
)
>
{}
.
Desc
;
// batch descritpor for device memory
constexpr
auto
in_c_y_x_global_desc
=
in_n_c_h_w_global_desc
.
Slice
(
I2
,
Number
<
1
>
{})
.
Slice
(
I3
,
Number
<
X
>
{})
.
Extract
(
Sequence
<
1
,
2
,
3
>
{});
// merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy
constexpr
auto
in_e_n1_b_n2_global_merged_desc
=
make_ConstantMergedTensorDescriptor
(
in_c_y_x_global_desc
.
Embed
(
in_n0_n1_n2_h_w_global_desc
),
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
4
>
{},
Sequence
<
3
,
6
,
7
>
{},
Sequence
<
5
>
{});
// memory layout descriptor in LDS [E, N1, B, N2], dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
in_e_n1_b_n2_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
EPerBlock
,
N1
,
BPerBlock
,
N2
>
{},
Number
<
InBlockCopyDstDataPerWrite_N2
>
{});
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with multiple alignment
// requirements
static_assert
(
in_e_n1_b_n2_block_desc
.
GetStride
(
I1
)
%
GemmDataPerReadB
==
0
,
"GemmDataPerReadB alignment requirement is not satisfied"
);
// input blockwise copy
// slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in
auto
blockwise_in_copy
=
BlockwiseGenericTensorSliceCopy_v1
<
BlockSize
,
Float
,
decltype
(
in_e_n1_b_n2_global_merged_desc
),
decltype
(
in_e_n1_b_n2_block_desc
),
decltype
(
in_e_n1_b_n2_block_desc
.
GetLengths
()),
InBlockCopySubLengths_E_N1_B_N2
,
InBlockCopyClusterLengths_E_N1_B_N2
,
InBlockCopyThreadClusterArrangeOrder
,
InBlockCopySrcAccessOrder
,
InBlockCopyDstAccessOrder
,
InBlockCopySrcDataPerRead_B
,
InBlockCopyDstDataPerWrite_N2
>
(
{
0
,
0
,
b_block_data_on_global
,
0
},
{
0
,
0
,
0
,
0
});
// weight tensor
// tensor descriptor in device memory, src of blockwise copy
#if 1
constexpr
auto
wei_e_k_global_desc
=
wei_k_c_y_x_global_desc
.
ReorderGivenNew2Old
(
Sequence
<
1
,
0
>
{});
#else
constexpr
auto
wei_e_k_global_desc
=
wei_k_c_y_x_global_desc
;
#endif
// tensor descriptor in LDS, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
wei_e_k_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
EPerBlock
,
KPerBlock
>
{},
Number
<
math
::
lcm
(
WeiBlockCopyDstDataPerWrite_K
,
GemmDataPerReadA
)
>
{});
// operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in
auto
blockwise_wei_copy
=
BlockwiseGenericTensorSliceCopy_v1
<
BlockSize
,
Float
,
decltype
(
wei_e_k_global_desc
),
decltype
(
wei_e_k_block_desc
),
decltype
(
wei_e_k_block_desc
.
GetLengths
()),
WeiBlockCopySubLengths_E_K
,
WeiBlockCopyClusterLengths_E_K
,
WeiBlockCopyThreadClusterArrangeOrder
,
WeiBlockCopySrcAccessOrder
,
WeiBlockCopyDstAccessOrder
,
WeiBlockCopySrcDataPerRead_E
,
WeiBlockCopyDstDataPerWrite_K
>
(
{
0
,
k_block_data_on_global
},
{
0
,
0
});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[EPerBlock, KPerBlock] is in LDS
// b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
// register
constexpr
auto
a_e_k_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
EPerBlock
>
{},
Number
<
KPerBlock
>
{},
Number
<
wei_e_k_block_desc
.
GetStride
(
I0
)
>
{});
constexpr
auto
b_e_n1bn2_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
EPerBlock
>
{},
Number
<
N1
*
BPerBlock
*
N2
>
{},
Number
<
in_e_n1_b_n2_block_desc
.
GetStride
(
I0
)
>
{});
// sanity check
static_assert
(
KPerBlock
%
(
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
)
==
0
,
"wrong!"
);
constexpr
index_t
GemmMRepeat
=
KPerBlock
/
(
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr
auto
c_k0k2_n1n2_thread_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
GemmMRepeat
*
GemmMPerThreadSubC
>
{},
Number
<
N1
*
N2
>
{});
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
<
BlockSize
,
decltype
(
a_e_k_block_mtx_desc
),
decltype
(
b_e_n1bn2_block_mtx_desc
),
decltype
(
c_k0k2_n1n2_thread_mtx_desc
),
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{};
// LDS allocation for input and weight: be careful of alignment
constexpr
index_t
max_align
=
math
::
lcm
(
InBlockCopyDstDataPerWrite_N2
,
WeiBlockCopyDstDataPerWrite_K
,
GemmDataPerReadA
,
GemmDataPerReadB
);
constexpr
index_t
in_block_space
=
math
::
integer_least_multiple
(
in_e_n1_b_n2_block_desc
.
GetElementSpace
(),
max_align
);
constexpr
index_t
wei_block_space
=
math
::
integer_least_multiple
(
wei_e_k_block_desc
.
GetElementSpace
(),
max_align
);
__shared__
Float
p_in_block_double
[
2
*
in_block_space
];
__shared__
Float
p_wei_block_double
[
2
*
wei_block_space
];
// register allocation for output
AccDataType
p_out_thread
[
c_k0k2_n1n2_thread_mtx_desc
.
GetElementSpace
()];
// zero out threadwise output
threadwise_matrix_set_zero
(
c_k0k2_n1n2_thread_mtx_desc
,
p_out_thread
);
const
Float
*
p_wei_block_on_global
=
p_wei_global
;
// LDS double buffer: preload data into LDS
{
blockwise_in_copy
.
Run
(
p_in_global
,
p_in_block_double
);
blockwise_wei_copy
.
Run
(
p_wei_global
,
p_wei_block_double
);
}
// LDS double buffer: main body
for
(
index_t
e_block_data_begin
=
0
;
e_block_data_begin
+
2
*
EPerBlock
<
E
;
e_block_data_begin
+=
2
*
EPerBlock
)
{
#pragma unroll
for
(
index_t
iloop
=
0
;
iloop
<
2
;
++
iloop
)
{
const
bool
even_loop
=
(
iloop
%
2
==
0
);
Float
*
p_in_block_now
=
even_loop
?
p_in_block_double
:
p_in_block_double
+
in_block_space
;
Float
*
p_wei_block_now
=
even_loop
?
p_wei_block_double
:
p_wei_block_double
+
wei_block_space
;
Float
*
p_in_block_next
=
even_loop
?
p_in_block_double
+
in_block_space
:
p_in_block_double
;
Float
*
p_wei_block_next
=
even_loop
?
p_wei_block_double
+
wei_block_space
:
p_wei_block_double
;
Float
p_in_register_clipboard
[
blockwise_in_copy
.
GetRegisterClipboardSize
()];
Float
p_wei_register_clipboard
[
blockwise_wei_copy
.
GetRegisterClipboardSize
()];
blockwise_in_copy
.
MoveSlicingWindowOnSourceTensor
(
I0
,
Number
<
EPerBlock
>
{},
True
);
p_wei_block_on_global
+=
EPerBlock
*
wei_e_k_global_desc
.
GetStride
(
I0
);
__syncthreads
();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy
.
RunLoadRegisterClipboard
(
p_in_global
,
p_in_register_clipboard
);
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_block_on_global
,
p_wei_register_clipboard
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_wei_block_now
,
p_in_block_now
,
p_out_thread
);
// LDS double buffer: store next data to LDS
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
p_in_block_next
);
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_register_clipboard
,
p_wei_block_next
);
}
}
// LDS double buffer: tail
{
Float
p_in_register_clipboard
[
blockwise_in_copy
.
GetRegisterClipboardSize
()];
Float
p_wei_register_clipboard
[
blockwise_wei_copy
.
GetRegisterClipboardSize
()];
// even iteration
blockwise_in_copy
.
MoveSlicingWindowOnSourceTensor
(
I0
,
Number
<
EPerBlock
>
{},
True
);
p_wei_block_on_global
+=
EPerBlock
*
wei_e_k_global_desc
.
GetStride
(
I0
);
__syncthreads
();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy
.
RunLoadRegisterClipboard
(
p_in_global
,
p_in_register_clipboard
);
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_block_on_global
,
p_wei_register_clipboard
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_wei_block_double
,
p_in_block_double
,
p_out_thread
);
// LDS double buffer: store next data to LDS
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
p_in_block_double
+
in_block_space
);
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_register_clipboard
,
p_wei_block_double
+
wei_block_space
);
// odd iteration
__syncthreads
();
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_wei_block_double
+
wei_block_space
,
p_in_block_double
+
in_block_space
,
p_out_thread
);
}
// copy output: register to global memory
{
constexpr
index_t
K2
=
GemmMPerThreadSubC
;
constexpr
index_t
K1
=
GemmMLevel0Cluster
*
GemmMLevel1Cluster
;
// define tensor descriptor for threadwise copy
// output memory layout descriptor in register
constexpr
auto
out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
KPerBlock
/
(
K1
*
K2
),
1
,
K2
,
N1
,
1
,
1
,
1
,
N2
>
{});
// output tensor descriptor in register, src of threadwise copy
constexpr
auto
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc
=
out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc
.
ReorderGivenNew2Old
(
Sequence
<
4
,
3
,
7
,
0
,
1
,
2
,
5
,
6
>
{});
// output memory layout descriptor in device memory, dst of threadwise copy
constexpr
auto
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc_forw
=
out_n_k_h_w_global_desc
.
Fold
(
I1
,
Number
<
K1
>
{},
Number
<
K2
>
{})
.
Fold
(
I0
,
Number
<
N1
>
{},
Number
<
N2
>
{});
constexpr
auto
out_lengths_new
=
Sequence
<
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc_forw
.
GetLength
(
I0
),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc_forw
.
GetLength
(
I1
),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc_forw
.
GetLength
(
I2
),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc_forw
.
GetLength
(
I3
),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc_forw
.
GetLength
(
I4
),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc_forw
.
GetLength
(
I5
),
math
::
integer_divide_ceil
(
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc_forw
.
GetLength
(
I6
),
ConvStrides
{}.
Get
(
I0
)),
math
::
integer_divide_ceil
(
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc_forw
.
GetLength
(
I7
),
ConvStrides
{}.
Get
(
I1
))
>
{};
constexpr
auto
out_strides_new
=
Sequence
<
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc_forw
.
GetStride
(
I0
),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc_forw
.
GetStride
(
I1
),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc_forw
.
GetStride
(
I2
),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc_forw
.
GetStride
(
I3
),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc_forw
.
GetStride
(
I4
),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc_forw
.
GetStride
(
I5
),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc_forw
.
GetStride
(
I6
)
*
ConvStrides
{}.
Get
(
I0
),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc_forw
.
GetStride
(
I7
)
*
ConvStrides
{}.
Get
(
I1
)
>
{};
constexpr
auto
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc_back
=
make_ConstantTensorDescriptor
(
out_lengths_new
,
out_strides_new
);
constexpr
auto
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc
=
GetDesc
<
isForward
,
decltype
(
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc_forw
),
decltype
(
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc_back
)
>
{}
.
Desc
;
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
index_t
k_thread_data_on_global
=
k_block_data_on_global
+
c_thread_mtx_on_block
.
row
;
const
index_t
b_thread_data_on_global
=
b_block_data_on_global
+
c_thread_mtx_on_block
.
col
/
N2
;
// output merged global tensor descriptor, for calculating origin of thread tensor
// in global memory
constexpr
auto
out_k_n1_b_n2_global_merged_desc
=
make_ConstantMergedTensorDescriptor
(
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc
.
Unfold
(
I3
,
I5
),
Sequence
<
3
>
{},
Sequence
<
1
>
{},
Sequence
<
0
,
4
,
5
>
{},
Sequence
<
2
>
{});
// origin of dst in device memory
Float
*
p_out_thread_on_global
=
p_out_global
+
out_k_n1_b_n2_global_merged_desc
.
GetOffsetFromMultiIndex
(
k_thread_data_on_global
,
0
,
b_thread_data_on_global
,
0
);
threadwise_generic_tensor_slice_copy_v1
(
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc
,
p_out_thread
,
{
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
},
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc
,
p_out_thread_on_global
,
{
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
},
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc
.
GetLengths
(),
arithmetic_sequence_gen
<
0
,
8
,
1
>::
type
{},
Number
<
1
>
{});
}
}
};
}
// namespace ck
#endif
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw_lds_double_buffer.hpp
View file @
2185affb
...
...
@@ -15,6 +15,7 @@ namespace ck {
template
<
index_t
GridSize
,
index_t
BlockSize
,
class
Float
,
class
AccDataType
,
class
InGlobalDesc
,
class
WeiGlobalDesc
,
class
OutGlobalDesc
,
...
...
@@ -50,7 +51,8 @@ template <index_t GridSize,
index_t
WeiBlockCopyDstDataPerWrite_K
>
struct
GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
{
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
const
{
...
...
@@ -84,12 +86,6 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
constexpr
index_t
Y
=
wei_k_c_y_x_global_desc
.
GetLength
(
I2
);
constexpr
index_t
X
=
wei_k_c_y_x_global_desc
.
GetLength
(
I3
);
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
static_assert
(
N
%
(
N1
*
N2
)
==
0
,
"wrong! cannot divice N evenly among thread"
);
constexpr
index_t
N0
=
N
/
(
N1
*
N2
);
...
...
@@ -98,14 +94,6 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
constexpr
index_t
E
=
C
*
Y
*
X
;
// sanity-check for vectorized memory load
static_assert
(
ConvStrideW
==
1
||
InBlockCopySrcDataPerRead_B
==
1
,
"wrong! global vector load of input tensor is wrong"
);
static_assert
((
X
==
1
||
ConvDilationW
%
InBlockCopySrcDataPerRead_B
==
0
),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated"
);
// divide block work by [K, B]
static_assert
(
K
%
KPerBlock
==
0
&&
B
%
BPerBlock
==
0
&&
E
%
(
2
*
EPerBlock
)
==
0
,
"wrong! cannot divide work evenly among block"
);
...
...
@@ -125,15 +113,15 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
// input tensor
// tensor descriptor in device memory [N0, N1, N2, Ho, Wo]
constexpr
auto
in_n0_n1_n2_h_w_global_desc
=
in_n_c_h_w_global_desc
.
StridedSlice
(
I2
,
Number
<
Ho
>
{},
Number
<
ConvStride
H
>
{})
.
StridedSlice
(
I3
,
Number
<
Wo
>
{},
Number
<
ConvStride
W
>
{})
in_n_c_h_w_global_desc
.
StridedSlice
(
I2
,
Number
<
Ho
>
{},
Number
<
ConvStride
s
::
Get
(
I0
)
>
{})
.
StridedSlice
(
I3
,
Number
<
Wo
>
{},
Number
<
ConvStride
s
::
Get
(
I1
)
>
{})
.
Fold
(
I0
,
Number
<
N1
>
{},
Number
<
N2
>
{})
.
Extract
(
Sequence
<
0
,
1
,
2
,
4
,
5
>
{});
// batch descritpor for device memory
constexpr
auto
in_c_y_x_global_desc
=
in_n_c_h_w_global_desc
.
StridedSlice
(
I2
,
Number
<
Y
>
{},
Number
<
ConvDilation
H
>
{})
.
StridedSlice
(
I3
,
Number
<
X
>
{},
Number
<
ConvDilation
W
>
{})
in_n_c_h_w_global_desc
.
StridedSlice
(
I2
,
Number
<
Y
>
{},
Number
<
ConvDilation
s
::
Get
(
I0
)
>
{})
.
StridedSlice
(
I3
,
Number
<
X
>
{},
Number
<
ConvDilation
s
::
Get
(
I1
)
>
{})
.
Extract
(
Sequence
<
1
,
2
,
3
>
{});
// merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy
...
...
@@ -260,7 +248,7 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
__shared__
Float
p_wei_block_double
[
2
*
wei_block_space
];
// register allocation for output
Float
p_out_thread
[
c_k0k2_n1n2_thread_mtx_desc
.
GetElementSpace
()];
AccDataType
p_out_thread
[
c_k0k2_n1n2_thread_mtx_desc
.
GetElementSpace
()];
// zero out threadwise output
threadwise_matrix_set_zero
(
c_k0k2_n1n2_thread_mtx_desc
,
p_out_thread
);
...
...
@@ -332,7 +320,6 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_block_on_global
,
p_wei_register_clipboard
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_wei_block_double
,
p_in_block_double
,
p_out_thread
);
// LDS double buffer: store next data to LDS
...
...
composable_kernel/include/tensor_description/ConstantMergedTensorDescriptor.hpp
View file @
2185affb
...
...
@@ -37,7 +37,7 @@ struct ConstantMergedTensorDescriptor
return
OriginalTensorDesc
{};
}
__host__
__device__
static
constexpr
auto
GetNumOfDimension
()
{
return
Number
<
nDim
>
{}
;
}
__host__
__device__
static
constexpr
index_t
GetNumOfDimension
()
{
return
nDim
;
}
template
<
index_t
IDim
>
__host__
__device__
static
constexpr
auto
GetContainedOriginalDimensions
(
Number
<
IDim
>
)
...
...
@@ -52,7 +52,7 @@ struct ConstantMergedTensorDescriptor
}
template
<
index_t
IDim
>
__host__
__device__
static
constexpr
auto
GetLength
(
Number
<
IDim
>
)
__host__
__device__
static
constexpr
index_t
GetLength
(
Number
<
IDim
>
)
{
constexpr
auto
original_dims_partial
=
std
::
get
<
IDim
>
(
mOriginalDimMergeSeqs
);
...
...
@@ -60,7 +60,7 @@ struct ConstantMergedTensorDescriptor
}
template
<
index_t
IDim
>
__host__
__device__
static
constexpr
auto
GetStride
(
Number
<
IDim
>
)
__host__
__device__
static
constexpr
index_t
GetStride
(
Number
<
IDim
>
)
{
static_assert
(
!
ContainMultipleOriginalDimensions
(
Number
<
IDim
>
{}),
"wrong! stride of a merged dimension is undefined"
);
...
...
@@ -75,7 +75,7 @@ struct ConstantMergedTensorDescriptor
return
Sequence
<
OriginalTensorDesc
::
Extract
(
OriginalDimMergeSeqs
{}).
GetElementSize
()...
>
{};
}
__host__
__device__
static
constexpr
auto
GetElementSize
()
__host__
__device__
static
constexpr
index_t
GetElementSize
()
{
return
OriginalTensorDesc
::
GetElementSize
();
}
...
...
composable_kernel/include/tensor_description/ConstantTensorDescriptor.hpp
View file @
2185affb
...
...
@@ -43,22 +43,22 @@ struct ConstantTensorDescriptor
return
Sequence
<
IDim
>
{};
}
__host__
__device__
static
constexpr
auto
GetNumOfDimension
()
{
return
Number
<
nDim
>
{}
;
}
__host__
__device__
static
constexpr
index_t
GetNumOfDimension
()
{
return
nDim
;
}
__host__
__device__
static
constexpr
auto
GetLengths
()
{
return
Lengths
{};
}
__host__
__device__
static
constexpr
auto
GetStrides
()
{
return
Strides
{};
}
template
<
class
IDim
>
__host__
__device__
static
constexpr
auto
GetLength
(
IDim
)
template
<
index_t
I
>
__host__
__device__
static
constexpr
index_t
GetLength
(
Number
<
I
>
)
{
return
Lengths
::
Get
(
IDim
{});
return
Lengths
::
Get
(
Number
<
I
>
{});
}
template
<
class
IDim
>
__host__
__device__
static
constexpr
auto
GetStride
(
IDim
)
template
<
index_t
I
>
__host__
__device__
static
constexpr
index_t
GetStride
(
Number
<
I
>
)
{
return
Strides
::
Get
(
IDim
{});
return
Strides
::
Get
(
Number
<
I
>
{});
}
struct
lambda_AreDimensionsContinuous
...
...
@@ -102,18 +102,17 @@ struct ConstantTensorDescriptor
return
false
;
}
__host__
__device__
static
constexpr
auto
GetElementSize
()
__host__
__device__
static
constexpr
index_t
GetElementSize
()
{
return
Number
<
accumulate_on_sequence
(
Lengths
{},
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{})
>
{};
return
accumulate_on_sequence
(
Lengths
{},
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{});
}
__host__
__device__
static
constexpr
auto
GetElementSpace
()
__host__
__device__
static
constexpr
index_t
GetElementSpace
()
{
constexpr
index_t
element_space_unaligned
=
accumulate_on_sequence
(
(
GetLengths
()
-
Number
<
1
>
{})
*
GetStrides
(),
math
::
plus
<
index_t
>
{},
Number
<
1
>
{});
return
Number
<
element_space_unaligned
>
{}
;
return
element_space_unaligned
;
}
// emulate constexpr lambda
...
...
@@ -157,14 +156,13 @@ struct ConstantTensorDescriptor
}
template
<
index_t
...
Is
>
__host__
__device__
static
constexpr
auto
GetOffsetFromMultiIndex
(
Sequence
<
Is
...
>
)
__host__
__device__
static
constexpr
index_t
GetOffsetFromMultiIndex
(
Sequence
<
Is
...
>
)
{
static_assert
(
sizeof
...(
Is
)
==
nDim
,
"wrong! Dimension not consistent"
);
constexpr
auto
multi_id
=
Sequence
<
Is
...
>
{};
return
Number
<
accumulate_on_sequence
(
multi_id
*
GetStrides
(),
math
::
plus
<
index_t
>
{},
Number
<
0
>
{})
>
{};
return
accumulate_on_sequence
(
multi_id
*
GetStrides
(),
math
::
plus
<
index_t
>
{},
Number
<
0
>
{});
}
// emulate constexpr lambda
...
...
composable_kernel/include/tensor_operation/blockwise_gemm.hpp
View file @
2185affb
...
...
@@ -142,17 +142,15 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
// assertion for inline asm
static_assert
(
is_same
<
FloatA
,
float
>
{}
&&
is_same
<
FloatB
,
float
>
{}
&&
is_same
<
FloatC
,
float
>
{},
"Run_amd_asm only deal with float"
);
static_assert
(
MPerThreadSubC
==
4
&&
NPerThreadSubC
==
4
&&
KPerThreadLoop
==
1
&&
MPerThread
==
8
&&
NPerThread
==
8
,
"Run_amd_asm cannot deal with this GEMM shape yet"
);
static_assert
(
DataPerReadA
==
4
&&
DataPerReadB
==
4
,
"Run_amd_asm only do float4 read"
);
// If A and B datatype is float
static_if
<
std
::
is_same
<
FloatA
,
float
>::
value
&&
std
::
is_same
<
FloatB
,
float
>::
value
>
{}([
&
](
auto
)
{
using
Float4
=
vector_type
<
float
,
4
>::
MemoryType
;
Float4
*
reg_a
=
reinterpret_cast
<
Float4
*>
(
p_a_thread
);
...
...
@@ -183,6 +181,41 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
}
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
0
],
reg_c
[
8
],
reg_c
[
10
],
reg_c
[
12
],
reg_c
[
14
]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
1
],
reg_c
[
9
],
reg_c
[
11
],
reg_c
[
13
],
reg_c
[
15
]);
}).
Else
([
&
](
auto
)
{
// If A and B datatype is bfloat16/float16
using
Half4x4
=
vector_type
<
vector_type
<
half
,
4
>
,
4
>
;
using
Float4
=
vector_type
<
float
,
4
>::
MemoryType
;
Half4x4
*
reg_a
=
reinterpret_cast
<
Half4x4
*>
(
p_a_thread
);
Half4x4
*
reg_b
=
reinterpret_cast
<
Half4x4
*>
(
p_b_thread
);
Float4
*
reg_c
=
reinterpret_cast
<
Float4
*>
(
p_c_thread
);
reg_a
[
0
]
=
*
reinterpret_cast
<
const
Half4x4
*>
(
&
p_a_block
[
mMyThreadOffsetA
]);
reg_b
[
0
]
=
*
reinterpret_cast
<
const
Half4x4
*>
(
&
p_b_block
[
mMyThreadOffsetB
]);
reg_b
[
1
]
=
*
reinterpret_cast
<
const
Half4x4
*>
(
&
p_b_block
[
mMyThreadOffsetB
+
NPerLevel1Cluster
]);
reg_a
[
1
]
=
*
reinterpret_cast
<
const
Half4x4
*>
(
&
p_a_block
[
mMyThreadOffsetA
+
MPerLevel1Cluster
]);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
[
0
],
reg_c
[
2
],
reg_c
[
4
],
reg_c
[
6
]);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
1
],
reg_c
[
1
],
reg_c
[
3
],
reg_c
[
5
],
reg_c
[
7
]);
#pragma unroll
for
(
index_t
k
=
1
;
k
<
K
;
++
k
)
{
reg_a
[
0
]
=
*
reinterpret_cast
<
const
Half4x4
*>
(
&
p_a_block
[
mMyThreadOffsetA
+
k
*
M
]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
0
],
reg_c
[
8
],
reg_c
[
10
],
reg_c
[
12
],
reg_c
[
14
]);
reg_b
[
0
]
=
*
reinterpret_cast
<
const
Half4x4
*>
(
&
p_b_block
[
mMyThreadOffsetB
+
k
*
N
]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
1
],
reg_c
[
9
],
reg_c
[
11
],
reg_c
[
13
],
reg_c
[
15
]);
reg_b
[
1
]
=
*
reinterpret_cast
<
const
Half4x4
*>
(
&
p_b_block
[
mMyThreadOffsetB
+
k
*
N
+
NPerLevel1Cluster
]);
reg_a
[
1
]
=
*
reinterpret_cast
<
const
Half4x4
*>
(
&
p_a_block
[
mMyThreadOffsetA
+
k
*
M
+
MPerLevel1Cluster
]);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
[
0
],
reg_c
[
2
],
reg_c
[
4
],
reg_c
[
6
]);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
1
],
reg_c
[
1
],
reg_c
[
3
],
reg_c
[
5
],
reg_c
[
7
]);
}
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
0
],
reg_c
[
8
],
reg_c
[
10
],
reg_c
[
12
],
reg_c
[
14
]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
1
],
reg_c
[
9
],
reg_c
[
11
],
reg_c
[
13
],
reg_c
[
15
]);
});
}
#endif
...
...
@@ -204,11 +237,11 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr
index_t
NPerThread
=
c_thread_mtx
.
NCol
();
// thread A, B for GEMM
constexpr
auto
a_thread_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
K
PerThread
Loop
>
{},
Number
<
MPerThread
>
{});
constexpr
auto
a_thread_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
M
PerThread
>
{},
Number
<
MPerThread
>
{});
constexpr
auto
b_thread_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
K
PerThread
Loop
>
{},
Number
<
NPerThread
>
{});
constexpr
auto
b_thread_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
N
PerThread
>
{},
Number
<
NPerThread
>
{});
// thread A-sub, B-sub for copy
constexpr
auto
a_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
...
...
@@ -415,7 +448,11 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
{
#if CK_USE_AMD_INLINE_ASM && CK_BLOCKWISE_GEMM_USE_AMD_INLINE_ASM
static_if
<
std
::
is_same
<
FloatA
,
ushort
>::
value
&&
std
::
is_same
<
FloatB
,
ushort
>::
value
>
{}(
[
&
](
auto
)
{
Run_source
(
p_a_block
,
p_b_block
,
p_c_thread
);
})
.
Else
([
&
](
auto
)
{
// If A and B datatype is bfloat16/float16
Run_amd_asm
(
p_a_block
,
p_b_block
,
p_c_thread
);
});
#else
Run_source
(
p_a_block
,
p_b_block
,
p_c_thread
);
#endif
...
...
composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp
View file @
2185affb
...
...
@@ -10,13 +10,15 @@
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1 1
#endif
#define JOINTCAT(x, y) x##y
#define ASSERT_MSG_ARG1(msg, var1) JOINTCAT(msg, var1)
#define ASSERT_MSG_ARG2(msg, var1, va2) ASSERT_MSG_ARG1(JOINTCAT(msg, var1), var2)
namespace
ck
{
// slice a (normal or merged) tensor, and copy it into another (normal or merged) tensor
// memory layout (ordering of dimensions) can be different between src and dst.
// on a merged dimension that constains multiple original dimensions,
// its sub-length need to evenly divide the length of the last original dimension
// so each thread is effectively reading a normal (not merged) tensor
// memory layout (ordering of dimensions) can be different between src and dst
// For now, only support SubLengths[...] == 1 on a merged dimension
template
<
index_t
BlockSize
,
class
Float
,
class
SrcDesc
,
...
...
@@ -77,15 +79,18 @@ struct BlockwiseGenericTensorSliceCopy_v1
// thread cluster
constexpr
auto
thread_cluster_desc
=
make_ConstantTensorDescriptor_packed
(
DataClusterLengths
::
ReorderGivenNew2Old
(
ThreadClusterArrangeOrder
{}));
DataClusterLengths
{}.
ReorderGivenNew2Old
(
ThreadClusterArrangeOrder
{}));
// BlockSize
static_assert
(
BlockSize
==
thread_cluster_desc
.
GetElementSize
(),
"wrong! BlockSize"
);
static_assert
(
BlockSize
==
thread_cluster_desc
.
GetElementSize
(),
"wrong! block size doesn't match with thread cluster size."
);
// divide work
constexpr
auto
data_per_cluster_per_dims
=
SubLengths
{}
*
DataClusterLengths
{};
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
IDim
)
{
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
IDim_
)
{
constexpr
auto
IDim
=
decltype
(
IDim_
){};
static_assert
(
SliceLengths
::
Get
(
IDim
)
%
SubLengths
::
Get
(
IDim
)
==
0
,
"wrong! cannot evenly divide sliced tensor into sub-tensor"
);
...
...
@@ -93,23 +98,15 @@ struct BlockwiseGenericTensorSliceCopy_v1
"wrong! cannot evenly divide sliced tensor into cluster"
);
});
// on a merged dimension that constains multiple original dimensions,
// its sub-length need to evenly divide the length of the last original dimension,
// so each thread is effectively reading a normal (not merged) tensor
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
IDim
)
{
constexpr
auto
sub_length
=
SubLengths
::
Get
(
IDim
);
constexpr
auto
idim_original_src
=
SrcDesc
::
GetContainedOriginalDimensions
(
IDim
).
Back
();
static_assert
(
SrcDesc
::
GetOriginalTensorDescriptor
().
GetLength
(
idim_original_src
)
%
sub_length
==
0
,
"wrong!"
);
// for now, only support SubLengths == 1 on a merged dimension that constains
// multiple original dimensions
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
IDim_
)
{
constexpr
auto
IDim
=
decltype
(
IDim_
){};
constexpr
auto
idim_original_dst
=
DstDesc
::
GetContainedOriginalDimensions
(
IDim
).
Back
();
static_assert
(
DstDesc
::
GetOriginalTensorDescriptor
().
GetLength
(
idim_original_dst
)
%
sub_length
==
0
,
"wrong!"
);
static_assert
(
SubLengths
::
Get
(
IDim
)
==
1
||
(
!
SrcDesc
::
ContainMultipleOriginalDimensions
(
IDim
)
&&
!
DstDesc
::
ContainMultipleOriginalDimensions
(
IDim
)),
"wrong! only support Sub-Length == 1 on a merged dimension"
);
});
// calculate mThreadSrcOffset, mThreadDstOffset
...
...
@@ -129,25 +126,31 @@ struct BlockwiseGenericTensorSliceCopy_v1
dst_block_data_multi_id_begin
+
thread_data_multi_id_begin
);
// partial offset on each dimension
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
IDim
)
{
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
IDim_
)
{
constexpr
auto
IDim
=
decltype
(
IDim_
){};
constexpr
index_t
idim
=
IDim
;
constexpr
auto
src_partial_original_dims
=
SrcDesc
::
GetContainedOriginalDimensions
(
IDim
);
constexpr
auto
src_partial_original_desc
=
SrcDesc
::
GetOriginalTensorDescriptor
().
Extract
(
src_partial_original_dims
);
mThreadSrcPartialOffsets
(
ID
im
)
=
src_partial_original_desc
.
GetOffsetFromMultiIndex
(
mThreadSrcPartialOffsets
(
id
im
)
=
src_partial_original_desc
.
GetOffsetFromMultiIndex
(
extract_array
(
mThreadSrcOriginalMultiId
,
src_partial_original_dims
));
});
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
IDim
)
{
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
IDim_
)
{
constexpr
auto
IDim
=
decltype
(
IDim_
){};
constexpr
index_t
idim
=
IDim
;
constexpr
auto
dst_partial_original_dims
=
DstDesc
::
GetContainedOriginalDimensions
(
IDim
);
constexpr
auto
dst_partial_original_desc
=
DstDesc
::
GetOriginalTensorDescriptor
().
Extract
(
dst_partial_original_dims
);
mThreadDstPartialOffsets
(
ID
im
)
=
dst_partial_original_desc
.
GetOffsetFromMultiIndex
(
mThreadDstPartialOffsets
(
id
im
)
=
dst_partial_original_desc
.
GetOffsetFromMultiIndex
(
extract_array
(
mThreadDstOriginalMultiId
,
dst_partial_original_dims
));
});
...
...
@@ -181,8 +184,10 @@ struct BlockwiseGenericTensorSliceCopy_v1
constexpr
auto
thread_tensor_desc
=
make_ConstantTensorDescriptor_packed
(
thread_sub_tensor_lengths
*
repeat_lengths
);
static_ford
<
decltype
(
repeat_lengths
)
>
{}([
&
](
auto
repeat_multi_id_
)
{
#if CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1
static_ford
<
decltype
(
repeat_lengths
)
>
{}([
&
](
auto
repeat_multi_id
)
{
constexpr
auto
repeat_multi_id
=
decltype
(
repeat_multi_id_
){};
constexpr
auto
src_thread_data_multi_id_begin
=
repeat_multi_id
*
data_per_cluster_per_dims
;
...
...
@@ -195,25 +200,19 @@ struct BlockwiseGenericTensorSliceCopy_v1
constexpr
index_t
clipboard_offset
=
thread_tensor_desc
.
GetOffsetFromMultiIndex
(
clipboard_data_multi_id_begin
);
#else
ford
<
decltype
(
repeat_lengths
)
>
{}([
&
](
auto
repeat_multi_id
)
{
constexpr
auto
repeat_multi_id
=
sequence2array
(
decltype
(
repeat_multi_id_
){});
const
auto
src_thread_data_multi_id_begin
=
repeat_multi_id
*
data_per_cluster_per_dims
;
const
auto
clipboard_data_multi_id_begin
=
repeat_multi_id
*
thread_sub_tensor_lengths
;
const
index_t
src_offset
=
SrcDesc
::
GetOffsetFromMultiIndex
(
src_thread_data_multi_id_begin
);
SrcDesc
{}.
GetOffsetFromMultiIndex
(
src_thread_data_multi_id_begin
);
const
index_t
clipboard_offset
=
thread_tensor_desc
.
GetOffsetFromMultiIndex
(
clipboard_data_multi_id_begin
);
#endif
// By position the origin of the per-thread window at the point, where multi-index
// of the SrcDesc (might be a merged tensor) is all-zero. This threadwise slice copy
// is assuming each thread is copy a noraml (not merged) tensor.
// User need to guarantee this is true.
// By setting SubLengths = 1 at the merged dimension, this is always true;
// If in the future, you want to enable SubLengths > 1 at the merged dimension,
// special care in implementation is needed
threadwise_generic_tensor_slice_copy_v1
(
SrcDesc
{},
p_src
+
src_offset
+
mThreadSrcOffset
,
make_zero_array
<
index_t
,
nDim
>
(),
...
...
@@ -238,8 +237,10 @@ struct BlockwiseGenericTensorSliceCopy_v1
constexpr
auto
thread_tensor_desc
=
make_ConstantTensorDescriptor_packed
(
thread_sub_tensor_lengths
*
repeat_lengths
);
static_ford
<
decltype
(
repeat_lengths
)
>
{}([
&
](
auto
repeat_multi_id_
)
{
#if CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1
static_ford
<
decltype
(
repeat_lengths
)
>
{}([
&
](
auto
repeat_multi_id
)
{
constexpr
auto
repeat_multi_id
=
decltype
(
repeat_multi_id_
){};
constexpr
auto
clipboard_data_multi_id_begin
=
repeat_multi_id
*
thread_sub_tensor_lengths
;
...
...
@@ -249,9 +250,10 @@ struct BlockwiseGenericTensorSliceCopy_v1
thread_tensor_desc
.
GetOffsetFromMultiIndex
(
clipboard_data_multi_id_begin
);
constexpr
index_t
dst_offset
=
DstDesc
::
GetOffsetFromMultiIndex
(
dst_data_multi_id_begin
);
DstDesc
{}.
GetOffsetFromMultiIndex
(
dst_data_multi_id_begin
);
#else
ford
<
decltype
(
repeat_lengths
)
>
{}([
&
](
auto
repeat_multi_id
)
{
constexpr
auto
repeat_multi_id
=
sequence2array
(
decltype
(
repeat_multi_id_
){});
const
auto
clipboard_data_multi_id_begin
=
repeat_multi_id
*
thread_sub_tensor_lengths
;
const
auto
dst_data_multi_id_begin
=
repeat_multi_id
*
data_per_cluster_per_dims
;
...
...
@@ -259,16 +261,9 @@ struct BlockwiseGenericTensorSliceCopy_v1
const
index_t
clipboard_offset
=
thread_tensor_desc
.
GetOffsetFromMultiIndex
(
clipboard_data_multi_id_begin
);
const
index_t
dst_offset
=
DstDesc
::
GetOffsetFromMultiIndex
(
dst_data_multi_id_begin
);
const
index_t
dst_offset
=
DstDesc
{}.
GetOffsetFromMultiIndex
(
dst_data_multi_id_begin
);
#endif
// By position the origin of the per-thread window at the point, where multi-index
// of the SrcDesc (might be a merged tensor) is all-zero. This threadwise slice copy
// is assuming each thread is copy a noraml (not merged) tensor.
// User need to guarantee this is true.
// By setting SubLengths = 1 at the merged dimension, this is always true;
// If in the future, you want to enable SubLengths > 1 at the merged dimension,
// special care in implementation is needed
threadwise_generic_tensor_slice_copy_v1
(
thread_tensor_desc
,
p_clipboard
+
clipboard_offset
,
make_zero_array
<
index_t
,
nDim
>
(),
...
...
@@ -303,6 +298,7 @@ struct BlockwiseGenericTensorSliceCopy_v1
Number
<
IDim_
>
,
Number
<
StepSize
>
,
integral_constant
<
bool
,
PositiveDirection
>
direction
)
{
constexpr
auto
IDim
=
Number
<
IDim_
>
{};
constexpr
index_t
idim
=
IDim
;
static_if
<
SrcDesc
::
ContainMultipleOriginalDimensions
(
IDim
)
>
{}([
&
](
auto
)
{
// logic for a merged dimension, also works for non-merged dimension, but its logic may
...
...
@@ -325,21 +321,22 @@ struct BlockwiseGenericTensorSliceCopy_v1
old_src_partial_original_multi_id
,
StepSize
,
direction
);
// update "mThreadSrcOriginalMultiId"
static_for
<
0
,
decltype
(
src_partial_original_dims
)
::
GetSize
(),
1
>
{}([
&
](
auto
I
)
{
constexpr
auto
IDimOriginal
=
src_partial_original_dims
[
I
];
static_for
<
0
,
decltype
(
src_partial_original_dims
)
::
GetSize
(),
1
>
{}([
&
](
auto
I_
)
{
constexpr
auto
I
=
decltype
(
I_
){};
constexpr
index_t
idim_original
=
src_partial_original_dims
.
Get
(
I
);
mThreadSrcOriginalMultiId
(
IDimO
riginal
)
=
new_src_partial_original_multi_id
[
I
];
mThreadSrcOriginalMultiId
(
idim_o
riginal
)
=
new_src_partial_original_multi_id
[
I
];
});
// calculate new partial offset on this merged dimension
const
index_t
old_src_partial_offset
=
mThreadSrcPartialOffsets
[
ID
im
];
const
index_t
old_src_partial_offset
=
mThreadSrcPartialOffsets
[
id
im
];
const
index_t
new_src_partial_offset
=
src_partial_original_desc
.
GetOffsetFromMultiIndex
(
new_src_partial_original_multi_id
);
// update "mThreadSrcPartialOffsets"
mThreadSrcPartialOffsets
(
ID
im
)
=
new_src_partial_offset
;
mThreadSrcPartialOffsets
(
id
im
)
=
new_src_partial_offset
;
// update "mThreadSrcOffset", do "+" before "-" to avoid underflow
mThreadSrcOffset
=
(
mThreadSrcOffset
+
new_src_partial_offset
)
-
old_src_partial_offset
;
...
...
@@ -354,20 +351,20 @@ struct BlockwiseGenericTensorSliceCopy_v1
// of the boundary of the tensor being sliced. Otherwise, there might be hazard like
// unsigned integer underflow. That is NO runtime sanity check to prevent the hazard
constexpr
auto
IDimO
riginal
=
SrcDesc
::
GetContainedOriginalDimensions
(
IDim
).
Front
();
constexpr
index_t
idim_o
riginal
=
SrcDesc
::
GetContainedOriginalDimensions
(
IDim
).
Front
();
static_if
<
PositiveDirection
>
{}([
&
](
auto
fwd
)
{
mThreadSrcOffset
+=
StepSize
*
fwd
(
SrcDesc
{}).
GetStride
(
IDim
);
mThreadSrcOriginalMultiId
(
IDimO
riginal
)
+=
StepSize
;
mThreadSrcOriginalMultiId
(
idim_o
riginal
)
+=
StepSize
;
mThreadSrcPartialOffsets
(
ID
im
)
+=
StepSize
*
fwd
(
SrcDesc
{}).
GetStride
(
IDim
);
mThreadSrcPartialOffsets
(
id
im
)
+=
StepSize
*
fwd
(
SrcDesc
{}).
GetStride
(
IDim
);
}).
Else
([
&
](
auto
fwd
)
{
mThreadSrcOffset
-=
StepSize
*
fwd
(
SrcDesc
{}).
GetStride
(
IDim
);
mThreadSrcOriginalMultiId
(
IDimO
riginal
)
-=
StepSize
;
mThreadSrcOriginalMultiId
(
idim_o
riginal
)
-=
StepSize
;
mThreadSrcPartialOffsets
(
ID
im
)
-=
StepSize
*
fwd
(
SrcDesc
{}).
GetStride
(
IDim
);
mThreadSrcPartialOffsets
(
id
im
)
-=
StepSize
*
fwd
(
SrcDesc
{}).
GetStride
(
IDim
);
});
});
}
...
...
composable_kernel/include/tensor_operation/threadwise_gemm.hpp
View file @
2185affb
...
...
@@ -3,6 +3,7 @@
#include "common_header.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "float_types.h"
namespace
ck
{
...
...
@@ -34,11 +35,15 @@ __device__ void threadwise_matrix_copy(SrcMatrix,
{
static_assert
(
NCol
%
DataPerRead
==
0
,
"wrong! should be NCol % == DataPerRead == 0"
);
using
vector_t
=
typename
vector_type
<
Float
,
DataPerRead
>::
MemoryType
;
constexpr
auto
src_mtx
=
SrcMatrix
{};
constexpr
auto
dst_mtx
=
DstMatrix
{};
// Depending upon datatype i.e float/half/bfloat16, carry out data movement
// in appropriate vectorized form
// float - 4, half - 4, bfloat16 - 2
static_if
<
std
::
is_same
<
Float
,
float
>::
value
>
{}([
&
](
auto
)
{
using
vector_t
=
typename
vector_type
<
float
,
DataPerRead
>::
MemoryType
;
for
(
index_t
i
=
0
;
i
<
NRow
;
++
i
)
{
for
(
index_t
j
=
0
;
j
<
NCol
;
j
+=
DataPerRead
)
...
...
@@ -50,6 +55,40 @@ __device__ void threadwise_matrix_copy(SrcMatrix,
*
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_index
]);
}
}
}).
Else
([
&
](
auto
)
{
static_if
<
std
::
is_same
<
Float
,
half
>::
value
>
{}([
&
](
auto
)
{
// If src/dst matrix datatype is bfloat16/float16 (vector size 2/4 respectively)
using
vector_t
=
typename
vector_type
<
Float
,
4
>::
MemoryType
;
for
(
index_t
i
=
0
;
i
<
NRow
;
++
i
)
{
for
(
index_t
j
=
0
;
j
<
NCol
;
++
j
)
{
const
index_t
src_index
=
src_mtx
.
GetOffsetFromMultiIndex
(
i
,
j
);
const
index_t
dst_index
=
dst_mtx
.
GetOffsetFromMultiIndex
(
i
,
j
);
*
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_index
*
4
])
=
*
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_index
*
4
]);
}
}
}).
Else
([
&
](
auto
)
{
using
vector_t
=
typename
vector_type
<
Float
,
2
>::
MemoryType
;
for
(
index_t
i
=
0
;
i
<
NRow
;
++
i
)
{
for
(
index_t
j
=
0
;
j
<
NCol
;
++
j
)
{
const
index_t
src_index
=
src_mtx
.
GetOffsetFromMultiIndex
(
i
,
j
);
const
index_t
dst_index
=
dst_mtx
.
GetOffsetFromMultiIndex
(
i
,
j
);
*
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_index
*
2
])
=
*
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_index
*
2
]);
}
}
});
});
}
template
<
class
MatrixA
,
...
...
@@ -90,7 +129,32 @@ __device__ void threadwise_gemm(MatrixA,
const
index_t
bindex
=
b_mtx
.
GetOffsetFromMultiIndex
(
k
,
j
);
const
index_t
cindex
=
c_mtx
.
GetOffsetFromMultiIndex
(
i
,
j
);
p_c_thread
[
cindex
]
+=
p_a_thread
[
aindex
]
*
p_b_thread
[
bindex
];
static_if
<
std
::
is_same
<
FloatA
,
float
>::
value
>
{}([
&
](
auto
)
{
p_c_thread
[
cindex
]
+=
CVT_FLOAT2ACCUM
(
p_a_thread
[
aindex
])
*
CVT_FLOAT2ACCUM
(
p_b_thread
[
bindex
]);
}).
Else
([
&
](
auto
)
{
static_if
<
std
::
is_same
<
FloatA
,
half
>::
value
>
{}([
&
](
auto
)
{
// If src/dst matrix datatype is bfloat16/float16 (vector size 2/4
// respectively)
float
acc
=
0.0
;
for
(
index_t
v
=
0
;
v
<
4
;
++
v
)
{
acc
+=
CVT_FLOAT2ACCUM
(
p_a_thread
[
aindex
*
4
+
v
])
*
CVT_FLOAT2ACCUM
(
p_b_thread
[
bindex
*
4
+
v
]);
}
p_c_thread
[
cindex
]
=
acc
;
}).
Else
([
&
](
auto
)
{
// If src/dst matrix datatype is bfloat16/float16 (vector size 2/4
// respectively)
float
acc
=
0.0
;
for
(
index_t
v
=
0
;
v
<
2
;
++
v
)
{
acc
+=
CVT_FLOAT2ACCUM
(
p_a_thread
[
aindex
*
2
+
v
])
*
CVT_FLOAT2ACCUM
(
p_b_thread
[
bindex
*
2
+
v
]);
}
p_c_thread
[
cindex
]
+=
acc
;
});
});
}
}
}
...
...
composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp
View file @
2185affb
...
...
@@ -4,6 +4,7 @@
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "float_types.h"
#ifndef CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1 0
...
...
@@ -12,7 +13,8 @@
namespace
ck
{
// user need to make sure alignment requirement is satisfied when setting DataPerAccesss > 1
template
<
class
Float
,
template
<
class
SrcFloat
,
class
DesFloat
,
class
SrcDesc
,
class
DstDesc
,
class
SliceLengths
,
...
...
@@ -20,10 +22,10 @@ template <class Float,
index_t
DataPerAccess
>
__device__
void
threadwise_generic_tensor_slice_copy_v1
(
SrcDesc
,
const
Float
*
__restrict__
p_src
,
const
Src
Float
*
__restrict__
p_src
,
Array
<
index_t
,
SrcDesc
::
GetNumOfDimension
()
>
src_multi_id_begin
,
DstDesc
,
Float
*
__restrict__
p_dst
,
Des
Float
*
__restrict__
p_dst
,
Array
<
index_t
,
DstDesc
::
GetNumOfDimension
()
>
dst_multi_id_begin
,
SliceLengths
,
DimAccessOrder
,
...
...
@@ -64,7 +66,8 @@ __device__ void threadwise_generic_tensor_slice_copy_v1(
constexpr
auto
access_lengths
=
slice_lengths_in_access_order
.
Modify
(
Number
<
nDim
-
1
>
{},
Number
<
num_access_on_lowest_access_dimension
>
{});
using
vector_t
=
typename
vector_type
<
Float
,
DataPerAccess
>::
MemoryType
;
using
vector_src_t
=
typename
vector_type
<
SrcFloat
,
DataPerAccess
>::
MemoryType
;
using
vector_dest_t
=
typename
vector_type
<
DesFloat
,
DataPerAccess
>::
MemoryType
;
#if CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1
static_ford
<
decltype
(
access_lengths
)
>
{}([
&
](
auto
access_multi_id
)
{
...
...
@@ -82,8 +85,15 @@ __device__ void threadwise_generic_tensor_slice_copy_v1(
const
index_t
dst_index
=
DstDesc
::
GetOffsetFromMultiIndex
(
dst_multi_id_begin
+
data_multi_id
);
*
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_index
])
=
*
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_index
]);
static_if
<
std
::
is_same
<
vector_src_t
,
vector_dest_t
>::
value
>
{}([
&
](
auto
)
{
*
reinterpret_cast
<
vector_dest_t
*>
(
&
p_dst
[
dst_index
])
=
*
reinterpret_cast
<
const
vector_src_t
*>
(
&
p_src
[
src_index
]);
}).
Else
([
&
](
auto
)
{
for
(
unsigned
int
data_idx
=
0
;
data_idx
<
DataPerAccess
;
++
data_idx
)
{
p_dst
[
dst_index
+
data_idx
]
=
CVT_ACCUM2FLOAT
(
p_src
[
src_index
+
data_idx
]);
}
});
});
#else
ford
<
decltype
(
access_lengths
)
>
{}([
&
](
auto
access_multi_id
)
{
...
...
@@ -99,8 +109,16 @@ __device__ void threadwise_generic_tensor_slice_copy_v1(
const
index_t
dst_index
=
DstDesc
::
GetOffsetFromMultiIndex
(
dst_multi_id_begin
+
data_multi_id
);
*
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_index
])
=
*
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_index
]);
static_if
<
std
::
is_same
<
vector_src_t
,
vector_dest_t
>::
value
>
{}([
&
](
auto
)
{
*
reinterpret_cast
<
vector_dest_t
*>
(
&
p_dst
[
dst_index
])
=
*
reinterpret_cast
<
const
vector_src_t
*>
(
&
p_src
[
src_index
]);
//printf("%f ", static_cast<float>(p_dst[dst_index]));
}).
Else
([
&
](
auto
)
{
for
(
unsigned
int
data_idx
=
0
;
data_idx
<
DataPerAccess
;
++
data_idx
)
{
p_dst
[
dst_index
+
data_idx
]
=
CVT_ACCUM2FLOAT
(
p_src
[
src_index
+
data_idx
]);
}
});
});
#endif
}
...
...
composable_kernel/include/utility/Sequence.hpp
View file @
2185affb
...
...
@@ -16,32 +16,31 @@ struct Sequence
static
constexpr
index_t
mSize
=
sizeof
...(
Is
);
__host__
__device__
static
constexpr
auto
GetSize
()
{
return
Number
<
mSize
>
{}
;
}
__host__
__device__
static
constexpr
index_t
GetSize
()
{
return
mSize
;
}
__host__
__device__
static
constexpr
index_t
GetImpl
(
index_t
I
)
template
<
index_t
I
>
__host__
__device__
static
constexpr
index_t
Get
(
Number
<
I
>
)
{
static_assert
(
I
<
mSize
,
"wrong! I too large"
);
// the last dummy element is to prevent compiler complain about empty array, when mSize = 0
const
index_t
mData
[
mSize
+
1
]
=
{
Is
...,
0
};
return
mData
[
I
];
}
template
<
index_t
I
>
__host__
__device__
static
constexpr
auto
Get
(
Number
<
I
>
)
__host__
__device__
constexpr
auto
operator
[]
(
Number
<
I
>
)
const
{
static_assert
(
I
<
mSize
,
"wrong! I too large"
);
return
Number
<
GetImpl
(
Number
<
I
>
{})
>
{};
return
Number
<
Get
(
Number
<
I
>
{})
>
{};
}
template
<
index_t
I
>
__host__
__device__
constexpr
auto
operator
[](
Number
<
I
>
)
const
// make sure I is constepxr
__host__
__device__
constexpr
index_t
operator
[](
index_t
I
)
const
{
return
Get
(
Number
<
I
>
{});
const
index_t
mData
[
mSize
+
1
]
=
{
Is
...,
0
};
return
mData
[
I
];
}
// make sure I is constepxr if you want a constexpr return type
__host__
__device__
constexpr
index_t
operator
[](
index_t
I
)
const
{
return
GetImpl
(
I
);
}
template
<
index_t
...
IRs
>
__host__
__device__
static
constexpr
auto
ReorderGivenNew2Old
(
Sequence
<
IRs
...
>
/*new2old*/
)
{
...
...
@@ -55,16 +54,16 @@ struct Sequence
__host__
__device__
static
constexpr
auto
Reverse
();
__host__
__device__
static
constexpr
auto
Front
()
__host__
__device__
static
constexpr
index_t
Front
()
{
static_assert
(
mSize
>
0
,
"wrong!"
)
;
return
Get
(
Number
<
0
>
{})
;
const
index_t
mData
[
mSize
+
1
]
=
{
Is
...,
0
}
;
return
mData
[
0
]
;
}
__host__
__device__
static
constexpr
auto
Back
()
__host__
__device__
static
constexpr
index_t
Back
()
{
static_assert
(
mSize
>
0
,
"wrong!"
)
;
return
Get
(
Number
<
mSize
-
1
>
{})
;
const
index_t
mData
[
mSize
+
1
]
=
{
Is
...,
0
}
;
return
mData
[
mSize
-
1
]
;
}
__host__
__device__
static
constexpr
auto
PopFront
();
...
...
composable_kernel/include/utility/amd_inline_asm.hpp
View file @
2185affb
...
...
@@ -118,6 +118,58 @@ __device__ void outerProduct4x4(const vector_type<float, 4>::MemoryType& a,
outerProduct1x4
(
a
.
w
,
b
,
c3
);
}
__device__
void
outerProduct1x4
(
const
half2
*
a
,
const
half2
*
b
,
float
*
c
)
{
asm
volatile
(
"
\n
\
v_dot2_f32_f16 %0, %4, %6 %0
\n
\
v_dot2_f32_f16 %1, %4, %8 %1
\n
\
v_dot2_f32_f16 %2, %4, %10 %2
\n
\
v_dot2_f32_f16 %3, %4, %12 %3
\n
\
v_dot2_f32_f16 %0, %5, %7 %0
\n
\
v_dot2_f32_f16 %1, %5, %9 %1
\n
\
v_dot2_f32_f16 %2, %5, %11 %2
\n
\
v_dot2_f32_f16 %3, %5, %13 %3
\n
\
"
:
"=v"
(
c
[
0
]),
"=v"
(
c
[
1
]),
"=v"
(
c
[
2
]),
"=v"
(
c
[
3
])
// Dest registers
:
"v"
(
a
[
0
]),
"v"
(
a
[
1
]),
// 1st Src registers for 2 half2 registers
"v"
(
b
[
0
]),
"v"
(
b
[
1
]),
"v"
(
b
[
2
]),
"v"
(
b
[
3
]),
// 2nd Src registers for 2 half2 registers
"v"
(
b
[
4
]),
"v"
(
b
[
5
]),
"v"
(
b
[
6
]),
"v"
(
b
[
7
]),
// 2nd Src registers for 2 half2 registers
"0"
(
c
[
0
]),
"1"
(
c
[
1
]),
"2"
(
c
[
2
]),
"3"
(
c
[
3
]));
// 3rd Src Acc registers for 2 half2 registers
}
__device__
void
outerProduct1x4Half
(
const
vector_type
<
half
,
4
>&
a
,
const
vector_type
<
vector_type
<
half
,
4
>
,
4
>&
b
,
vector_type
<
float
,
4
>::
MemoryType
&
c
)
{
outerProduct1x4
(
reinterpret_cast
<
const
half2
*>
(
&
a
),
reinterpret_cast
<
const
half2
*>
(
&
b
),
reinterpret_cast
<
float
*>
(
&
c
));
}
__device__
void
outerProduct4x4
(
const
vector_type
<
vector_type
<
half
,
4
>
,
4
>&
a
,
const
vector_type
<
vector_type
<
half
,
4
>
,
4
>&
b
,
vector_type
<
float
,
4
>::
MemoryType
&
c0
,
vector_type
<
float
,
4
>::
MemoryType
&
c1
,
vector_type
<
float
,
4
>::
MemoryType
&
c2
,
vector_type
<
float
,
4
>::
MemoryType
&
c3
)
{
const
vector_type
<
half
,
4
>*
reg_a
=
reinterpret_cast
<
const
vector_type
<
half
,
4
>*>
(
&
a
);
outerProduct1x4Half
(
reg_a
[
0
],
b
,
c0
);
outerProduct1x4Half
(
reg_a
[
1
],
b
,
c1
);
outerProduct1x4Half
(
reg_a
[
2
],
b
,
c2
);
outerProduct1x4Half
(
reg_a
[
3
],
b
,
c3
);
}
__device__
void
outerProduct8x8
(
const
vector_type
<
float
,
4
>::
MemoryType
*
a
,
const
vector_type
<
float
,
4
>::
MemoryType
*
b
,
vector_type
<
float
,
4
>::
MemoryType
*
c
)
...
...
composable_kernel/include/utility/bfloat16_dev.hpp
0 → 100644
View file @
2185affb
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2019 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef BFLOAT16_DEVICE_HPP
#define BFLOAT16_DEVICE_HPP
#ifdef __cplusplus
extern
"C"
{
#endif
#ifdef __HIP_PLATFORM_HCC__
#define EXECUTION_SPECIFIER __device__
#else
#define EXECUTION_SPECIFIER
#endif // MIOPEN_BACKEND_HIP
typedef
union
{
uint
u32
;
ushort2
ushortx2
;
// Composable kernels are written in HIP language. The language doesnt support
// ushort2.hi or ushort2.low.
#ifdef __HIP_PLATFORM_HCC__
ushort
ushortvec
[
2
];
#endif // MIOPEN_BACKEND_HIP
float
f32
;
}
cvt_bf16_fp32_t
;
EXECUTION_SPECIFIER
float
bfloat16_to_float
(
ushort
src_val
)
{
cvt_bf16_fp32_t
target_val
;
#ifdef __HIP_PLATFORM_HCC__
target_val
.
ushortx2
=
make_ushort2
(
0
,
src_val
);
#else
target_val
.
ushortx2
=
(
ushort2
)(
0
,
src_val
);
#endif
return
target_val
.
f32
;
}
EXECUTION_SPECIFIER
ushort
float_to_bfloat16
(
float
src_val
)
{
cvt_bf16_fp32_t
target_val
;
target_val
.
f32
=
src_val
;
// BF16 round and NaN preservation code matches
// https://github.com/ROCmSoftwarePlatform/rocBLAS/blob/develop/library/include/rocblas_bfloat16.h
if
((
~
target_val
.
u32
&
0x7f800000
)
==
0
)
// Inf or NaN
{
// When all of the exponent bits are 1, the value is Inf or NaN.
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
// bit being 1. Signaling NaN is indicated by the most significant
// mantissa bit being 0 but some other bit(s) being 1. If any of the
// lower 16 bits of the mantissa are 1, we set the least significant bit
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
// the bloat16's mantissa bits are all 0.
if
((
target_val
.
u32
&
0xffff
)
!=
0
)
{
target_val
.
u32
|=
0x10000
;
// Preserve signaling NaN
}
}
else
{
#ifdef MIOPEN_USE_RNE_BFLOAT16
// When the exponent bits are not all 1s, then the value is zero, normal,
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
// least significant bits of the float mantissa are greater than 0x8000,
// or if they are equal to 0x8000 and the least significant bit of the
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
// has the value 0x7f, then incrementing it causes it to become 0x00 and
// the exponent is incremented by one, which is the next higher FP value
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
// with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
// incrementing it causes it to become an exponent of 0xFF and a mantissa
// of 0x00, which is Inf, the next higher value to the unrounded value.
#ifdef __HIP_PLATFORM_HCC__
target_val
.
u32
+=
(
0x7fff
+
(
target_val
.
ushortvec
[
0
]
&
1
));
#else
target_val
.
u32
+=
(
0x7fff
+
(
target_val
.
ushortx2
.
hi
&
1
));
// Round to nearest, round to even
#endif // MIOPEN_BACKEND_HIP
#endif // MIOPEN_USE_RNE_BFLOAT16
}
#ifdef __HIP_PLATFORM_HCC__
return
target_val
.
ushortvec
[
0
];
#else
return
target_val
.
ushortx2
.
hi
;
#endif // MIOPEN_BACKEND_HIP
}
#ifdef __cplusplus
}
#endif
#endif // BFLOAT16_DEVICE_HPP
composable_kernel/include/utility/common_header.hpp
View file @
2185affb
#ifndef CK_COMMON_HEADER_HPP
#define CK_COMMON_HEADER_HPP
#define MIOPEN_USE_FP16 1
#define MIOPEN_USE_BFP16 0
#define MIOPEN_USE_FP32 0
#define __HIP_PLATFORM_HCC__ 1
#include "config.hpp"
#include "utility.hpp"
#include "integral_constant.hpp"
...
...
composable_kernel/include/utility/float_types.h
0 → 100644
View file @
2185affb
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2019 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef FLOAT_TYPES_HPP
#define FLOAT_TYPES_HPP
#include "bfloat16_dev.hpp"
#define PPCAT_NX(A, B) A##B
#define PPCAT(A, B) PPCAT_NX(A, B)
#define TWO 2
#define FOUR 4
#define EIGHT 8
#if MIOPEN_USE_FP16 == 1
#ifdef __HIP_PLATFORM_HCC__
#define FLOAT half
#define FLOAT_ACCUM float
#else
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#define _FLOAT half
#define _FLOAT_ACCUM float
#endif // __HIP_PLATFORM_HCC__
#define SIZEOF_FLOAT 2
/* sizeof is unavailable for preprocessor */
#ifndef HALF_MAX
#define MAX_VAL 65504
/* max value */
#else
#define MAX_VAL HALF_MAX
#endif // HALF_MAX
#endif // MIOPEN_USE_FP16
#if MIOPEN_USE_FP32 == 1
#ifdef __HIP_PLATFORM_HCC__
#define FLOAT float
#define FLOAT_ACCUM float
#else
#define _FLOAT float
#define _FLOAT_ACCUM float
#endif // __HIP_PLATFORM_HCC__
#define SIZEOF_FLOAT 4
/* sizeof is unavailable for preprocessor */
#ifndef FLT_MAX
#define MAX_VAL 3.402823466e+38F
/* max value */
#else
#define MAX_VAL FLT_MAX
#endif // FLT_MAX
#endif // MIOPEN_USE_FP32
#if MIOPEN_USE_BFP16 == 1
#ifdef __HIP_PLATFORM_HCC__
#define FLOAT ushort
#define FLOAT_ACCUM float
#else
#define _FLOAT ushort
#define _FLOAT_ACCUM float
#endif //
#define SIZEOF_FLOAT 2
/* sizeof is unavailable for preprocessor */
#define MAX_VAL 0x7F7F
/* max value */
#endif // MIOPEN_USE_BFP16
#if MIOPEN_USE_FP16 == 1
#ifdef __HIP_PLATFORM_HCC__
#define CVT_FLOAT2ACCUM(x) (static_cast<FLOAT_ACCUM>(x))
#define CVT_ACCUM2FLOAT(x) (static_cast<FLOAT>(x))
#else
#define CVT_FLOAT2ACCUM(x) ((_FLOAT_ACCUM)(x))
#define CVT_ACCUM2FLOAT(x) ((_FLOAT)(x))
#endif // MIOPEN_BACKEND_HIP
#endif // MIOPEN_USE_FP16
#if MIOPEN_USE_FP32 == 1
#ifdef __HIP_PLATFORM_HCC__
#define CVT_FLOAT2ACCUM(x) (static_cast<FLOAT_ACCUM>(x))
#define CVT_ACCUM2FLOAT(x) (static_cast<FLOAT>(x))
#else
#define CVT_FLOAT2ACCUM(x) ((_FLOAT_ACCUM)(x))
#define CVT_ACCUM2FLOAT(x) ((_FLOAT)(x))
#endif
#endif // MIOPEN_USE_FP32
#if MIOPEN_USE_BFP16 == 1
#define CVT_FLOAT2ACCUM(x) bfloat16_to_float(x)
#define CVT_ACCUM2FLOAT(x) float_to_bfloat16(x)
#endif
#ifndef __HIP_PLATFORM_HCC__
#define _FLOAT2 PPCAT(_FLOAT, TWO)
#endif
#endif // FLOAT_TYPES_HPP
composable_kernel/include/utility/integral_constant.hpp
View file @
2185affb
...
...
@@ -13,64 +13,30 @@ struct integral_constant
__host__
__device__
constexpr
value_type
operator
()()
const
noexcept
{
return
value
;
}
};
template
<
class
X
,
class
Y
>
struct
is_same
:
public
integral_constant
<
bool
,
false
>
{
};
template
<
class
X
>
struct
is_same
<
X
,
X
>
:
public
integral_constant
<
bool
,
true
>
template
<
class
T
,
T
X
,
T
Y
>
__host__
__device__
constexpr
auto
operator
+
(
integral_constant
<
T
,
X
>
,
integral_constant
<
T
,
Y
>
)
{
};
template
<
index_t
N
>
using
Number
=
integral_constant
<
index_t
,
N
>
;
template
<
index_t
X
,
index_t
Y
>
__host__
__device__
constexpr
auto
operator
+
(
Number
<
X
>
,
Number
<
Y
>
)
{
return
Number
<
X
+
Y
>
{};
return
integral_constant
<
T
,
X
+
Y
>
{};
}
template
<
index_t
X
,
index_t
Y
>
__host__
__device__
constexpr
auto
operator
-
(
Number
<
X
>
,
Number
<
Y
>
)
template
<
class
T
,
T
X
,
T
Y
>
__host__
__device__
constexpr
auto
operator
*
(
integral_constant
<
T
,
X
>
,
integral_constant
<
T
,
Y
>
)
{
static_assert
(
Y
<=
X
,
"wrong!"
);
return
Number
<
X
-
Y
>
{};
return
integral_constant
<
T
,
X
*
Y
>
{};
}
template
<
index_t
X
,
index_t
Y
>
__host__
__device__
constexpr
auto
operator
*
(
Number
<
X
>
,
Number
<
Y
>
)
{
return
Number
<
X
*
Y
>
{};
}
template
<
index_t
N
>
using
Number
=
integral_constant
<
index_t
,
N
>
;
template
<
index_t
X
,
index_t
Y
>
__host__
__device__
constexpr
auto
operator
/
(
Number
<
X
>
,
Number
<
Y
>
)
template
<
class
X
,
class
Y
>
struct
is_same
:
public
integral_constant
<
bool
,
false
>
{
static_assert
(
Y
>
0
,
"wrong!"
);
return
Number
<
X
/
Y
>
{};
}
};
template
<
index_t
X
,
index_t
Y
>
__host__
__device__
constexpr
auto
operator
%
(
Number
<
X
>
,
Number
<
Y
>
)
template
<
class
X
>
struct
is_same
<
X
,
X
>
:
public
integral_constant
<
bool
,
true
>
{
static_assert
(
Y
>
0
,
"wrong!"
);
return
Number
<
X
%
Y
>
{};
}
#if 0
static constexpr Number<0> 0_c;
static constexpr Number<1> 1_c;
static constexpr Number<2> 2_c;
static constexpr Number<3> 3_c;
static constexpr Number<4> 4_c;
static constexpr Number<5> 5_c;
static constexpr Number<6> 6_c;
static constexpr Number<7> 7_c;
static constexpr Number<8> 8_c;
static constexpr Number<9> 9_c;
#endif
};
}
// namespace ck
#endif
composable_kernel/include/utility/math.hpp
View file @
2185affb
...
...
@@ -42,16 +42,20 @@ struct integer_divide_ceiler
}
};
template
<
class
X
,
class
Y
>
__host__
__device__
constexpr
auto
integer_divide_ceil
(
X
x
,
Y
y
)
template
<
class
T
>
__host__
__device__
constexpr
T
integer_divide_ceil
(
T
a
,
T
b
)
{
return
(
x
+
y
-
1
)
/
y
;
static_assert
(
is_same
<
T
,
index_t
>
{}
||
is_same
<
T
,
int
>
{},
"wrong type"
);
return
(
a
+
b
-
1
)
/
b
;
}
template
<
class
X
,
class
Y
>
__host__
__device__
constexpr
auto
integer_least_multiple
(
X
x
,
Y
y
)
template
<
class
T
>
__host__
__device__
constexpr
T
integer_least_multiple
(
T
a
,
T
b
)
{
return
y
*
integer_divide_ceil
(
x
,
y
);
static_assert
(
is_same
<
T
,
index_t
>
{}
||
is_same
<
T
,
int
>
{},
"wrong type"
);
return
b
*
integer_divide_ceil
(
a
,
b
);
}
template
<
class
T
>
...
...
composable_kernel/include/utility/vector_type.hpp
View file @
2185affb
#ifndef CK_VECTOR_TYPE_HPP
#define CK_VECTOR_TYPE_HPP
#include "cuda_fp16.h"
#include "config.hpp"
#include "integral_constant.hpp"
...
...
@@ -9,12 +10,15 @@ namespace ck {
template
<
class
T
,
index_t
N
>
struct
vector_type
{
T
vector
[
N
];
};
template
<
>
struct
vector_type
<
float
,
1
>
{
typedef
float
MemoryType
;
using
MemoryType
=
float
;
__host__
__device__
static
constexpr
index_t
GetSize
()
{
return
1
;
}
template
<
index_t
I
>
__host__
__device__
static
void
SetScalar
(
MemoryType
&
v
,
float
s
,
Number
<
I
>
)
...
...
@@ -29,6 +33,8 @@ struct vector_type<float, 2>
{
using
MemoryType
=
float2_t
;
__host__
__device__
static
constexpr
index_t
GetSize
()
{
return
2
;
}
union
Data
{
MemoryType
vector
;
...
...
@@ -42,13 +48,6 @@ struct vector_type<float, 2>
*
(
reinterpret_cast
<
float
*>
(
&
v
)
+
I
)
=
s
;
}
__host__
__device__
static
MemoryType
Pack
(
float
s0
,
float
s1
)
{
Data
data
;
data
.
scalar
[
0
]
=
s0
;
data
.
scalar
[
1
]
=
s1
;
return
data
.
vector
;
}
};
template
<
>
...
...
@@ -56,6 +55,8 @@ struct vector_type<float, 4>
{
using
MemoryType
=
float4_t
;
__host__
__device__
static
constexpr
index_t
GetSize
()
{
return
4
;
}
template
<
index_t
I
>
__host__
__device__
static
void
SetScalar
(
MemoryType
&
v
,
float
s
,
Number
<
I
>
)
{
...
...
@@ -64,6 +65,116 @@ struct vector_type<float, 4>
}
};
template
<
>
struct
vector_type
<
half
,
1
>
{
using
MemoryType
=
half
;
__host__
__device__
static
constexpr
index_t
GetSize
()
{
return
1
;
}
template
<
index_t
I
>
__host__
__device__
static
void
SetScalar
(
MemoryType
&
v
,
half
s
,
Number
<
I
>
)
{
static_assert
(
I
<
1
,
"wrong"
);
*
(
reinterpret_cast
<
half
*>
(
&
v
)
+
I
)
=
s
;
}
};
template
<
>
struct
vector_type
<
half
,
2
>
{
using
MemoryType
=
half2
;
union
Data
{
MemoryType
vector
;
half
scalar
[
2
];
};
__host__
__device__
static
constexpr
index_t
GetSize
()
{
return
2
;
}
template
<
index_t
I
>
__host__
__device__
static
void
SetScalar
(
MemoryType
&
v
,
half
s
,
Number
<
I
>
)
{
static_assert
(
I
<
2
,
"wrong"
);
*
(
reinterpret_cast
<
half
*>
(
&
v
)
+
I
)
=
s
;
}
};
template
<
>
struct
vector_type
<
half
,
4
>
{
typedef
struct
MemoryType
{
half2
vector
[
2
];
}
MemoryType
;
__host__
__device__
static
constexpr
index_t
GetSize
()
{
return
4
;
}
template
<
index_t
I
>
__host__
__device__
static
void
SetScalar
(
MemoryType
&
v
,
half
s
,
Number
<
I
>
)
{
static_assert
(
I
<
4
,
"wrong"
);
*
(
reinterpret_cast
<
half
*>
(
&
v
)
+
I
)
=
s
;
}
};
template
<
>
struct
vector_type
<
ushort
,
1
>
{
using
MemoryType
=
ushort
;
__host__
__device__
static
constexpr
index_t
GetSize
()
{
return
1
;
}
template
<
index_t
I
>
__host__
__device__
static
void
SetScalar
(
MemoryType
&
v
,
ushort
s
,
Number
<
I
>
)
{
static_assert
(
I
<
1
,
"wrong"
);
*
(
reinterpret_cast
<
ushort
*>
(
&
v
)
+
I
)
=
s
;
}
};
template
<
>
struct
vector_type
<
ushort
,
2
>
{
using
MemoryType
=
ushort2
;
union
Data
{
MemoryType
vector
;
half
scalar
[
2
];
};
__host__
__device__
static
constexpr
index_t
GetSize
()
{
return
2
;
}
template
<
index_t
I
>
__host__
__device__
static
void
SetScalar
(
MemoryType
&
v
,
ushort
s
,
Number
<
I
>
)
{
static_assert
(
I
<
2
,
"wrong"
);
*
(
reinterpret_cast
<
ushort
*>
(
&
v
)
+
I
)
=
s
;
}
};
template
<
>
struct
vector_type
<
ushort
,
4
>
{
typedef
struct
MemoryType
{
ushort2
vector
[
2
];
}
MemoryType
;
__host__
__device__
static
constexpr
index_t
GetSize
()
{
return
4
;
}
template
<
index_t
I
>
__host__
__device__
static
void
SetScalar
(
MemoryType
&
v
,
ushort
s
,
Number
<
I
>
)
{
static_assert
(
I
<
4
,
"wrong"
);
*
(
reinterpret_cast
<
ushort
*>
(
&
v
)
+
I
)
=
s
;
}
};
}
// namespace ck
#endif
driver/include/device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp
View file @
2185affb
#pragma once
#include <unistd.h>
#define MIOPEN_USE_FP16 1
#define MIOPEN_USE_BFP16 0
#define MIOPEN_USE_FP32 0
#define __HIP_PLATFORM_HCC__ 1
#include "float_types.h"
#include "device.hpp"
#include "tensor.hpp"
#include "gridwise_convolution_kernel_wrapper.hpp"
#include "gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp"
#include "gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw_lds_double_buffer.hpp"
#include "gridwise_convolution_implicit_gemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_buffer.hpp"
#define CK_PARAM_TUNABLE_K_PER_BLOCK 64
using
namespace
ck
;
...
...
@@ -24,6 +35,10 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
ConvDilations
,
index_t
nrepeat
)
{
// read params: problem decription
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
...
...
@@ -59,16 +74,22 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
B
=
(
N
*
Ho
*
Wo
)
/
(
N1
*
N2
);
#if 1
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
EPerBlock
=
8
;
constexpr
index_t
KPerBlock
=
K
%
128
==
0
?
128
:
(
K
%
64
==
0
?
64
:
32
);
constexpr
index_t
BlockSize
=
K
%
128
==
0
?
256
:
(
K
%
64
==
0
?
128
:
64
);
#if MIOPEN_USE_FP16 == 1
// ES set to 4 as dot4 operator is supported on fp16 in MI100
constexpr
index_t
ES
=
4
;
#elif MIOPEN_USE_BFP16 == 1
// ES set to 2 as dot2 operator is supported on bfp16 in MI100
constexpr
index_t
ES
=
2
;
#else
// do nothing
#endif
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
...
...
@@ -76,92 +97,103 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
1
,
1
,
1
,
4
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
8
,
2
,
16
,
1
>
;
#if MIOPEN_USE_FP32 == 1
using
InBlockCopyThreadClusterArrangeOrder
=
Sequence
<
0
,
1
,
3
,
2
>
;
// [E, N1, N2, B]
using
InBlockCopySrcAccessOrder
=
Sequence
<
0
,
1
,
3
,
2
>
;
// [E, N1, N2, B]
using
InBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
,
2
,
3
>
;
// [E, N1, B, N2]
constexpr
index_t
InBlockCopySrcDataPerRead_B
=
1
;
constexpr
index_t
InBlockCopyDstDataPerWrite_N2
=
4
;
using
WeiBlockCopySubLengths_E_K
=
Sequence
<
4
,
1
>
;
using
WeiBlockCopyClusterLengths_E_K
=
Sequence
<
2
,
128
>
;
using
WeiBlockCopyThreadClusterArrangeOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
using
WeiBlockCopySrcAccessOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
using
WeiBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
>
;
// [E, K]
#elif MIOPEN_USE_FP16 == 1 || MIOPEN_USE_BFP16 == 1
// ES - E dimension is folded into 2 dimensions E and ES
using
InBlockCopyThreadClusterArrangeOrder
=
Sequence
<
0
,
1
,
3
,
2
,
4
>
;
// [E, N1, N2, B, ES]
using
InBlockCopySrcAccessOrder
=
Sequence
<
0
,
1
,
3
,
2
,
4
>
;
// [E, N1, N2, B, ES]
using
InBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
,
2
,
3
,
4
>
;
// [E, N1, B, N2, ES]
using
WeiBlockCopyThreadClusterArrangeOrder
=
Sequence
<
1
,
0
,
2
>
;
// [K, E, ES]
using
WeiBlockCopySrcAccessOrder
=
Sequence
<
1
,
0
,
2
>
;
// [K, E, ES]
using
WeiBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
,
2
>
;
// [E, K, ES]
#endif
#if CK_PARAM_TUNABLE_K_PER_BLOCK == 32
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
4
;
constexpr
index_t
EPerBlock
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
1
;
constexpr
index_t
InBlockCopySrcDataPerRead_B
=
1
;
constexpr
index_t
InBlockCopyDstDataPerWrite_N2
=
1
;
#if MIOPEN_USE_FP32 == 1
// all_of(X_Per_Block % (X_Sub_Length * X_Cluster_Length) == 0)
// accumulate(X_Cluster_Lengths, multiply) == BlockSize
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
1
,
2
,
1
,
4
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
4
,
1
,
16
,
1
>
;
using
WeiBlockCopySubLengths_E_K
=
Sequence
<
2
,
1
>
;
using
WeiBlockCopyClusterLengths_E_K
=
Sequence
<
2
,
32
>
;
#elif MIOPEN_USE_FP16 == 1 || MIOPEN_USE_BFP16 == 1
using
InBlockCopySubLengths_E_N1_B_N2_ES
=
Sequence
<
1
,
2
,
1
,
4
,
ES
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2_ES
=
Sequence
<
4
,
1
,
16
,
1
,
1
>
;
using
WeiBlockCopySubLengths_E_K_ES
=
Sequence
<
2
,
1
,
ES
>
;
using
WeiBlockCopyClusterLengths_E_K_ES
=
Sequence
<
2
,
32
,
1
>
;
#endif // MIOPEN_USE_FP32 == 1
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
1
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
1
;
#elif 0
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
EPerBlock
=
8
;
#elif CK_PARAM_TUNABLE_K_PER_BLOCK == 64
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
constexpr
index_t
EPerBlock
=
8
;
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
1
,
1
,
4
,
1
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
8
,
2
,
4
,
4
>
;
using
InBlockCopyThreadClusterArrangeOrder
=
Sequence
<
0
,
1
,
3
,
2
>
;
// [E, N1, N2, B]
using
InBlockCopySrcAccessOrder
=
Sequence
<
0
,
1
,
3
,
2
>
;
// [E, N1, N2, B]
using
InBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
,
2
,
3
>
;
// [E, N1, B, N2]
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
InBlockCopySrcDataPerRead_B
=
4
;
constexpr
index_t
InBlockCopySrcDataPerRead_B
=
1
;
constexpr
index_t
InBlockCopyDstDataPerWrite_N2
=
1
;
#if MIOPEN_USE_FP32 == 1
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
1
,
2
,
1
,
4
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
8
,
1
,
16
,
1
>
;
using
WeiBlockCopySubLengths_E_K
=
Sequence
<
4
,
1
>
;
using
WeiBlockCopyClusterLengths_E_K
=
Sequence
<
2
,
128
>
;
using
WeiBlockCopyThreadClusterArrangeOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
using
WeiBlockCopySrcAccessOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
using
WeiBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
>
;
// [E, K]
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
4
;
using
WeiBlockCopyClusterLengths_E_K
=
Sequence
<
2
,
64
>
;
#elif MIOPEN_USE_FP16 == 1 || MIOPEN_USE_BFP16 == 1
// ES - E dimension is folded into 2 dimensions E and ES
using
InBlockCopySubLengths_E_N1_B_N2_ES
=
Sequence
<
1
,
2
,
1
,
4
,
ES
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2_ES
=
Sequence
<
8
,
1
,
16
,
1
,
1
>
;
using
WeiBlockCopySubLengths_E_K_ES
=
Sequence
<
4
,
1
,
ES
>
;
using
WeiBlockCopyClusterLengths_E_K_ES
=
Sequence
<
2
,
64
,
1
>
;
#endif // MIOPEN_USE_FP32 == 1
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
1
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
1
;
#elif 1
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
128
;
#elif CK_PARAM_TUNABLE_K_PER_BLOCK == 128
constexpr
index_t
EPerBlock
=
8
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
1
,
1
,
2
,
2
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
8
,
2
,
8
,
2
>
;
using
InBlockCopyThreadClusterArrangeOrder
=
Sequence
<
0
,
1
,
3
,
2
>
;
// [E, N1, N2, B]
using
InBlockCopySrcAccessOrder
=
Sequence
<
0
,
1
,
3
,
2
>
;
// [E, N1, N2, B]
using
InBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
,
2
,
3
>
;
// [E, N1, B, N2]
constexpr
index_t
InBlockCopySrcDataPerRead_B
=
2
;
constexpr
index_t
InBlockCopyDstDataPerWrite_N2
=
2
;
constexpr
index_t
InBlockCopySrcDataPerRead_B
=
1
;
constexpr
index_t
InBlockCopyDstDataPerWrite_N2
=
1
;
#if MIOPEN_USE_FP32 == 1
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
1
,
1
,
1
,
4
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
8
,
2
,
16
,
1
>
;
using
WeiBlockCopySubLengths_E_K
=
Sequence
<
4
,
1
>
;
using
WeiBlockCopyClusterLengths_E_K
=
Sequence
<
2
,
128
>
;
using
WeiBlockCopyThreadClusterArrangeOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
using
WeiBlockCopySrcAccessOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
using
WeiBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
>
;
// [E, K]
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
4
;
#elif MIOPEN_USE_FP16 == 1 || MIOPEN_USE_BFP16 == 1
// ES - E dimension is folded into 2 dimensions E and ES
using
InBlockCopySubLengths_E_N1_B_N2_ES
=
Sequence
<
1
,
1
,
1
,
4
,
ES
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2_ES
=
Sequence
<
8
,
2
,
16
,
1
,
1
>
;
using
WeiBlockCopySubLengths_E_K_ES
=
Sequence
<
4
,
1
,
ES
>
;
using
WeiBlockCopyClusterLengths_E_K_ES
=
Sequence
<
2
,
128
,
1
>
;
#endif // MIOPEN_USE_FP32 == 1
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
1
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
1
;
#endif
#else
static_assert
(
false
,
"wrong! Only kperblock could be 32/64/128 not supported"
);
#endif // CK_PARAM_TUNABLE_K_PER_BLOCK == 32
constexpr
index_t
GridSize
=
((
B
+
BPerBlock
-
1
)
/
BPerBlock
)
*
((
K
+
KPerBlock
-
1
)
/
KPerBlock
);
...
...
@@ -171,14 +203,12 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
constexpr
auto
gridwise_conv
=
#if 0
GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw
#else
GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
#endif
<
GridSize
,
#if MIOPEN_USE_FP32 == 1
GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
<
GridSize
,
BlockSize
,
T
,
FLOAT
,
FLOAT_ACCUM
,
decltype
(
in_nchw_desc
),
decltype
(
wei_kcyx_desc
),
decltype
(
out_nkhw_desc
),
...
...
@@ -212,6 +242,47 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
WeiBlockCopyDstAccessOrder
,
WeiBlockCopySrcDataPerRead_E
,
WeiBlockCopyDstDataPerWrite_K
>
{};
#elif MIOPEN_USE_FP16 == 1 || MIOPEN_USE_BFP16 == 1
GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_buffer
<
GridSize
,
BlockSize
,
half
,
float
,
decltype
(
in_nchw_desc
),
decltype
(
wei_kcyx_desc
),
decltype
(
out_nkhw_desc
),
ConvStrides
,
ConvDilations
,
BPerBlock
,
KPerBlock
,
EPerBlock
,
N1
,
N2
,
ES
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmDataPerReadA
,
GemmDataPerReadB
,
InBlockCopySubLengths_E_N1_B_N2_ES
,
InBlockCopyClusterLengths_E_N1_B_N2_ES
,
InBlockCopyThreadClusterArrangeOrder
,
InBlockCopySrcAccessOrder
,
InBlockCopyDstAccessOrder
,
InBlockCopySrcDataPerRead_B
,
InBlockCopyDstDataPerWrite_N2
,
WeiBlockCopySubLengths_E_K_ES
,
WeiBlockCopyClusterLengths_E_K_ES
,
WeiBlockCopyThreadClusterArrangeOrder
,
WeiBlockCopySrcAccessOrder
,
WeiBlockCopyDstAccessOrder
,
WeiBlockCopySrcDataPerRead_E
,
WeiBlockCopyDstDataPerWrite_K
>
{};
#endif
float
time
=
launch_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
dim3
(
GridSize
),
...
...
driver/src/driver.cpp
View file @
2185affb
...
...
@@ -790,13 +790,13 @@ int main(int argc, char* argv[])
#elif 1
// 1x1 filter, 7x7 image
// cudnn@V100 49%, ck@V100 50%, ck@P100 61%, ck@VII 52%
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
832
;
constexpr
index_t
HI
=
7
;
constexpr
index_t
WI
=
7
;
constexpr
index_t
K
=
12
8
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
constexpr
index_t
N
=
32
;
constexpr
index_t
C
=
128
;
constexpr
index_t
HI
=
28
;
constexpr
index_t
WI
=
28
;
constexpr
index_t
K
=
1
9
2
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
...
...
@@ -817,8 +817,8 @@ int main(int argc, char* argv[])
ostream_ConstantTensorDescriptor
(
wei_kcyx_desc
,
std
::
cout
<<
"wei_kcyx_desc: "
);
ostream_ConstantTensorDescriptor
(
out_nkhw_desc
,
std
::
cout
<<
"out_nkhw_desc: "
);
using
in_data_t
=
float
;
using
out_data_t
=
float
;
using
in_data_t
=
half
;
using
out_data_t
=
half
;
Tensor
<
in_data_t
>
in_nchw
(
make_TensorDescriptor
(
in_nchw_desc
));
Tensor
<
in_data_t
>
wei_kcyx
(
make_TensorDescriptor
(
wei_kcyx_desc
));
Tensor
<
out_data_t
>
out_nkhw_host
(
make_TensorDescriptor
(
out_nkhw_desc
));
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment