Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
32850b93
Commit
32850b93
authored
Oct 09, 2019
by
Wen-Heng (Jack) Chung
Browse files
Ported xdlops kernels to debug bwdwrw fp32/fp16/bfp16 issue. Verified atleast fwd data fp32 works.
parent
583755a7
Changes
37
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
5066 additions
and
1207 deletions
+5066
-1207
composable_kernel/include/implicitgemm_params.hpp
composable_kernel/include/implicitgemm_params.hpp
+43
-0
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_buffer.hpp
...t_gemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_buffer.hpp
+496
-0
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4_nchw_kc1x1_nkhw_lds_double_buffer.hpp
...on_implicit_gemm_v4_nchw_kc1x1_nkhw_lds_double_buffer.hpp
+454
-0
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw_lds_double_buffer.hpp
...ion_implicit_gemm_v4_nchw_kcyx_nkhw_lds_double_buffer.hpp
+105
-58
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_xdlops_fp16_bfp16_nchw_kcyx_nkhw_lds_double_buffer.hpp
...r4_xdlops_fp16_bfp16_nchw_kcyx_nkhw_lds_double_buffer.hpp
+454
-0
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_xdlops_nchw_kc1x1_nkhw_lds_double_buffer.hpp
...it_gemm_v4r4_xdlops_nchw_kc1x1_nkhw_lds_double_buffer.hpp
+400
-0
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw_lds_double_buffer.hpp
...cit_gemm_v4r4_xdlops_nchw_kcyx_nkhw_lds_double_buffer.hpp
+409
-0
composable_kernel/include/tensor_description/ConstantMatrixDescriptor.hpp
...l/include/tensor_description/ConstantMatrixDescriptor.hpp
+13
-1
composable_kernel/include/tensor_description/ConstantMergedTensorDescriptor.hpp
...ude/tensor_description/ConstantMergedTensorDescriptor.hpp
+22
-5
composable_kernel/include/tensor_description/ConstantTensorDescriptor.hpp
...l/include/tensor_description/ConstantTensorDescriptor.hpp
+45
-21
composable_kernel/include/tensor_description/tensor_coordinate.hpp
...e_kernel/include/tensor_description/tensor_coordinate.hpp
+301
-0
composable_kernel/include/tensor_operation/blockwise_gemm.hpp
...osable_kernel/include/tensor_operation/blockwise_gemm.hpp
+208
-228
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
...kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
+297
-0
composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp
.../tensor_operation/blockwise_generic_tensor_slice_copy.hpp
+557
-26
composable_kernel/include/tensor_operation/threadwise_gemm.hpp
...sable_kernel/include/tensor_operation/threadwise_gemm.hpp
+12
-77
composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp
...tensor_operation/threadwise_generic_tensor_slice_copy.hpp
+761
-82
composable_kernel/include/utility/Array.hpp
composable_kernel/include/utility/Array.hpp
+18
-2
composable_kernel/include/utility/Sequence.hpp
composable_kernel/include/utility/Sequence.hpp
+133
-74
composable_kernel/include/utility/amd_inline_asm.hpp
composable_kernel/include/utility/amd_inline_asm.hpp
+336
-633
composable_kernel/include/utility/bfloat16_dev.hpp
composable_kernel/include/utility/bfloat16_dev.hpp
+2
-0
No files found.
composable_kernel/include/implicitgemm_params.hpp
0 → 100644
View file @
32850b93
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2019 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef GUARD_MIOPEN_IMPLICITGEMM_PARMS_HPP_
#define GUARD_MIOPEN_IMPLICITGEMM_PARMS_HPP_
enum
struct
ImplicitGemmDirection
{
ForwardData
,
BackwardData
,
BackwardWeight
};
enum
struct
ImplicitGemmXdlopsKernel
{
KernelFwdWrw
=
0
,
Kernel1x1
=
1
,
};
#endif
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_buffer.hpp
0 → 100644
View file @
32850b93
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_FP16_BFP16_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_FP16_BFP16_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "implicitgemm_params.hpp"
namespace
ck
{
template
<
ImplicitGemmDirection
conv_dir
,
typename
WeiDesc
,
index_t
NonVectorizedC
>
struct
make_vectorized_WeiDesc
{
};
template
<
typename
WeiDesc
,
index_t
NonVectorizedC
>
struct
make_vectorized_WeiDesc
<
ImplicitGemmDirection
::
ForwardData
,
WeiDesc
,
NonVectorizedC
>
{
__device__
constexpr
auto
get
(
WeiDesc
&
)
{
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I4
=
Number
<
4
>
{};
return
WeiDesc
{}
.
Fold
(
I1
,
Number
<
NonVectorizedC
>
{})
.
Unfold
(
I2
,
I4
)
.
ReorderGivenNew2Old
(
Sequence
<
2
,
0
,
1
>
{});
}
};
template
<
typename
WeiDesc
,
index_t
NonVectorizedC
>
struct
make_vectorized_WeiDesc
<
ImplicitGemmDirection
::
BackwardWeight
,
WeiDesc
,
NonVectorizedC
>
{
__device__
constexpr
auto
get
(
WeiDesc
&
desc
)
{
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I4
=
Number
<
4
>
{};
return
make_ConstantMergedTensorDescriptor
(
desc
.
Fold
(
I1
,
Number
<
NonVectorizedC
>
{}).
Unfold
(
I3
,
I4
),
Sequence
<
2
,
3
>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{});
}
};
// define B = merge(N0, Ho, Wo)
template
<
index_t
GridSize
,
index_t
BlockSize
,
class
Float
,
class
AccDataType
,
class
InGlobalDesc
,
class
WeiGlobalDesc
,
class
OutGlobalDesc
,
class
ConvStrides
,
class
ConvDilations
,
index_t
BPerBlock
,
index_t
KPerBlock
,
index_t
EPerBlock
,
index_t
GemmNRepeat
,
index_t
EPACK
,
index_t
GemmMPerThreadSubC
,
index_t
GemmNPerThreadSubC
,
index_t
GemmMLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
GemmKPerThreadLoop
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadB
,
class
InBlockCopySubLengths_E_N1_B_N2_EPACK
,
class
InBlockCopyClusterLengths_E_N1_B_N2_EPACK
,
class
InBlockCopyThreadClusterArrangeOrder
,
class
InBlockCopySrcAccessOrder
,
class
InBlockCopyDstAccessOrder
,
index_t
InBlockCopySrcDataPerRead_B
,
index_t
InBlockCopyDstDataPerWrite_N2
,
class
WeiBlockCopySubLengths_E_K_EPACK
,
class
WeiBlockCopyClusterLengths_E_K_EPACK
,
class
WeiBlockCopyThreadClusterArrangeOrder
,
class
WeiBlockCopySrcAccessOrder
,
class
WeiBlockCopyDstAccessOrder
,
index_t
WeiBlockCopySrcDataPerRead_E
,
index_t
WeiBlockCopyDstDataPerWrite_K
,
ImplicitGemmDirection
conv_dir
>
struct
GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_buffer
{
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
const
{
// this is a mess
// TODO: find more elegent way of specifying (or calculating) performance parameters
constexpr
index_t
N1
=
GemmNRepeat
;
constexpr
index_t
N2
=
GemmNPerThreadSubC
;
static_assert
(
N2
==
GemmNPerThreadSubC
,
"wrong!"
);
static_assert
((
N1
*
N2
*
BPerBlock
)
%
(
GemmNPerThreadSubC
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
)
==
0
,
"wrong!"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I5
=
Number
<
5
>
{};
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
in_n_c_h_w_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_k_c_y_x_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_n_k_h_w_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
N
=
in_n_c_h_w_global_desc
.
GetLength
(
I0
);
constexpr
index_t
C
=
in_n_c_h_w_global_desc
.
GetLength
(
I1
);
constexpr
index_t
K
=
out_n_k_h_w_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Ho
=
out_n_k_h_w_global_desc
.
GetLength
(
I2
);
constexpr
index_t
Wo
=
out_n_k_h_w_global_desc
.
GetLength
(
I3
);
constexpr
index_t
Y
=
wei_k_c_y_x_global_desc
.
GetLength
(
I2
);
constexpr
index_t
X
=
wei_k_c_y_x_global_desc
.
GetLength
(
I3
);
static_assert
(
N
%
(
N1
*
N2
)
==
0
,
"wrong! cannot divice N evenly among thread"
);
constexpr
index_t
N0
=
N
/
(
N1
*
N2
);
constexpr
index_t
B
=
N0
*
Ho
*
Wo
;
// EPACK=1 for float32, =2 for bfloat16, =4 for float16
static_assert
(
C
%
EPACK
==
0
,
"C needs to be multiple of vectorized C (EPACK)"
);
constexpr
auto
nonVectorizedC
=
C
/
EPACK
;
constexpr
index_t
E
=
nonVectorizedC
*
Y
*
X
;
// divide block work by [K, B]
static_assert
(
K
%
KPerBlock
==
0
&&
B
%
BPerBlock
==
0
&&
E
%
(
2
*
EPerBlock
)
==
0
,
"wrong! cannot divide work evenly among block"
);
constexpr
index_t
KBlockWork
=
K
/
KPerBlock
;
constexpr
index_t
BBlockWork
=
B
/
BPerBlock
;
constexpr
index_t
InBlockCopyDstDataPerWrite_EPACK
=
EPACK
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_EPACK
=
EPACK
;
constexpr
auto
block_work_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
KBlockWork
,
BBlockWork
>
{});
const
auto
block_work_multi_id
=
block_work_desc
.
GetMultiIndexFrom1dIndex
(
get_block_1d_id
());
const
index_t
k_block_data_on_global
=
block_work_multi_id
[
0
]
*
KPerBlock
;
const
index_t
b_block_data_on_global
=
block_work_multi_id
[
1
]
*
BPerBlock
;
// input tensor
// tensor descriptor in device memory [N0, N1, N2, Ho, Wo, {2C/4C}]
constexpr
auto
in_n0_n1_n2_h_w_2cor4c_global_desc
=
in_n_c_h_w_global_desc
.
StridedSlice
(
I2
,
Number
<
Ho
>
{},
Number
<
ConvStrides
::
Get
(
I0
)
>
{})
.
StridedSlice
(
I3
,
Number
<
Wo
>
{},
Number
<
ConvStrides
::
Get
(
I1
)
>
{})
.
Fold
(
I1
,
Number
<
nonVectorizedC
>
{})
.
Fold
(
I0
,
Number
<
N1
>
{},
Number
<
N2
>
{})
.
Extract
(
Sequence
<
0
,
1
,
2
,
3
,
5
,
6
>
{})
.
ReorderGivenNew2Old
(
Sequence
<
0
,
1
,
2
,
4
,
5
,
3
>
{});
// batch descritpor for device memory
constexpr
auto
in_c_y_x_global_desc
=
in_n_c_h_w_global_desc
.
StridedSlice
(
I2
,
Number
<
Y
>
{},
Number
<
ConvDilations
::
Get
(
I0
)
>
{})
.
StridedSlice
(
I3
,
Number
<
X
>
{},
Number
<
ConvDilations
::
Get
(
I1
)
>
{})
.
Fold
(
I1
,
Number
<
nonVectorizedC
>
{})
.
Extract
(
Sequence
<
2
,
3
,
4
>
{});
// merged tensor descriptor in device memory [E, N1, B, N2, {2E/4E}], src of blockwise
// copy
constexpr
auto
in_e_n1_b_n2_2eor4e_global_merged_desc
=
make_ConstantMergedTensorDescriptor
(
in_c_y_x_global_desc
.
Embed
(
in_n0_n1_n2_h_w_2cor4c_global_desc
),
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
4
>
{},
Sequence
<
3
,
6
,
7
>
{},
Sequence
<
5
>
{},
Sequence
<
8
>
{});
// memory layout descriptor in LDS [E, N1, B, N2, {2C/4C}], dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
in_e_n1_b_n2_2eor4e_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
EPerBlock
,
N1
,
BPerBlock
,
N2
,
EPACK
>
{},
Number
<
InBlockCopyDstDataPerWrite_EPACK
>
{});
// this check for GEMM is ad-hoc
// TODO: need to properly implement tensor descriptor with multiple alignment
// requirements
static_assert
(
in_e_n1_b_n2_2eor4e_block_desc
.
GetStride
(
I1
)
%
(
EPACK
*
GemmDataPerReadB
)
==
0
,
"GemmDataPerReadB alignment requirement is not satisfied"
);
// input blockwise copy
// slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in
auto
blockwise_in_copy
=
BlockwiseGenericTensorSliceCopy_v1
<
BlockSize
,
decltype
(
in_e_n1_b_n2_2eor4e_global_merged_desc
),
decltype
(
in_e_n1_b_n2_2eor4e_block_desc
),
decltype
(
in_e_n1_b_n2_2eor4e_block_desc
.
GetLengths
()),
InBlockCopySubLengths_E_N1_B_N2_EPACK
,
InBlockCopyClusterLengths_E_N1_B_N2_EPACK
,
InBlockCopyThreadClusterArrangeOrder
,
InBlockCopySrcAccessOrder
,
InBlockCopyDstAccessOrder
,
2
,
4
,
InBlockCopySrcDataPerRead_B
,
InBlockCopyDstDataPerWrite_EPACK
>
(
{
0
,
0
,
b_block_data_on_global
,
0
,
0
},
{
0
,
0
,
0
,
0
,
0
});
// weight tensor
// tensor descriptor in device memory, src of blockwise copy
constexpr
auto
wei_e_k_2eor4e_global_desc
=
make_vectorized_WeiDesc
<
conv_dir
,
decltype
(
wei_k_c_y_x_global_desc
),
nonVectorizedC
>
{}
.
get
(
wei_k_c_y_x_global_desc
);
// tensor descriptor in LDS, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
wei_e_k_2eor4e_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
EPerBlock
,
KPerBlock
,
EPACK
>
{},
Number
<
WeiBlockCopyDstDataPerWrite_EPACK
>
{});
// this check for GEMM is ad-hoc
// TODO: need to properly implement tensor descriptor with multiple alignment
// requirements
static_assert
(
wei_e_k_2eor4e_block_desc
.
GetStride
(
I1
)
%
(
EPACK
*
GemmDataPerReadA
)
==
0
,
"GemmDataPerReadA alignment requirement is not satisfied"
);
// operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in
auto
blockwise_wei_copy
=
BlockwiseGenericTensorSliceCopy_v1
<
BlockSize
,
decltype
(
wei_e_k_2eor4e_global_desc
),
decltype
(
wei_e_k_2eor4e_block_desc
),
decltype
(
wei_e_k_2eor4e_block_desc
.
GetLengths
()),
WeiBlockCopySubLengths_E_K_EPACK
,
WeiBlockCopyClusterLengths_E_K_EPACK
,
WeiBlockCopyThreadClusterArrangeOrder
,
WeiBlockCopySrcAccessOrder
,
WeiBlockCopyDstAccessOrder
,
0
,
2
,
WeiBlockCopySrcDataPerRead_E
,
WeiBlockCopyDstDataPerWrite_EPACK
>
(
{
0
,
k_block_data_on_global
,
0
},
{
0
,
0
,
0
});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[EPerBlock, KPerBlock ] is in LDS of type float/bfloat16 vec2/ float16 vec4
// b_mtx[EPerBlocl, N1 * BPerBlock * N2 ] is in LDS of type float/bfloat16 vec2/ float16
// vec4
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
// register
constexpr
auto
a_e_k_block_mtx_desc
=
make_ConstantMatrixDescriptor_packed
(
Number
<
EPerBlock
>
{},
Number
<
KPerBlock
>
{});
constexpr
auto
b_e_n1bn2_block_mtx_desc
=
make_ConstantMatrixDescriptor_packed
(
Number
<
EPerBlock
>
{},
Number
<
N1
*
BPerBlock
*
N2
>
{});
// sanity check
static_assert
(
KPerBlock
%
(
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
)
==
0
,
"wrong!"
);
constexpr
index_t
GemmMRepeat
=
KPerBlock
/
(
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr
auto
c_k0k2_n1n2_thread_mtx_desc
=
make_ConstantMatrixDescriptor_packed
(
Number
<
GemmMRepeat
*
GemmMPerThreadSubC
>
{},
Number
<
N1
*
N2
>
{});
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
<
BlockSize
,
EPACK
,
decltype
(
a_e_k_block_mtx_desc
),
decltype
(
b_e_n1bn2_block_mtx_desc
),
decltype
(
c_k0k2_n1n2_thread_mtx_desc
),
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{};
// LDS allocation for input and weight: be careful of alignment
constexpr
index_t
lds_allocation_align
=
math
::
lcm
(
InBlockCopyDstDataPerWrite_EPACK
,
WeiBlockCopyDstDataPerWrite_EPACK
,
EPACK
*
GemmDataPerReadA
,
EPACK
*
GemmDataPerReadB
);
constexpr
index_t
in_block_space
=
math
::
integer_least_multiple
(
in_e_n1_b_n2_2eor4e_block_desc
.
GetElementSpace
(),
lds_allocation_align
);
constexpr
index_t
wei_block_space
=
math
::
integer_least_multiple
(
wei_e_k_2eor4e_block_desc
.
GetElementSpace
(),
lds_allocation_align
);
__shared__
Float
p_in_block_double
[
2
*
in_block_space
];
__shared__
Float
p_wei_block_double
[
2
*
wei_block_space
];
// register allocation for output
AccDataType
p_out_thread
[
c_k0k2_n1n2_thread_mtx_desc
.
GetElementSpace
()];
// zero out threadwise output
threadwise_matrix_set_zero
(
c_k0k2_n1n2_thread_mtx_desc
,
p_out_thread
);
const
Float
*
p_wei_block_on_global
=
p_wei_global
;
// LDS double buffer: preload data into LDS
{
blockwise_in_copy
.
Run
(
p_in_global
,
p_in_block_double
);
blockwise_wei_copy
.
Run
(
p_wei_global
,
p_wei_block_double
);
}
// LDS double buffer: main body
for
(
index_t
e_block_data_begin
=
0
;
e_block_data_begin
+
2
*
EPerBlock
<
E
;
e_block_data_begin
+=
2
*
EPerBlock
)
{
// hcc compilation error: loop not unrolled: the optimizer was unable to perform the
// requested transformation;
// the transformation might be disabled or specified as part of an unsupported
// transformation
// ordering [-Werror,-Wpass-failed=transform-warning]
//#pragma unroll
for
(
index_t
iloop
=
0
;
iloop
<
2
;
++
iloop
)
{
const
bool
even_loop
=
(
iloop
%
2
==
0
);
Float
*
p_in_block_now
=
even_loop
?
p_in_block_double
:
p_in_block_double
+
in_block_space
;
Float
*
p_wei_block_now
=
even_loop
?
p_wei_block_double
:
p_wei_block_double
+
wei_block_space
;
Float
*
p_in_block_next
=
even_loop
?
p_in_block_double
+
in_block_space
:
p_in_block_double
;
Float
*
p_wei_block_next
=
even_loop
?
p_wei_block_double
+
wei_block_space
:
p_wei_block_double
;
Float
p_in_register_buffer
[
blockwise_in_copy
.
GetRegisterBufferSize
()];
Float
p_wei_register_buffer
[
blockwise_wei_copy
.
GetRegisterBufferSize
()];
blockwise_in_copy
.
MoveSlicingWindowOnSourceTensor
(
I0
,
Number
<
EPerBlock
>
{},
True
);
static_if
<
conv_dir
==
ImplicitGemmDirection
::
BackwardWeight
>
{}([
&
](
auto
fwd
)
{
fwd
(
blockwise_wei_copy
).
MoveSrcSlicingWindow
(
Sequence
<
EPerBlock
,
0
,
0
>
{},
True
);
}).
Else
([
&
](
auto
fwd
)
{
p_wei_block_on_global
+=
EPerBlock
*
fwd
(
wei_e_k_2eor4e_global_desc
).
GetStride
(
I0
);
});
__syncthreads
();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy
.
RunLoadRegisterBuffer
(
p_in_global
,
p_in_register_buffer
);
blockwise_wei_copy
.
RunLoadRegisterBuffer
(
p_wei_block_on_global
,
p_wei_register_buffer
);
// LDS double buffer: GEMM on current data
const
typename
vector_type
<
Float
,
EPACK
>::
MemoryType
*
p_a_block_vec
=
reinterpret_cast
<
const
typename
vector_type
<
Float
,
EPACK
>::
MemoryType
*>
(
p_wei_block_now
);
const
typename
vector_type
<
Float
,
EPACK
>::
MemoryType
*
p_b_block_vec
=
reinterpret_cast
<
const
typename
vector_type
<
Float
,
EPACK
>::
MemoryType
*>
(
p_in_block_now
);
blockwise_gemm
.
Run
(
p_a_block_vec
,
p_b_block_vec
,
p_out_thread
);
// LDS double buffer: store next data to LDS
blockwise_in_copy
.
RunStoreRegisterBuffer
(
p_in_register_buffer
,
p_in_block_next
);
blockwise_wei_copy
.
RunStoreRegisterBuffer
(
p_wei_register_buffer
,
p_wei_block_next
);
}
}
// LDS double buffer: tail
{
Float
p_in_register_buffer
[
blockwise_in_copy
.
GetRegisterBufferSize
()];
Float
p_wei_register_buffer
[
blockwise_wei_copy
.
GetRegisterBufferSize
()];
// even iteration
blockwise_in_copy
.
MoveSlicingWindowOnSourceTensor
(
I0
,
Number
<
EPerBlock
>
{},
True
);
static_if
<
conv_dir
==
ImplicitGemmDirection
::
BackwardWeight
>
{}([
&
](
auto
fwd
)
{
fwd
(
blockwise_wei_copy
).
MoveSrcSlicingWindow
(
Sequence
<
EPerBlock
,
0
,
0
>
{},
True
);
}).
Else
([
&
](
auto
fwd
)
{
p_wei_block_on_global
+=
EPerBlock
*
fwd
(
wei_e_k_2eor4e_global_desc
).
GetStride
(
I0
);
});
__syncthreads
();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy
.
RunLoadRegisterBuffer
(
p_in_global
,
p_in_register_buffer
);
blockwise_wei_copy
.
RunLoadRegisterBuffer
(
p_wei_block_on_global
,
p_wei_register_buffer
);
// LDS double buffer: GEMM on current data
// Vectorize the pointer to match with how half/bfloat16 datatypes are
// processed in gemm operation. Half type packs 4 half values while
// bfloat16 packs 2 bfloat16 values. Since gemm's matrix A and B
// 2D indexes are computed with a single value in mind (e.g. float),
// to retain the same 2D indexes for half/bfloat16, we recast datatype
// from a single half to 4 packed half/2 packed bfloat16 respectively.
const
typename
vector_type
<
Float
,
EPACK
>::
MemoryType
*
p_a_block_vec
=
reinterpret_cast
<
const
typename
vector_type
<
Float
,
EPACK
>::
MemoryType
*>
(
p_wei_block_double
);
const
typename
vector_type
<
Float
,
EPACK
>::
MemoryType
*
p_b_block_vec
=
reinterpret_cast
<
const
typename
vector_type
<
Float
,
EPACK
>::
MemoryType
*>
(
p_in_block_double
);
blockwise_gemm
.
Run
(
p_a_block_vec
,
p_b_block_vec
,
p_out_thread
);
// LDS double buffer: store next data to LDS
blockwise_in_copy
.
RunStoreRegisterBuffer
(
p_in_register_buffer
,
p_in_block_double
+
in_block_space
);
blockwise_wei_copy
.
RunStoreRegisterBuffer
(
p_wei_register_buffer
,
p_wei_block_double
+
wei_block_space
);
// odd iteration
__syncthreads
();
p_a_block_vec
=
reinterpret_cast
<
const
typename
vector_type
<
Float
,
EPACK
>::
MemoryType
*>
(
p_wei_block_double
+
wei_block_space
);
p_b_block_vec
=
reinterpret_cast
<
const
typename
vector_type
<
Float
,
EPACK
>::
MemoryType
*>
(
p_in_block_double
+
in_block_space
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_a_block_vec
,
p_b_block_vec
,
p_out_thread
);
}
// copy output: register to global memory
{
constexpr
index_t
K2
=
GemmMPerThreadSubC
;
constexpr
index_t
K1
=
GemmMLevel0Cluster
*
GemmMLevel1Cluster
;
// define tensor descriptor for threadwise copy
// output memory layout descriptor in register
constexpr
auto
out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
KPerBlock
/
(
K1
*
K2
),
1
,
K2
,
N1
,
1
,
1
,
1
,
N2
>
{});
// output tensor descriptor in register, src of threadwise copy
constexpr
auto
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc
=
out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc
.
ReorderGivenNew2Old
(
Sequence
<
4
,
3
,
7
,
0
,
1
,
2
,
5
,
6
>
{});
// output memory layout descriptor in device memory, dst of threadwise copy
constexpr
auto
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc
=
out_n_k_h_w_global_desc
.
Fold
(
I1
,
Number
<
K1
>
{},
Number
<
K2
>
{})
.
Fold
(
I0
,
Number
<
N1
>
{},
Number
<
N2
>
{});
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
index_t
k_thread_data_on_global
=
k_block_data_on_global
+
c_thread_mtx_on_block
.
row
;
const
index_t
b_thread_data_on_global
=
b_block_data_on_global
+
c_thread_mtx_on_block
.
col
/
N2
;
// output merged global tensor descriptor, for calculating origin of thread tensor
// in global memory
constexpr
auto
out_k_n1_b_n2_global_merged_desc
=
make_ConstantMergedTensorDescriptor
(
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc
.
Unfold
(
I3
,
I5
),
Sequence
<
3
>
{},
Sequence
<
1
>
{},
Sequence
<
0
,
4
,
5
>
{},
Sequence
<
2
>
{});
// origin of dst in device memory
Float
*
p_out_thread_on_global
=
p_out_global
+
out_k_n1_b_n2_global_merged_desc
.
GetOffsetFromMultiIndex
(
k_thread_data_on_global
,
0
,
b_thread_data_on_global
,
0
);
ThreadwiseGenericTensorSliceCopy_v1r2
<
decltype
(
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc
),
decltype
(
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc
),
decltype
(
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc
.
GetLengths
()),
arithmetic_sequence_gen
<
0
,
8
,
1
>::
type
,
7
,
1
,
1
>
(
make_zero_array
<
index_t
,
8
>
(),
make_zero_array
<
index_t
,
8
>
())
.
Run
(
p_out_thread
,
p_out_thread_on_global
);
}
}
};
}
// namespace ck
#endif // CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_FP16_BFP16_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4_nchw_kc1x1_nkhw_lds_double_buffer.hpp
0 → 100644
View file @
32850b93
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_NCHW_KC1x1_NKHW_LDS_DOUBLE_BUFFER_HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_NCHW_KC1x1_NKHW_LDS_DOUBLE_BUFFER_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "implicitgemm_params.hpp"
namespace
ck
{
// define B = merge(N0, Ho, Wo)
template
<
index_t
GridSize
,
index_t
BlockSize
,
class
Float
,
class
AccDataType
,
class
InGlobalDesc
,
// exchanged outside for backward
class
WeiGlobalDesc
,
class
OutGlobalDesc
,
// exchanged outside for backward
class
ConvStrides
,
ImplicitGemmDirection
Direction
,
index_t
BPerBlock
,
index_t
KPerBlock
,
index_t
EPerBlock
,
index_t
N1
,
index_t
N2
,
index_t
GemmMPerThreadSubC
,
index_t
GemmNPerThreadSubC
,
index_t
GemmMLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
GemmKPerThreadLoop
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadB
,
class
InBlockCopySubLengths_E_N1_B_N2
,
class
InBlockCopyClusterLengths_E_N1_B_N2
,
class
InBlockCopyThreadClusterArrangeOrder
,
class
InBlockCopySrcAccessOrder
,
class
InBlockCopyDstAccessOrder
,
index_t
InBlockCopySrcDataPerRead_B
,
index_t
InBlockCopyDstDataPerWrite_N2
,
class
WeiBlockCopySubLengths_E_K
,
class
WeiBlockCopyClusterLengths_E_K
,
class
WeiBlockCopyThreadClusterArrangeOrder
,
class
WeiBlockCopySrcAccessOrder
,
class
WeiBlockCopyDstAccessOrder
,
index_t
WeiBlockCopySrcDataPerRead_E
,
index_t
WeiBlockCopyDstDataPerWrite_K
>
struct
GridwiseConvolutionImplicitGemm_v4_nchw_kc1x1_nkhw_lds_double_buffer
{
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
const
{
constexpr
bool
isForward
=
Direction
==
ImplicitGemmDirection
::
ForwardData
;
// this is a mess
// TODO: find more elegent way of specifying (or calculating) performance parameters
static_assert
(
N2
==
GemmNPerThreadSubC
,
"wrong!"
);
static_assert
((
N1
*
N2
*
BPerBlock
)
%
(
GemmNPerThreadSubC
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
)
==
0
,
"wrong!"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I4
=
Number
<
4
>
{};
constexpr
auto
I5
=
Number
<
5
>
{};
constexpr
auto
I6
=
Number
<
6
>
{};
constexpr
auto
I7
=
Number
<
7
>
{};
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
in_n_c_h_w_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_c_k_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_n_k_h_w_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
N
=
in_n_c_h_w_global_desc
.
GetLength
(
I0
);
constexpr
index_t
C
=
in_n_c_h_w_global_desc
.
GetLength
(
I1
);
constexpr
index_t
K
=
out_n_k_h_w_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Ho
=
std
::
conditional
<
isForward
,
decltype
(
out_n_k_h_w_global_desc
),
decltype
(
in_n_c_h_w_global_desc
)
>::
type
::
GetLength
(
I2
);
constexpr
index_t
Wo
=
std
::
conditional
<
isForward
,
decltype
(
out_n_k_h_w_global_desc
),
decltype
(
in_n_c_h_w_global_desc
)
>::
type
::
GetLength
(
I3
);
static_assert
(
N
%
(
N1
*
N2
)
==
0
,
"wrong! cannot divice N evenly among thread"
);
constexpr
index_t
N0
=
N
/
(
N1
*
N2
);
constexpr
index_t
B
=
N0
*
Ho
*
Wo
;
constexpr
index_t
E
=
C
;
// divide block work by [K, B]
static_assert
(
K
%
KPerBlock
==
0
&&
B
%
BPerBlock
==
0
&&
E
%
(
2
*
EPerBlock
)
==
0
,
"wrong! cannot divide work evenly among block"
);
constexpr
index_t
KBlockWork
=
K
/
KPerBlock
;
constexpr
index_t
BBlockWork
=
B
/
BPerBlock
;
constexpr
auto
block_work_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
KBlockWork
,
BBlockWork
>
{});
const
auto
block_work_multi_id
=
block_work_desc
.
GetMultiIndexFrom1dIndex
(
get_block_1d_id
());
const
index_t
k_block_data_on_global
=
block_work_multi_id
[
0
]
*
KPerBlock
;
const
index_t
b_block_data_on_global
=
block_work_multi_id
[
1
]
*
BPerBlock
;
// input tensor
// tensor descriptor in device memory [N0, N1, N2, Ho, Wo]
constexpr
auto
in_n0_n1_n2_h_w_global_desc_forw
=
in_n_c_h_w_global_desc
.
StridedSlice
(
I2
,
Number
<
Ho
>
{},
Number
<
ConvStrides
::
Get
(
I0
)
>
{})
.
StridedSlice
(
I3
,
Number
<
Wo
>
{},
Number
<
ConvStrides
::
Get
(
I1
)
>
{})
.
Fold
(
I0
,
Number
<
N1
>
{},
Number
<
N2
>
{})
.
Extract
(
Sequence
<
0
,
1
,
2
,
4
,
5
>
{});
constexpr
auto
in_n0_n1_n2_h_w_global_desc_back
=
in_n_c_h_w_global_desc
.
Fold
(
I0
,
Number
<
N1
>
{},
Number
<
N2
>
{})
.
Extract
(
Sequence
<
0
,
1
,
2
,
4
,
5
>
{});
constexpr
auto
in_n0_n1_n2_h_w_global_desc
=
typename
std
::
conditional
<
isForward
,
decltype
(
in_n0_n1_n2_h_w_global_desc_forw
),
decltype
(
in_n0_n1_n2_h_w_global_desc_back
)
>::
type
{};
// batch descritpor for device memory
constexpr
auto
in_c_global_desc
=
in_n_c_h_w_global_desc
.
Extract
(
I1
);
// merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy
constexpr
auto
in_e_n1_b_n2_global_merged_desc
=
make_ConstantMergedTensorDescriptor
(
in_c_global_desc
.
Embed
(
in_n0_n1_n2_h_w_global_desc
),
Sequence
<
0
>
{},
Sequence
<
2
>
{},
Sequence
<
1
,
4
,
5
>
{},
Sequence
<
3
>
{});
// memory layout descriptor in LDS [E, N1, B, N2], dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
in_e_n1_b_n2_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
EPerBlock
,
N1
,
BPerBlock
,
N2
>
{},
Number
<
InBlockCopyDstDataPerWrite_N2
>
{});
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with multiple alignment
// requirements
static_assert
(
in_e_n1_b_n2_block_desc
.
GetStride
(
I1
)
%
GemmDataPerReadB
==
0
,
"GemmDataPerReadB alignment requirement is not satisfied"
);
// input blockwise copy
// slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in
auto
blockwise_in_copy
=
BlockwiseGenericTensorSliceCopy_v1
<
BlockSize
,
decltype
(
in_e_n1_b_n2_global_merged_desc
),
decltype
(
in_e_n1_b_n2_block_desc
),
decltype
(
in_e_n1_b_n2_block_desc
.
GetLengths
()),
InBlockCopySubLengths_E_N1_B_N2
,
InBlockCopyClusterLengths_E_N1_B_N2
,
InBlockCopyThreadClusterArrangeOrder
,
InBlockCopySrcAccessOrder
,
InBlockCopyDstAccessOrder
,
2
,
3
,
InBlockCopySrcDataPerRead_B
,
InBlockCopyDstDataPerWrite_N2
>
(
{
0
,
0
,
b_block_data_on_global
,
0
},
{
0
,
0
,
0
,
0
});
// weight tensor
// tensor descriptor in device memory, src of blockwise copy
constexpr
auto
wei_e_k_global_desc_forw
=
wei_c_k_global_desc
;
constexpr
auto
wei_e_k_global_desc_back
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
C
,
K
>
{});
constexpr
auto
wei_e_k_global_desc
=
typename
std
::
conditional
<
isForward
,
decltype
(
wei_e_k_global_desc_forw
),
decltype
(
wei_e_k_global_desc_back
)
>::
type
{};
// tensor descriptor in LDS, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
wei_e_k_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
EPerBlock
,
KPerBlock
>
{},
Number
<
math
::
lcm
(
WeiBlockCopyDstDataPerWrite_K
,
GemmDataPerReadA
)
>
{});
// operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in
auto
blockwise_wei_copy
=
BlockwiseGenericTensorSliceCopy_v1
<
BlockSize
,
decltype
(
wei_e_k_global_desc
),
decltype
(
wei_e_k_block_desc
),
decltype
(
wei_e_k_block_desc
.
GetLengths
()),
WeiBlockCopySubLengths_E_K
,
WeiBlockCopyClusterLengths_E_K
,
WeiBlockCopyThreadClusterArrangeOrder
,
WeiBlockCopySrcAccessOrder
,
WeiBlockCopyDstAccessOrder
,
0
,
1
,
WeiBlockCopySrcDataPerRead_E
,
WeiBlockCopyDstDataPerWrite_K
>
(
{
0
,
k_block_data_on_global
},
{
0
,
0
});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[EPerBlock, KPerBlock] is in LDS
// b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
// register
constexpr
auto
a_e_k_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
wei_e_k_block_desc
);
constexpr
auto
b_e_n1bn2_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
in_e_n1_b_n2_block_desc
.
Unfold
(
I1
,
I3
));
// sanity check
static_assert
(
KPerBlock
%
(
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
)
==
0
,
"wrong!"
);
constexpr
index_t
GemmMRepeat
=
KPerBlock
/
(
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr
auto
c_k0k2_n1n2_thread_mtx_desc
=
make_ConstantMatrixDescriptor_packed
(
Number
<
GemmMRepeat
*
GemmMPerThreadSubC
>
{},
Number
<
N1
*
N2
>
{});
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
<
BlockSize
,
1
,
// EPACK = 1
decltype
(
a_e_k_block_mtx_desc
),
decltype
(
b_e_n1bn2_block_mtx_desc
),
decltype
(
c_k0k2_n1n2_thread_mtx_desc
),
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{};
// LDS allocation for input and weight: be careful of alignment
constexpr
index_t
max_align
=
math
::
lcm
(
InBlockCopyDstDataPerWrite_N2
,
WeiBlockCopyDstDataPerWrite_K
,
GemmDataPerReadA
,
GemmDataPerReadB
);
constexpr
index_t
in_block_space
=
math
::
integer_least_multiple
(
in_e_n1_b_n2_block_desc
.
GetElementSpace
(),
max_align
);
constexpr
index_t
wei_block_space
=
math
::
integer_least_multiple
(
wei_e_k_block_desc
.
GetElementSpace
(),
max_align
);
__shared__
Float
p_in_block_double
[
2
*
in_block_space
];
__shared__
Float
p_wei_block_double
[
2
*
wei_block_space
];
// register allocation for output
AccDataType
p_out_thread
[
c_k0k2_n1n2_thread_mtx_desc
.
GetElementSpace
()];
// zero out threadwise output
threadwise_matrix_set_zero
(
c_k0k2_n1n2_thread_mtx_desc
,
p_out_thread
);
const
Float
*
p_wei_block_on_global
=
p_wei_global
;
// LDS double buffer: preload data into LDS
{
blockwise_in_copy
.
Run
(
p_in_global
,
p_in_block_double
);
blockwise_wei_copy
.
Run
(
p_wei_global
,
p_wei_block_double
);
}
// LDS double buffer: main body
for
(
index_t
e_block_data_begin
=
0
;
e_block_data_begin
+
2
*
EPerBlock
<
E
;
e_block_data_begin
+=
2
*
EPerBlock
)
{
#pragma unroll
for
(
index_t
iloop
=
0
;
iloop
<
2
;
++
iloop
)
{
const
bool
even_loop
=
(
iloop
%
2
==
0
);
Float
*
p_in_block_now
=
even_loop
?
p_in_block_double
:
p_in_block_double
+
in_block_space
;
Float
*
p_wei_block_now
=
even_loop
?
p_wei_block_double
:
p_wei_block_double
+
wei_block_space
;
Float
*
p_in_block_next
=
even_loop
?
p_in_block_double
+
in_block_space
:
p_in_block_double
;
Float
*
p_wei_block_next
=
even_loop
?
p_wei_block_double
+
wei_block_space
:
p_wei_block_double
;
Float
p_in_register_buffer
[
blockwise_in_copy
.
GetRegisterBufferSize
()];
Float
p_wei_register_buffer
[
blockwise_wei_copy
.
GetRegisterBufferSize
()];
blockwise_in_copy
.
MoveSrcSlicingWindow
(
Sequence
<
EPerBlock
,
0
,
0
,
0
>
{},
True
);
p_wei_block_on_global
+=
EPerBlock
*
wei_e_k_global_desc
.
GetStride
(
I0
);
__syncthreads
();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy
.
RunLoadRegisterBuffer
(
p_in_global
,
p_in_register_buffer
);
blockwise_wei_copy
.
RunLoadRegisterBuffer
(
p_wei_block_on_global
,
p_wei_register_buffer
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_wei_block_now
,
p_in_block_now
,
p_out_thread
);
// LDS double buffer: store next data to LDS
blockwise_in_copy
.
RunStoreRegisterBuffer
(
p_in_register_buffer
,
p_in_block_next
);
blockwise_wei_copy
.
RunStoreRegisterBuffer
(
p_wei_register_buffer
,
p_wei_block_next
);
}
}
// LDS double buffer: tail
{
Float
p_in_register_buffer
[
blockwise_in_copy
.
GetRegisterBufferSize
()];
Float
p_wei_register_buffer
[
blockwise_wei_copy
.
GetRegisterBufferSize
()];
// even iteration
blockwise_in_copy
.
MoveSrcSlicingWindow
(
Sequence
<
EPerBlock
,
0
,
0
,
0
>
{},
True
);
p_wei_block_on_global
+=
EPerBlock
*
wei_e_k_global_desc
.
GetStride
(
I0
);
__syncthreads
();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy
.
RunLoadRegisterBuffer
(
p_in_global
,
p_in_register_buffer
);
blockwise_wei_copy
.
RunLoadRegisterBuffer
(
p_wei_block_on_global
,
p_wei_register_buffer
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_wei_block_double
,
p_in_block_double
,
p_out_thread
);
// LDS double buffer: store next data to LDS
blockwise_in_copy
.
RunStoreRegisterBuffer
(
p_in_register_buffer
,
p_in_block_double
+
in_block_space
);
blockwise_wei_copy
.
RunStoreRegisterBuffer
(
p_wei_register_buffer
,
p_wei_block_double
+
wei_block_space
);
// odd iteration
__syncthreads
();
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_wei_block_double
+
wei_block_space
,
p_in_block_double
+
in_block_space
,
p_out_thread
);
}
// copy output: register to global memory
{
constexpr
index_t
K2
=
GemmMPerThreadSubC
;
constexpr
index_t
K1
=
GemmMLevel0Cluster
*
GemmMLevel1Cluster
;
// define tensor descriptor for threadwise copy
// output memory layout descriptor in register
constexpr
auto
out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
KPerBlock
/
(
K1
*
K2
),
1
,
K2
,
N1
,
1
,
1
,
1
,
N2
>
{});
// output tensor descriptor in register, src of threadwise copy
constexpr
auto
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc
=
out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc
.
ReorderGivenNew2Old
(
Sequence
<
4
,
3
,
7
,
0
,
1
,
2
,
5
,
6
>
{});
// output memory layout descriptor in device memory, dst of threadwise copy
constexpr
auto
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc_forw
=
out_n_k_h_w_global_desc
.
Fold
(
I1
,
Number
<
K1
>
{},
Number
<
K2
>
{})
.
Fold
(
I0
,
Number
<
N1
>
{},
Number
<
N2
>
{});
constexpr
auto
out_lengths_new
=
Sequence
<
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc_forw
.
GetLength
(
I0
),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc_forw
.
GetLength
(
I1
),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc_forw
.
GetLength
(
I2
),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc_forw
.
GetLength
(
I3
),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc_forw
.
GetLength
(
I4
),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc_forw
.
GetLength
(
I5
),
math
::
integer_divide_ceil
(
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc_forw
.
GetLength
(
I6
),
ConvStrides
{}.
Get
(
I0
)),
math
::
integer_divide_ceil
(
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc_forw
.
GetLength
(
I7
),
ConvStrides
{}.
Get
(
I1
))
>
{};
constexpr
auto
out_strides_new
=
Sequence
<
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc_forw
.
GetStride
(
I0
),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc_forw
.
GetStride
(
I1
),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc_forw
.
GetStride
(
I2
),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc_forw
.
GetStride
(
I3
),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc_forw
.
GetStride
(
I4
),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc_forw
.
GetStride
(
I5
),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc_forw
.
GetStride
(
I6
)
*
ConvStrides
{}.
Get
(
I0
),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc_forw
.
GetStride
(
I7
)
*
ConvStrides
{}.
Get
(
I1
)
>
{};
constexpr
auto
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc_back
=
make_ConstantTensorDescriptor
(
out_lengths_new
,
out_strides_new
);
constexpr
auto
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc
=
typename
std
::
conditional
<
isForward
,
decltype
(
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc_forw
),
decltype
(
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc_back
)
>::
type
{};
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
index_t
k_thread_data_on_global
=
k_block_data_on_global
+
c_thread_mtx_on_block
.
row
;
const
index_t
b_thread_data_on_global
=
b_block_data_on_global
+
c_thread_mtx_on_block
.
col
/
N2
;
// output merged global tensor descriptor, for calculating origin of thread tensor
// in global memory
constexpr
auto
out_k_n1_b_n2_global_merged_desc
=
make_ConstantMergedTensorDescriptor
(
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc
.
Unfold
(
I3
,
I5
),
Sequence
<
3
>
{},
Sequence
<
1
>
{},
Sequence
<
0
,
4
,
5
>
{},
Sequence
<
2
>
{});
// origin of dst in device memory
Float
*
p_out_thread_on_global
=
p_out_global
+
out_k_n1_b_n2_global_merged_desc
.
GetOffsetFromMultiIndex
(
k_thread_data_on_global
,
0
,
b_thread_data_on_global
,
0
);
ThreadwiseGenericTensorSliceCopy_v1r2
<
decltype
(
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc
),
decltype
(
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc
),
decltype
(
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc
.
GetLengths
()),
arithmetic_sequence_gen
<
0
,
8
,
1
>::
type
,
7
,
1
,
1
>
(
make_zero_array
<
index_t
,
8
>
(),
make_zero_array
<
index_t
,
8
>
())
.
Run
(
p_out_thread
,
p_out_thread_on_global
);
}
}
};
}
// namespace ck
#endif // CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_NCHW_KC1x1_NKHW_LDS_DOUBLE_BUFFER_HPP
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw_lds_double_buffer.hpp
View file @
32850b93
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER
_HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER
_HPP
#include "common_header.hpp"
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantTensorDescriptor.hpp"
...
@@ -8,9 +8,35 @@
...
@@ -8,9 +8,35 @@
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp"
#include "blockwise_gemm.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "implicitgemm_params.hpp"
namespace
ck
{
namespace
ck
{
template
<
ImplicitGemmDirection
conv_dir
,
typename
WeiDesc
>
struct
make_WeiDesc
{
};
template
<
typename
WeiDesc
>
struct
make_WeiDesc
<
ImplicitGemmDirection
::
ForwardData
,
WeiDesc
>
{
__device__
constexpr
auto
get
(
WeiDesc
&
)
{
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
return
WeiDesc
{}.
Unfold
(
I1
,
I3
).
ReorderGivenNew2Old
(
Sequence
<
1
,
0
>
{});
}
};
template
<
typename
WeiDesc
>
struct
make_WeiDesc
<
ImplicitGemmDirection
::
BackwardWeight
,
WeiDesc
>
{
__device__
constexpr
auto
get
(
WeiDesc
&
desc
)
{
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
return
make_ConstantMergedTensorDescriptor
(
desc
.
Unfold
(
I2
,
I3
),
Sequence
<
1
,
2
>
{},
Sequence
<
0
>
{});
}
};
// define B = merge(N0, Ho, Wo)
// define B = merge(N0, Ho, Wo)
template
<
index_t
GridSize
,
template
<
index_t
GridSize
,
index_t
BlockSize
,
index_t
BlockSize
,
...
@@ -24,8 +50,7 @@ template <index_t GridSize,
...
@@ -24,8 +50,7 @@ template <index_t GridSize,
index_t
BPerBlock
,
index_t
BPerBlock
,
index_t
KPerBlock
,
index_t
KPerBlock
,
index_t
EPerBlock
,
index_t
EPerBlock
,
index_t
N1
,
index_t
GemmNRepeat
,
index_t
N2
,
index_t
GemmMPerThreadSubC
,
index_t
GemmMPerThreadSubC
,
index_t
GemmNPerThreadSubC
,
index_t
GemmNPerThreadSubC
,
index_t
GemmMLevel0Cluster
,
index_t
GemmMLevel0Cluster
,
...
@@ -48,17 +73,19 @@ template <index_t GridSize,
...
@@ -48,17 +73,19 @@ template <index_t GridSize,
class
WeiBlockCopySrcAccessOrder
,
class
WeiBlockCopySrcAccessOrder
,
class
WeiBlockCopyDstAccessOrder
,
class
WeiBlockCopyDstAccessOrder
,
index_t
WeiBlockCopySrcDataPerRead_E
,
index_t
WeiBlockCopySrcDataPerRead_E
,
index_t
WeiBlockCopyDstDataPerWrite_K
>
index_t
WeiBlockCopyDstDataPerWrite_K
,
ImplicitGemmDirection
conv_dir
>
struct
GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
struct
GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
{
{
__device__
void
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
const
Float
*
const
__restrict__
p_out_global
)
const
{
{
// this is a mess
// this is a mess
// TODO: find more elegent way of specifying (or calculating) performance parameters
// TODO: find more elegent way of specifying (or calculating) performance parameters
static_assert
(
N2
==
GemmNPerThreadSubC
,
"wrong!"
);
constexpr
index_t
N1
=
GemmNRepeat
;
constexpr
index_t
N2
=
GemmNPerThreadSubC
;
static_assert
((
N1
*
N2
*
BPerBlock
)
%
static_assert
((
N1
*
N2
*
BPerBlock
)
%
(
GemmNPerThreadSubC
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
)
==
(
GemmNPerThreadSubC
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
)
==
0
,
0
,
...
@@ -86,6 +113,12 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
...
@@ -86,6 +113,12 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
constexpr
index_t
Y
=
wei_k_c_y_x_global_desc
.
GetLength
(
I2
);
constexpr
index_t
Y
=
wei_k_c_y_x_global_desc
.
GetLength
(
I2
);
constexpr
index_t
X
=
wei_k_c_y_x_global_desc
.
GetLength
(
I3
);
constexpr
index_t
X
=
wei_k_c_y_x_global_desc
.
GetLength
(
I3
);
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
static_assert
(
N
%
(
N1
*
N2
)
==
0
,
"wrong! cannot divice N evenly among thread"
);
static_assert
(
N
%
(
N1
*
N2
)
==
0
,
"wrong! cannot divice N evenly among thread"
);
constexpr
index_t
N0
=
N
/
(
N1
*
N2
);
constexpr
index_t
N0
=
N
/
(
N1
*
N2
);
...
@@ -94,6 +127,14 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
...
@@ -94,6 +127,14 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
constexpr
index_t
E
=
C
*
Y
*
X
;
constexpr
index_t
E
=
C
*
Y
*
X
;
// sanity-check for vectorized memory load
static_assert
(
ConvStrideW
==
1
||
InBlockCopySrcDataPerRead_B
==
1
,
"wrong! global vector load of input tensor is wrong"
);
static_assert
((
X
==
1
||
ConvDilationW
%
InBlockCopySrcDataPerRead_B
==
0
),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated"
);
// divide block work by [K, B]
// divide block work by [K, B]
static_assert
(
K
%
KPerBlock
==
0
&&
B
%
BPerBlock
==
0
&&
E
%
(
2
*
EPerBlock
)
==
0
,
static_assert
(
K
%
KPerBlock
==
0
&&
B
%
BPerBlock
==
0
&&
E
%
(
2
*
EPerBlock
)
==
0
,
"wrong! cannot divide work evenly among block"
);
"wrong! cannot divide work evenly among block"
);
...
@@ -113,15 +154,15 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
...
@@ -113,15 +154,15 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
// input tensor
// input tensor
// tensor descriptor in device memory [N0, N1, N2, Ho, Wo]
// tensor descriptor in device memory [N0, N1, N2, Ho, Wo]
constexpr
auto
in_n0_n1_n2_h_w_global_desc
=
constexpr
auto
in_n0_n1_n2_h_w_global_desc
=
in_n_c_h_w_global_desc
.
StridedSlice
(
I2
,
Number
<
Ho
>
{},
Number
<
ConvStride
s
::
Get
(
I0
)
>
{})
in_n_c_h_w_global_desc
.
StridedSlice
(
I2
,
Number
<
Ho
>
{},
Number
<
ConvStride
H
>
{})
.
StridedSlice
(
I3
,
Number
<
Wo
>
{},
Number
<
ConvStride
s
::
Get
(
I1
)
>
{})
.
StridedSlice
(
I3
,
Number
<
Wo
>
{},
Number
<
ConvStride
W
>
{})
.
Fold
(
I0
,
Number
<
N1
>
{},
Number
<
N2
>
{})
.
Fold
(
I0
,
Number
<
N1
>
{},
Number
<
N2
>
{})
.
Extract
(
Sequence
<
0
,
1
,
2
,
4
,
5
>
{});
.
Extract
(
Sequence
<
0
,
1
,
2
,
4
,
5
>
{});
// batch descritpor for device memory
// batch descritpor for device memory
constexpr
auto
in_c_y_x_global_desc
=
constexpr
auto
in_c_y_x_global_desc
=
in_n_c_h_w_global_desc
.
StridedSlice
(
I2
,
Number
<
Y
>
{},
Number
<
ConvDilation
s
::
Get
(
I0
)
>
{})
in_n_c_h_w_global_desc
.
StridedSlice
(
I2
,
Number
<
Y
>
{},
Number
<
ConvDilation
H
>
{})
.
StridedSlice
(
I3
,
Number
<
X
>
{},
Number
<
ConvDilation
s
::
Get
(
I1
)
>
{})
.
StridedSlice
(
I3
,
Number
<
X
>
{},
Number
<
ConvDilation
W
>
{})
.
Extract
(
Sequence
<
1
,
2
,
3
>
{});
.
Extract
(
Sequence
<
1
,
2
,
3
>
{});
// merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy
// merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy
...
@@ -148,7 +189,6 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
...
@@ -148,7 +189,6 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
// this copy operator already has blockwise offset built-in
// this copy operator already has blockwise offset built-in
auto
blockwise_in_copy
=
auto
blockwise_in_copy
=
BlockwiseGenericTensorSliceCopy_v1
<
BlockSize
,
BlockwiseGenericTensorSliceCopy_v1
<
BlockSize
,
Float
,
decltype
(
in_e_n1_b_n2_global_merged_desc
),
decltype
(
in_e_n1_b_n2_global_merged_desc
),
decltype
(
in_e_n1_b_n2_block_desc
),
decltype
(
in_e_n1_b_n2_block_desc
),
decltype
(
in_e_n1_b_n2_block_desc
.
GetLengths
()),
decltype
(
in_e_n1_b_n2_block_desc
.
GetLengths
()),
...
@@ -157,6 +197,8 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
...
@@ -157,6 +197,8 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
InBlockCopyThreadClusterArrangeOrder
,
InBlockCopyThreadClusterArrangeOrder
,
InBlockCopySrcAccessOrder
,
InBlockCopySrcAccessOrder
,
InBlockCopyDstAccessOrder
,
InBlockCopyDstAccessOrder
,
2
,
3
,
InBlockCopySrcDataPerRead_B
,
InBlockCopySrcDataPerRead_B
,
InBlockCopyDstDataPerWrite_N2
>
(
InBlockCopyDstDataPerWrite_N2
>
(
{
0
,
0
,
b_block_data_on_global
,
0
},
{
0
,
0
,
0
,
0
});
{
0
,
0
,
b_block_data_on_global
,
0
},
{
0
,
0
,
0
,
0
});
...
@@ -164,7 +206,8 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
...
@@ -164,7 +206,8 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
// weight tensor
// weight tensor
// tensor descriptor in device memory, src of blockwise copy
// tensor descriptor in device memory, src of blockwise copy
constexpr
auto
wei_e_k_global_desc
=
constexpr
auto
wei_e_k_global_desc
=
wei_k_c_y_x_global_desc
.
Unfold
(
I1
,
I3
).
ReorderGivenNew2Old
(
Sequence
<
1
,
0
>
{});
make_WeiDesc
<
conv_dir
,
decltype
(
wei_k_c_y_x_global_desc
)
>
{}.
get
(
wei_k_c_y_x_global_desc
);
// tensor descriptor in LDS, dst of blockwise copy
// tensor descriptor in LDS, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
...
@@ -177,7 +220,6 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
...
@@ -177,7 +220,6 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
// this copy operator already have blockwise offset built-in
// this copy operator already have blockwise offset built-in
auto
blockwise_wei_copy
=
auto
blockwise_wei_copy
=
BlockwiseGenericTensorSliceCopy_v1
<
BlockSize
,
BlockwiseGenericTensorSliceCopy_v1
<
BlockSize
,
Float
,
decltype
(
wei_e_k_global_desc
),
decltype
(
wei_e_k_global_desc
),
decltype
(
wei_e_k_block_desc
),
decltype
(
wei_e_k_block_desc
),
decltype
(
wei_e_k_block_desc
.
GetLengths
()),
decltype
(
wei_e_k_block_desc
.
GetLengths
()),
...
@@ -186,6 +228,8 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
...
@@ -186,6 +228,8 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
WeiBlockCopyThreadClusterArrangeOrder
,
WeiBlockCopyThreadClusterArrangeOrder
,
WeiBlockCopySrcAccessOrder
,
WeiBlockCopySrcAccessOrder
,
WeiBlockCopyDstAccessOrder
,
WeiBlockCopyDstAccessOrder
,
0
,
1
,
WeiBlockCopySrcDataPerRead_E
,
WeiBlockCopySrcDataPerRead_E
,
WeiBlockCopyDstDataPerWrite_K
>
(
WeiBlockCopyDstDataPerWrite_K
>
(
{
0
,
k_block_data_on_global
},
{
0
,
0
});
{
0
,
k_block_data_on_global
},
{
0
,
0
});
...
@@ -196,13 +240,11 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
...
@@ -196,13 +240,11 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
// b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS
// b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
// register
// register
constexpr
auto
a_e_k_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
EPerBlock
>
{},
Number
<
KPerBlock
>
{},
Number
<
wei_e_k_block_desc
.
GetStride
(
I0
)
>
{}
);
constexpr
auto
a_e_k_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
wei_e_k_block_desc
);
constexpr
auto
b_e_n1bn2_block_mtx_desc
=
constexpr
auto
b_e_n1bn2_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
EPerBlock
>
{},
make_ConstantMatrixDescriptor
(
in_e_n1_b_n2_block_desc
.
Unfold
(
I1
,
I3
));
Number
<
N1
*
BPerBlock
*
N2
>
{},
Number
<
in_e_n1_b_n2_block_desc
.
GetStride
(
I0
)
>
{});
// sanity check
// sanity check
static_assert
(
KPerBlock
%
(
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
)
==
static_assert
(
KPerBlock
%
(
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
)
==
...
@@ -214,11 +256,12 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
...
@@ -214,11 +256,12 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
// c_thread_mtx definition: this is a mess
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
// TODO:: more elegent way of defining c_thread_mtx
constexpr
auto
c_k0k2_n1n2_thread_mtx_desc
=
make_ConstantMatrixDescriptor
(
constexpr
auto
c_k0k2_n1n2_thread_mtx_desc
=
make_ConstantMatrixDescriptor
_packed
(
Number
<
GemmMRepeat
*
GemmMPerThreadSubC
>
{},
Number
<
N1
*
N2
>
{});
Number
<
GemmMRepeat
*
GemmMPerThreadSubC
>
{},
Number
<
GemmNRepeat
*
GemmNPerThreadSubC
>
{});
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
<
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
<
BlockSize
,
BlockSize
,
1
,
// EPACK = 1
decltype
(
a_e_k_block_mtx_desc
),
decltype
(
a_e_k_block_mtx_desc
),
decltype
(
b_e_n1bn2_block_mtx_desc
),
decltype
(
b_e_n1bn2_block_mtx_desc
),
decltype
(
c_k0k2_n1n2_thread_mtx_desc
),
decltype
(
c_k0k2_n1n2_thread_mtx_desc
),
...
@@ -280,53 +323,58 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
...
@@ -280,53 +323,58 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
Float
*
p_wei_block_next
=
Float
*
p_wei_block_next
=
even_loop
?
p_wei_block_double
+
wei_block_space
:
p_wei_block_double
;
even_loop
?
p_wei_block_double
+
wei_block_space
:
p_wei_block_double
;
Float
p_in_register_
clipboard
[
blockwise_in_copy
.
GetRegister
Clipboard
Size
()];
Float
p_in_register_
buffer
[
blockwise_in_copy
.
GetRegister
Buffer
Size
()];
Float
p_wei_register_
clipboard
[
blockwise_wei_copy
.
GetRegister
Clipboard
Size
()];
Float
p_wei_register_
buffer
[
blockwise_wei_copy
.
GetRegister
Buffer
Size
()];
blockwise_in_copy
.
MoveSlicingWindowOnSourceTensor
(
I0
,
Number
<
EPerBlock
>
{},
True
);
blockwise_in_copy
.
MoveSrcSlicingWindow
(
Sequence
<
EPerBlock
,
0
,
0
,
0
>
{},
True
);
p_wei_block_on_global
+=
EPerBlock
*
wei_e_k_global_desc
.
GetStride
(
I0
);
static_if
<
conv_dir
==
ImplicitGemmDirection
::
BackwardWeight
>
{}([
&
](
auto
fwd
)
{
fwd
(
blockwise_wei_copy
).
MoveSrcSlicingWindow
(
Sequence
<
EPerBlock
,
0
>
{},
True
);
}).
Else
([
&
](
auto
fwd
)
{
p_wei_block_on_global
+=
EPerBlock
*
fwd
(
wei_e_k_global_desc
).
GetStride
(
I0
);
});
__syncthreads
();
__syncthreads
();
// LDS doubel buffer: load next data from device mem
// LDS doubel buffer: load next data from device mem
blockwise_in_copy
.
RunLoadRegister
Clipboard
(
p_in_global
,
p_in_register_
clipboard
);
blockwise_in_copy
.
RunLoadRegister
Buffer
(
p_in_global
,
p_in_register_
buffer
);
blockwise_wei_copy
.
RunLoadRegister
Clipboard
(
p_wei_block_on_global
,
blockwise_wei_copy
.
RunLoadRegister
Buffer
(
p_wei_block_on_global
,
p_wei_register_
clipboard
);
p_wei_register_
buffer
);
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_wei_block_now
,
p_in_block_now
,
p_out_thread
);
blockwise_gemm
.
Run
(
p_wei_block_now
,
p_in_block_now
,
p_out_thread
);
// LDS double buffer: store next data to LDS
// LDS double buffer: store next data to LDS
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
blockwise_in_copy
.
RunStoreRegisterBuffer
(
p_in_register_buffer
,
p_in_block_next
);
p_in_block_next
);
blockwise_wei_copy
.
RunStoreRegisterBuffer
(
p_wei_register_buffer
,
p_wei_block_next
);
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_register_clipboard
,
p_wei_block_next
);
}
}
}
}
// LDS double buffer: tail
// LDS double buffer: tail
{
{
Float
p_in_register_clipboard
[
blockwise_in_copy
.
GetRegisterClipboardSize
()];
Float
p_wei_register_clipboard
[
blockwise_wei_copy
.
GetRegisterClipboardSize
()];
// even iteration
// even iteration
blockwise_in_copy
.
MoveSlicingWindowOnSourceTensor
(
I0
,
Number
<
EPerBlock
>
{},
True
);
Float
p_in_register_buffer
[
blockwise_in_copy
.
GetRegisterBufferSize
()];
p_wei_block_on_global
+=
EPerBlock
*
wei_e_k_global_desc
.
GetStride
(
I0
);
Float
p_wei_register_buffer
[
blockwise_wei_copy
.
GetRegisterBufferSize
()];
blockwise_in_copy
.
MoveSrcSlicingWindow
(
Sequence
<
EPerBlock
,
0
,
0
,
0
>
{},
True
);
static_if
<
conv_dir
==
ImplicitGemmDirection
::
BackwardWeight
>
{}([
&
](
auto
)
{
blockwise_wei_copy
.
MoveSrcSlicingWindow
(
Sequence
<
EPerBlock
,
0
>
{},
True
);
}).
Else
([
&
](
auto
fwd
)
{
p_wei_block_on_global
+=
EPerBlock
*
fwd
(
wei_e_k_global_desc
).
GetStride
(
I0
);
});
__syncthreads
();
__syncthreads
();
// LDS doubel buffer: load next data from device mem
// LDS doubel buffer: load next data from device mem
blockwise_in_copy
.
RunLoadRegisterClipboard
(
p_in_global
,
p_in_register_clipboard
);
blockwise_in_copy
.
RunLoadRegisterBuffer
(
p_in_global
,
p_in_register_buffer
);
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_block_on_global
,
blockwise_wei_copy
.
RunLoadRegisterBuffer
(
p_wei_block_on_global
,
p_wei_register_buffer
);
p_wei_register_clipboard
);
blockwise_gemm
.
Run
(
p_wei_block_double
,
p_in_block_double
,
p_out_thread
);
blockwise_gemm
.
Run
(
p_wei_block_double
,
p_in_block_double
,
p_out_thread
);
// LDS double buffer: store next data to LDS
// LDS double buffer: store next data to LDS
blockwise_in_copy
.
RunStoreRegister
Clipboard
(
p_in_register_
clipboard
,
blockwise_in_copy
.
RunStoreRegister
Buffer
(
p_in_register_
buffer
,
p_in_block_double
+
in_block_space
);
p_in_block_double
+
in_block_space
);
blockwise_wei_copy
.
RunStoreRegister
Clipboard
(
p_wei_register_
clipboard
,
blockwise_wei_copy
.
RunStoreRegister
Buffer
(
p_wei_register_
buffer
,
p_wei_block_double
+
wei_block_space
);
p_wei_block_double
+
wei_block_space
);
// odd iteration
// odd iteration
__syncthreads
();
__syncthreads
();
...
@@ -384,19 +432,18 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
...
@@ -384,19 +432,18 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
out_k_n1_b_n2_global_merged_desc
.
GetOffsetFromMultiIndex
(
out_k_n1_b_n2_global_merged_desc
.
GetOffsetFromMultiIndex
(
k_thread_data_on_global
,
0
,
b_thread_data_on_global
,
0
);
k_thread_data_on_global
,
0
,
b_thread_data_on_global
,
0
);
threadwise_generic_tensor_slice_copy_v1
(
ThreadwiseGenericTensorSliceCopy_v1r2
<
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc
,
decltype
(
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc
),
p_out_thread
,
decltype
(
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc
),
{
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
},
decltype
(
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc
.
GetLengths
()),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc
,
arithmetic_sequence_gen
<
0
,
8
,
1
>::
type
,
p_out_thread_on_global
,
7
,
{
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
},
1
,
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc
.
GetLengths
(),
1
>
(
make_zero_array
<
index_t
,
8
>
(),
make_zero_array
<
index_t
,
8
>
())
arithmetic_sequence_gen
<
0
,
8
,
1
>::
type
{},
.
Run
(
p_out_thread
,
p_out_thread_on_global
);
Number
<
1
>
{});
}
}
}
}
};
};
}
// namespace ck
}
// namespace ck
#endif
#endif
// CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_xdlops_fp16_bfp16_nchw_kcyx_nkhw_lds_double_buffer.hpp
0 → 100644
View file @
32850b93
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_FP16_BFP16_NCHW_KCYX_NKHW_HPP_LDS_DOUBLE_BUFFER_HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_FP16_BFP16_NCHW_KCYX_NKHW_HPP_LDS_DOUBLE_BUFFER_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "implicitgemm_params.hpp"
namespace
ck
{
template
<
ImplicitGemmDirection
conv_dir
,
typename
WeiDesc
,
index_t
NonVectorizedC
>
struct
make_vectorized_WeiDesc_Xdlops
{
};
template
<
typename
WeiDesc
,
index_t
NonVectorizedC
>
struct
make_vectorized_WeiDesc_Xdlops
<
ImplicitGemmDirection
::
ForwardData
,
WeiDesc
,
NonVectorizedC
>
{
__device__
constexpr
auto
get
(
WeiDesc
&
)
{
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I4
=
Number
<
4
>
{};
return
WeiDesc
{}
.
Fold
(
I1
,
Number
<
NonVectorizedC
>
{})
.
Unfold
(
I2
,
I4
)
.
ReorderGivenNew2Old
(
Sequence
<
2
,
0
,
1
>
{});
}
};
template
<
typename
WeiDesc
,
index_t
NonVectorizedC
>
struct
make_vectorized_WeiDesc_Xdlops
<
ImplicitGemmDirection
::
BackwardWeight
,
WeiDesc
,
NonVectorizedC
>
{
__device__
constexpr
auto
get
(
WeiDesc
&
desc
)
{
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I4
=
Number
<
4
>
{};
return
make_ConstantMergedTensorDescriptor
(
desc
.
Fold
(
I1
,
Number
<
NonVectorizedC
>
{}).
Unfold
(
I3
,
I4
),
Sequence
<
2
,
3
>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{});
}
};
// B = merge(N, Ho, Wo)
template
<
index_t
GridSize
,
index_t
BlockSize
,
class
Float
,
class
AccDataType
,
class
InGlobalDesc
,
class
WeiGlobalDesc
,
class
OutGlobalDesc
,
class
ConvStrides
,
class
ConvDilations
,
index_t
BPerBlock
,
index_t
KPerBlock
,
index_t
EPerBlock
,
index_t
EPack
,
index_t
GemmMPerWave
,
index_t
GemmNPerWave
,
index_t
GemmMWaves
,
index_t
GemmNWaves
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadB
,
bool
EnableXdlops
,
class
InBlockCopySubLengths_E_B
,
class
InBlockCopyClusterLengths_E_B
,
class
InBlockCopyThreadClusterArrangeOrder
,
class
InBlockCopySrcAccessOrder
,
class
InBlockCopyDstAccessOrder
,
index_t
InBlockCopyDataPerAccess_B
,
class
WeiBlockCopySubLengths_E_K
,
class
WeiBlockCopyClusterLengths_E_K
,
class
WeiBlockCopyThreadClusterArrangeOrder
,
class
WeiBlockCopySrcAccessOrder
,
class
WeiBlockCopyDstAccessOrder
,
index_t
WeiBlockCopySrcDataPerRead_E
,
index_t
WeiBlockCopyDstDataPerWrite_K
,
index_t
OutThreadCopyDataPerAccess_B
,
ImplicitGemmDirection
conv_dir
>
struct
GridwiseConvolutionImplicitGemm_v4r4_xdlops_fp16_bfp16_nchw_kcyx_nkhw_lds_double_buffer
{
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
in_n_c_h_w_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_k_c_y_x_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_n_k_h_w_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
N
=
in_n_c_h_w_global_desc
.
GetLengths
()[
0
];
constexpr
index_t
C
=
in_n_c_h_w_global_desc
.
GetLengths
()[
1
];
constexpr
index_t
K
=
out_n_k_h_w_global_desc
.
GetLengths
()[
1
];
constexpr
index_t
Ho
=
out_n_k_h_w_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wo
=
out_n_k_h_w_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
Y
=
wei_k_c_y_x_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
X
=
wei_k_c_y_x_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
constexpr
index_t
nonVectorizedC
=
C
/
EPack
;
constexpr
index_t
E
=
nonVectorizedC
*
Y
*
X
;
constexpr
index_t
B
=
N
*
Ho
*
Wo
;
static_assert
((
X
==
1
||
ConvDilationW
%
InBlockCopyDataPerAccess_B
==
0
),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated"
);
// divide block work by [K, B]
static_assert
(
K
%
KPerBlock
==
0
&&
B
%
BPerBlock
==
0
&&
E
%
(
2
*
EPerBlock
)
==
0
,
"wrong! cannot divide work evenly among block"
);
constexpr
index_t
KBlockWork
=
K
/
KPerBlock
;
constexpr
index_t
BBlockWork
=
B
/
BPerBlock
;
constexpr
auto
block_work_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
KBlockWork
,
BBlockWork
>
{});
const
auto
block_work_multi_id
=
block_work_desc
.
GetMultiIndexFrom1dIndex
(
get_block_1d_id
());
const
index_t
k_block_data_on_global
=
block_work_multi_id
[
0
]
*
KPerBlock
;
const
index_t
b_block_data_on_global
=
block_work_multi_id
[
1
]
*
BPerBlock
;
// input tensor
// tensor descriptor in device memory [N, Ho, Wo, {2C/4C}]
constexpr
auto
in_n_ho_wo_2cor4c_global_desc
=
in_n_c_h_w_global_desc
.
StridedSlice
(
I2
,
Number
<
Ho
>
{},
Number
<
ConvStrideH
>
{})
.
StridedSlice
(
I3
,
Number
<
Wo
>
{},
Number
<
ConvStrideW
>
{})
.
Fold
(
I1
,
Number
<
nonVectorizedC
>
{})
.
Extract
(
Sequence
<
0
,
1
,
3
,
4
>
{})
.
ReorderGivenNew2Old
(
Sequence
<
0
,
2
,
3
,
1
>
{});
// batch descritpor for device memory
constexpr
auto
in_c_y_x_global_desc
=
in_n_c_h_w_global_desc
.
StridedSlice
(
I2
,
Number
<
Y
>
{},
Number
<
ConvDilationH
>
{})
.
StridedSlice
(
I3
,
Number
<
X
>
{},
Number
<
ConvDilationW
>
{})
.
Fold
(
I1
,
Number
<
nonVectorizedC
>
{})
.
Extract
(
Sequence
<
2
,
3
,
4
>
{});
// merged tensor descriptor in device memory [E, B], src of blockwise copy
constexpr
auto
in_e_b_global_desc
=
make_ConstantMergedTensorDescriptor
(
in_c_y_x_global_desc
.
Embed
(
in_n_ho_wo_2cor4c_global_desc
),
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
>
{},
Sequence
<
6
>
{});
// memory layout descriptor in LDS [E, B, 2Cor4C], dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
in_e_b_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
EPerBlock
,
BPerBlock
,
EPack
>
{},
Number
<
math
::
lcm
(
InBlockCopyDataPerAccess_B
,
GemmDataPerReadB
,
EPack
)
>
{});
// input blockwise copy
// slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in
auto
blockwise_in_copy
=
BlockwiseGenericTensorSliceCopy_v2
<
BlockSize
,
decltype
(
in_e_b_global_desc
),
decltype
(
in_e_b_block_desc
),
MergedTensorCoordinate
<
decltype
(
in_e_b_global_desc
)
>
,
NormalTensorCoordinate
<
decltype
(
in_e_b_block_desc
)
>
,
decltype
(
in_e_b_block_desc
.
GetLengths
()),
InBlockCopySubLengths_E_B
,
InBlockCopyClusterLengths_E_B
,
InBlockCopyThreadClusterArrangeOrder
,
InBlockCopySrcAccessOrder
,
InBlockCopyDstAccessOrder
,
1
,
// Src dim to be read in vector form (B dimension)
2
,
// Dst dim to be written in vector form (EPack dimension)
InBlockCopyDataPerAccess_B
,
// Src dim vector len
InBlockCopyDataPerAccess_B
>
(
// Dst dim vector len
{
0
,
b_block_data_on_global
,
0
},
{
0
,
0
,
0
});
// weight tensor
// tensor descriptor in device memory, src of blockwise copy
constexpr
auto
wei_e_k_global_desc
=
make_vectorized_WeiDesc_Xdlops
<
conv_dir
,
decltype
(
wei_k_c_y_x_global_desc
),
nonVectorizedC
>
{}
.
get
(
wei_k_c_y_x_global_desc
);
// tensor descriptor in LDS, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
wei_e_k_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
EPerBlock
,
KPerBlock
,
EPack
>
{},
Number
<
math
::
lcm
(
WeiBlockCopyDstDataPerWrite_K
,
GemmDataPerReadA
,
EPack
)
>
{});
// operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in
auto
blockwise_wei_copy
=
BlockwiseGenericTensorSliceCopy_v2
<
BlockSize
,
decltype
(
wei_e_k_global_desc
),
decltype
(
wei_e_k_block_desc
),
NormalTensorCoordinate
<
decltype
(
wei_e_k_global_desc
)
>
,
NormalTensorCoordinate
<
decltype
(
wei_e_k_block_desc
)
>
,
decltype
(
wei_e_k_block_desc
.
GetLengths
()),
WeiBlockCopySubLengths_E_K
,
WeiBlockCopyClusterLengths_E_K
,
WeiBlockCopyThreadClusterArrangeOrder
,
WeiBlockCopySrcAccessOrder
,
WeiBlockCopyDstAccessOrder
,
0
,
// Src dim to be read in vector form (E dimension)
2
,
// Dst dim to be written in vector form (EPack dimension)
WeiBlockCopySrcDataPerRead_E
,
// Src dim vector len
WeiBlockCopyDstDataPerWrite_K
>
(
// Dst dim vector len
{
0
,
k_block_data_on_global
,
0
},
{
0
,
0
,
0
});
// GEMM definition
constexpr
auto
a_e_k_block_mtx_desc
=
make_ConstantMatrixDescriptor_packed
(
Number
<
EPerBlock
>
{},
Number
<
KPerBlock
>
{});
constexpr
auto
b_e_b_block_mtx_desc
=
make_ConstantMatrixDescriptor_packed
(
Number
<
EPerBlock
>
{},
Number
<
BPerBlock
>
{});
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
<
BlockSize
,
decltype
(
a_e_k_block_mtx_desc
),
decltype
(
b_e_b_block_mtx_desc
),
decltype
(
mfma_info
<
Float
>
{}),
EnableXdlops
,
GemmMPerWave
,
GemmNPerWave
,
GemmMWaves
,
GemmNWaves
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{};
constexpr
auto
c_k_thread_mtx_desc
=
blockwise_gemm
.
GetThreadMatrixCDescriptor
();
constexpr
index_t
max_align
=
math
::
lcm
(
InBlockCopyDataPerAccess_B
,
WeiBlockCopyDstDataPerWrite_K
,
GemmDataPerReadA
,
GemmDataPerReadB
,
EPack
);
constexpr
index_t
in_block_space
=
math
::
integer_least_multiple
(
in_e_b_block_desc
.
GetElementSpace
(),
max_align
);
constexpr
index_t
wei_block_space
=
math
::
integer_least_multiple
(
wei_e_k_block_desc
.
GetElementSpace
(),
max_align
);
__shared__
Float
p_in_block_double
[
2
*
in_block_space
];
__shared__
Float
p_wei_block_double
[
2
*
wei_block_space
];
// register allocation for output
AccDataType
p_out_thread
[
c_k_thread_mtx_desc
.
GetElementSpace
()];
// zero out threadwise output
threadwise_matrix_set_zero
(
c_k_thread_mtx_desc
,
p_out_thread
);
// static_if<EnableXdlops>{}(
// [&](auto) { gcnasm_accvgpr_zero<c_k_thread_mtx_desc.GetElementSpace()>(); });
const
Float
*
p_wei_block_on_global
=
p_wei_global
;
// LDS double buffer: preload data into LDS
{
blockwise_in_copy
.
Run
(
p_in_global
,
p_in_block_double
);
blockwise_wei_copy
.
Run
(
p_wei_global
,
p_wei_block_double
);
}
// LDS double buffer: main body
for
(
index_t
e_block_data_begin
=
0
;
e_block_data_begin
+
2
*
EPerBlock
<
E
;
e_block_data_begin
+=
2
*
EPerBlock
)
{
#pragma unroll
for
(
index_t
iloop
=
0
;
iloop
<
2
;
++
iloop
)
{
const
bool
even_loop
=
(
iloop
%
2
==
0
);
Float
*
p_in_block_now
=
even_loop
?
p_in_block_double
:
p_in_block_double
+
in_block_space
;
Float
*
p_wei_block_now
=
even_loop
?
p_wei_block_double
:
p_wei_block_double
+
wei_block_space
;
Float
*
p_in_block_next
=
even_loop
?
p_in_block_double
+
in_block_space
:
p_in_block_double
;
Float
*
p_wei_block_next
=
even_loop
?
p_wei_block_double
+
wei_block_space
:
p_wei_block_double
;
Float
p_in_register_buffer
[
blockwise_in_copy
.
GetRegisterBufferSize
()];
Float
p_wei_register_buffer
[
blockwise_wei_copy
.
GetRegisterBufferSize
()];
blockwise_in_copy
.
MoveSrcSlicingWindow
(
Sequence
<
EPerBlock
,
0
,
0
>
{},
True
);
static_if
<
conv_dir
==
ImplicitGemmDirection
::
BackwardWeight
>
{}([
&
](
auto
)
{
blockwise_wei_copy
.
MoveSrcSlicingWindow
(
Sequence
<
EPerBlock
,
0
,
0
>
{},
True
);
}).
Else
([
&
](
auto
fwd
)
{
p_wei_block_on_global
+=
EPerBlock
*
fwd
(
wei_e_k_global_desc
).
GetStride
(
I0
);
});
__syncthreads
();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy
.
RunLoadRegisterBuffer
(
p_in_global
,
p_in_register_buffer
);
blockwise_wei_copy
.
RunLoadRegisterBuffer
(
p_wei_block_on_global
,
p_wei_register_buffer
);
// LDS double buffer: GEMM on current data
const
typename
vector_type
<
Float
,
EPack
>::
MemoryType
*
p_a_block_vec
=
reinterpret_cast
<
const
typename
vector_type
<
Float
,
EPack
>::
MemoryType
*>
(
p_wei_block_now
);
const
typename
vector_type
<
Float
,
EPack
>::
MemoryType
*
p_b_block_vec
=
reinterpret_cast
<
const
typename
vector_type
<
Float
,
EPack
>::
MemoryType
*>
(
p_in_block_now
);
blockwise_gemm
.
Run
(
p_a_block_vec
,
p_b_block_vec
,
p_out_thread
);
// LDS double buffer: store next data to LDS
blockwise_in_copy
.
RunStoreRegisterBuffer
(
p_in_register_buffer
,
p_in_block_next
);
blockwise_wei_copy
.
RunStoreRegisterBuffer
(
p_wei_register_buffer
,
p_wei_block_next
);
}
}
// LDS double buffer: tail
{
Float
p_in_register_buffer
[
blockwise_in_copy
.
GetRegisterBufferSize
()];
Float
p_wei_register_buffer
[
blockwise_wei_copy
.
GetRegisterBufferSize
()];
// even iteration
blockwise_in_copy
.
MoveSrcSlicingWindow
(
Sequence
<
EPerBlock
,
0
,
0
>
{},
True
);
static_if
<
conv_dir
==
ImplicitGemmDirection
::
BackwardWeight
>
{}([
&
](
auto
)
{
blockwise_wei_copy
.
MoveSrcSlicingWindow
(
Sequence
<
EPerBlock
,
0
,
0
>
{},
True
);
}).
Else
([
&
](
auto
fwd
)
{
p_wei_block_on_global
+=
EPerBlock
*
fwd
(
wei_e_k_global_desc
).
GetStride
(
I0
);
});
__syncthreads
();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy
.
RunLoadRegisterBuffer
(
p_in_global
,
p_in_register_buffer
);
blockwise_wei_copy
.
RunLoadRegisterBuffer
(
p_wei_block_on_global
,
p_wei_register_buffer
);
// LDS double buffer: GEMM on current data
// Vectorize the pointer to match with how half/bfloat16 datatypes are
// processed in gemm operation. Half type packs 4 half values while
// bfloat16 packs 2 bfloat16 values. Since gemm's matrix A and B
// 2D indexes are computed with a single value in mind (e.g. float),
// to retain the same 2D indexes for half/bfloat16, we recast datatype
// from a single half to 4 packed half/2 packed bfloat16 respectively.
const
typename
vector_type
<
Float
,
EPack
>::
MemoryType
*
p_a_block_vec
=
reinterpret_cast
<
const
typename
vector_type
<
Float
,
EPack
>::
MemoryType
*>
(
p_wei_block_double
);
const
typename
vector_type
<
Float
,
EPack
>::
MemoryType
*
p_b_block_vec
=
reinterpret_cast
<
const
typename
vector_type
<
Float
,
EPack
>::
MemoryType
*>
(
p_in_block_double
);
blockwise_gemm
.
Run
(
p_a_block_vec
,
p_b_block_vec
,
p_out_thread
);
// LDS double buffer: store next data to LDS
blockwise_in_copy
.
RunStoreRegisterBuffer
(
p_in_register_buffer
,
p_in_block_double
+
in_block_space
);
blockwise_wei_copy
.
RunStoreRegisterBuffer
(
p_wei_register_buffer
,
p_wei_block_double
+
wei_block_space
);
// odd iteration
__syncthreads
();
p_a_block_vec
=
reinterpret_cast
<
const
typename
vector_type
<
Float
,
EPack
>::
MemoryType
*>
(
p_wei_block_double
+
wei_block_space
);
p_b_block_vec
=
reinterpret_cast
<
const
typename
vector_type
<
Float
,
EPack
>::
MemoryType
*>
(
p_in_block_double
+
in_block_space
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_a_block_vec
,
p_b_block_vec
,
p_out_thread
);
}
// load data from xldop_acc_regs
// static_if<EnableXdlops>{}([&](auto) {
// gcnasm_accvgpr_read<c_k_thread_mtx_desc.GetElementSpace()>(p_out_thread);
// });
// copy output: register to global memory
{
constexpr
index_t
K2
=
blockwise_gemm
.
OutputLayout
.
M2
;
constexpr
index_t
K1
=
blockwise_gemm
.
OutputLayout
.
M1
;
constexpr
index_t
K0
=
blockwise_gemm
.
OutputLayout
.
M0
;
// This is a hack, because slicing a merged dimension is not supported yet.
// dst descriptor
constexpr
auto
out_k0_k1_k2_b_global_desc
=
make_ConstantMergedTensorDescriptor
(
out_n_k_h_w_global_desc
.
Fold
(
I1
,
Number
<
K1
>
{},
Number
<
K2
>
{}),
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
0
,
4
,
5
>
{});
// src descriptor
constexpr
auto
out_k0_k1_k2_b_thread_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
K2
,
1
,
K0
,
1
>
{});
using
OutThreadCopySliceLengths
=
Sequence
<
K2
,
1
,
K0
,
1
>
;
constexpr
index_t
NumKPerBlk
=
out_k0_k1_k2_b_thread_desc
.
GetElementSpace
();
constexpr
index_t
NumBlks
=
GemmMPerWave
/
NumKPerBlk
;
for
(
index_t
i
=
0
;
i
<
NumBlks
;
++
i
)
{
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
i
);
const
index_t
k_thread_data_on_global
=
k_block_data_on_global
+
c_thread_mtx_on_block
.
row
;
const
index_t
b_thread_data_on_global
=
b_block_data_on_global
+
c_thread_mtx_on_block
.
col
;
auto
threadwise_out_copy
=
ThreadwiseGenericTensorSliceCopy_v2r1
<
decltype
(
out_k0_k1_k2_b_thread_desc
),
decltype
(
out_k0_k1_k2_b_global_desc
),
NormalTensorCoordinate
<
decltype
(
out_k0_k1_k2_b_thread_desc
)
>
,
MergedTensorCoordinate
<
decltype
(
out_k0_k1_k2_b_global_desc
)
>
,
OutThreadCopySliceLengths
,
arithmetic_sequence_gen
<
0
,
4
,
1
>::
type
,
arithmetic_sequence_gen
<
0
,
4
,
1
>::
type
,
3
,
// Src dim to be read in vector form (B dimension)
3
,
// Dst dim to be written in vector form (B dimension)
OutThreadCopyDataPerAccess_B
,
// Src dim vector len
OutThreadCopyDataPerAccess_B
>
(
// Dst dim vector len
{
0
,
0
,
0
,
0
},
{
k_thread_data_on_global
/
(
K0
*
K1
),
k_thread_data_on_global
%
(
K0
*
K1
)
/
K0
,
k_thread_data_on_global
%
K0
,
b_thread_data_on_global
});
threadwise_out_copy
.
Run
(
p_out_thread
+
i
*
NumKPerBlk
,
p_out_global
);
}
}
}
};
}
// namespace ck
#endif
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_xdlops_nchw_kc1x1_nkhw_lds_double_buffer.hpp
0 → 100644
View file @
32850b93
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KC1X1_NKHW_HPP_LDS_DOUBLE_BUFFER_HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KC1X1_NKHW_HPP_LDS_DOUBLE_BUFFER_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "implicitgemm_params.hpp"
namespace
ck
{
// B = merge(N, Ho, Wo)
template
<
index_t
GridSize
,
index_t
BlockSize
,
class
Float
,
class
InGlobalDesc
,
class
WeiGlobalDesc
,
class
OutGlobalDesc
,
class
ConvStrides
,
ImplicitGemmDirection
Direction
,
index_t
BPerBlock
,
index_t
KPerBlock
,
index_t
EPerBlock
,
index_t
GemmMPerWave
,
index_t
GemmNPerWave
,
index_t
GemmMWaves
,
index_t
GemmNWaves
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadB
,
bool
EnableXdlops
,
class
InBlockCopySubLengths_E_B
,
class
InBlockCopyClusterLengths_E_B
,
class
InBlockCopyThreadClusterArrangeOrder
,
class
InBlockCopySrcAccessOrder
,
class
InBlockCopyDstAccessOrder
,
index_t
InBlockCopyDataPerAccess_B
,
class
WeiBlockCopySubLengths_E_K
,
class
WeiBlockCopyClusterLengths_E_K
,
class
WeiBlockCopyThreadClusterArrangeOrder
,
class
WeiBlockCopySrcAccessOrder
,
class
WeiBlockCopyDstAccessOrder
,
index_t
WeiBlockCopySrcDataPerRead_E
,
index_t
WeiBlockCopyDstDataPerWrite_K
,
index_t
OutThreadCopyDataPerAccess_B
>
struct
GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kc1x1_nkhw_lds_double_buffer
{
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
in_n_c_h_w_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_c_k_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_n_k_h_w_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
N
=
in_n_c_h_w_global_desc
.
GetLengths
()[
0
];
constexpr
index_t
C
=
in_n_c_h_w_global_desc
.
GetLengths
()[
1
];
constexpr
bool
isForward
=
Direction
==
ImplicitGemmDirection
::
ForwardData
;
constexpr
index_t
K
=
out_n_k_h_w_global_desc
.
GetLengths
()[
1
];
constexpr
index_t
Ho
=
std
::
conditional
<
isForward
,
decltype
(
out_n_k_h_w_global_desc
),
decltype
(
in_n_c_h_w_global_desc
)
>::
type
::
GetLength
(
I2
);
constexpr
index_t
Wo
=
std
::
conditional
<
isForward
,
decltype
(
out_n_k_h_w_global_desc
),
decltype
(
in_n_c_h_w_global_desc
)
>::
type
::
GetLength
(
I3
);
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
E
=
C
;
constexpr
index_t
B
=
N
*
Ho
*
Wo
;
// divide block work by [K, B]
static_assert
(
K
%
KPerBlock
==
0
&&
B
%
BPerBlock
==
0
&&
E
%
(
2
*
EPerBlock
)
==
0
,
"wrong! cannot divide work evenly among block"
);
constexpr
index_t
KBlockWork
=
K
/
KPerBlock
;
constexpr
index_t
BBlockWork
=
B
/
BPerBlock
;
constexpr
auto
block_work_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
KBlockWork
,
BBlockWork
>
{});
const
auto
block_work_multi_id
=
block_work_desc
.
GetMultiIndexFrom1dIndex
(
get_block_1d_id
());
const
index_t
k_block_data_on_global
=
block_work_multi_id
[
0
]
*
KPerBlock
;
const
index_t
b_block_data_on_global
=
block_work_multi_id
[
1
]
*
BPerBlock
;
// input tensor
// tensor descriptor in device memory [N, Ho, Wo]
constexpr
auto
in_n_ho_wo_global_desc_forw
=
in_n_c_h_w_global_desc
.
Extract
(
I0
,
I2
,
I3
)
.
StridedSlice
(
I1
,
Number
<
Ho
>
{},
Number
<
ConvStrideH
>
{})
.
StridedSlice
(
I2
,
Number
<
Wo
>
{},
Number
<
ConvStrideW
>
{});
constexpr
auto
in_n_ho_wo_global_desc_back
=
in_n_c_h_w_global_desc
.
Extract
(
I0
,
I2
,
I3
);
constexpr
auto
in_n_ho_wo_global_desc
=
typename
std
::
conditional
<
isForward
,
decltype
(
in_n_ho_wo_global_desc_forw
),
decltype
(
in_n_ho_wo_global_desc_back
)
>::
type
{};
// batch descritpor for device memory
constexpr
auto
in_c_global_desc
=
in_n_c_h_w_global_desc
.
Extract
(
I1
);
// merged tensor descriptor in device memory [E, B], src of blockwise copy
constexpr
auto
in_e_b_global_desc
=
make_ConstantMergedTensorDescriptor
(
in_c_global_desc
.
Embed
(
in_n_ho_wo_global_desc
),
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{});
// memory layout descriptor in LDS [E, B], dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
in_e_b_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
EPerBlock
,
BPerBlock
>
{},
Number
<
math
::
lcm
(
InBlockCopyDataPerAccess_B
,
GemmDataPerReadB
)
>
{});
// input blockwise copy
// slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in
auto
blockwise_in_copy
=
BlockwiseGenericTensorSliceCopy_v2
<
BlockSize
,
decltype
(
in_e_b_global_desc
),
decltype
(
in_e_b_block_desc
),
MergedTensorCoordinate
<
decltype
(
in_e_b_global_desc
)
>
,
NormalTensorCoordinate
<
decltype
(
in_e_b_block_desc
)
>
,
decltype
(
in_e_b_block_desc
.
GetLengths
()),
InBlockCopySubLengths_E_B
,
InBlockCopyClusterLengths_E_B
,
InBlockCopyThreadClusterArrangeOrder
,
InBlockCopySrcAccessOrder
,
InBlockCopyDstAccessOrder
,
1
,
1
,
InBlockCopyDataPerAccess_B
,
InBlockCopyDataPerAccess_B
>
(
{
0
,
b_block_data_on_global
},
{
0
,
0
});
// weight tensor
// tensor descriptor in device memory, src of blockwise copy
constexpr
auto
wei_e_k_global_desc_forw
=
wei_c_k_global_desc
;
constexpr
auto
wei_e_k_global_desc_back
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
C
,
K
>
{});
constexpr
auto
wei_e_k_global_desc
=
typename
std
::
conditional
<
isForward
,
decltype
(
wei_e_k_global_desc_forw
),
decltype
(
wei_e_k_global_desc_back
)
>::
type
{};
// tensor descriptor in LDS, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
wei_e_k_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
EPerBlock
,
KPerBlock
>
{},
Number
<
math
::
lcm
(
WeiBlockCopyDstDataPerWrite_K
,
GemmDataPerReadA
)
>
{});
// operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in
auto
blockwise_wei_copy
=
BlockwiseGenericTensorSliceCopy_v2
<
BlockSize
,
decltype
(
wei_e_k_global_desc
),
decltype
(
wei_e_k_block_desc
),
NormalTensorCoordinate
<
decltype
(
wei_e_k_global_desc
)
>
,
NormalTensorCoordinate
<
decltype
(
wei_e_k_block_desc
)
>
,
decltype
(
wei_e_k_block_desc
.
GetLengths
()),
WeiBlockCopySubLengths_E_K
,
WeiBlockCopyClusterLengths_E_K
,
WeiBlockCopyThreadClusterArrangeOrder
,
WeiBlockCopySrcAccessOrder
,
WeiBlockCopyDstAccessOrder
,
0
,
1
,
WeiBlockCopySrcDataPerRead_E
,
WeiBlockCopyDstDataPerWrite_K
>
({
0
,
k_block_data_on_global
},
{
0
,
0
});
// GEMM definition
constexpr
auto
a_e_k_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
wei_e_k_block_desc
);
constexpr
auto
b_e_b_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
in_e_b_block_desc
);
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
<
BlockSize
,
decltype
(
a_e_k_block_mtx_desc
),
decltype
(
b_e_b_block_mtx_desc
),
decltype
(
mfma_info
<
float
>
{}),
EnableXdlops
,
GemmMPerWave
,
GemmNPerWave
,
GemmMWaves
,
GemmNWaves
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{};
constexpr
auto
c_k_thread_mtx_desc
=
blockwise_gemm
.
GetThreadMatrixCDescriptor
();
constexpr
index_t
max_align
=
math
::
lcm
(
InBlockCopyDataPerAccess_B
,
WeiBlockCopyDstDataPerWrite_K
,
GemmDataPerReadA
,
GemmDataPerReadB
);
constexpr
index_t
in_block_space
=
math
::
integer_least_multiple
(
in_e_b_block_desc
.
GetElementSpace
(),
max_align
);
constexpr
index_t
wei_block_space
=
math
::
integer_least_multiple
(
wei_e_k_block_desc
.
GetElementSpace
(),
max_align
);
__shared__
Float
p_in_block_double
[
2
*
in_block_space
];
__shared__
Float
p_wei_block_double
[
2
*
wei_block_space
];
// register allocation for output
Float
p_out_thread
[
c_k_thread_mtx_desc
.
GetElementSpace
()];
// zero out threadwise output
threadwise_matrix_set_zero
(
c_k_thread_mtx_desc
,
p_out_thread
);
static_if
<
EnableXdlops
>
{}(
[
&
](
auto
)
{
gcnasm_accvgpr_zero
<
c_k_thread_mtx_desc
.
GetElementSpace
()
>
();
});
const
Float
*
p_wei_block_on_global
=
p_wei_global
;
// LDS double buffer: preload data into LDS
{
blockwise_in_copy
.
Run
(
p_in_global
,
p_in_block_double
);
blockwise_wei_copy
.
Run
(
p_wei_global
,
p_wei_block_double
);
}
// LDS double buffer: main body
for
(
index_t
e_block_data_begin
=
0
;
e_block_data_begin
+
2
*
EPerBlock
<
E
;
e_block_data_begin
+=
2
*
EPerBlock
)
{
#pragma unroll
for
(
index_t
iloop
=
0
;
iloop
<
2
;
++
iloop
)
{
const
bool
even_loop
=
(
iloop
%
2
==
0
);
Float
*
p_in_block_now
=
even_loop
?
p_in_block_double
:
p_in_block_double
+
in_block_space
;
Float
*
p_wei_block_now
=
even_loop
?
p_wei_block_double
:
p_wei_block_double
+
wei_block_space
;
Float
*
p_in_block_next
=
even_loop
?
p_in_block_double
+
in_block_space
:
p_in_block_double
;
Float
*
p_wei_block_next
=
even_loop
?
p_wei_block_double
+
wei_block_space
:
p_wei_block_double
;
Float
p_in_register_buffer
[
blockwise_in_copy
.
GetRegisterBufferSize
()];
Float
p_wei_register_buffer
[
blockwise_wei_copy
.
GetRegisterBufferSize
()];
blockwise_in_copy
.
MoveSrcSlicingWindow
(
Sequence
<
EPerBlock
,
0
>
{},
True
);
p_wei_block_on_global
+=
EPerBlock
*
wei_e_k_global_desc
.
GetStrides
()[
0
];
__syncthreads
();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy
.
RunLoadRegisterBuffer
(
p_in_global
,
p_in_register_buffer
);
blockwise_wei_copy
.
RunLoadRegisterBuffer
(
p_wei_block_on_global
,
p_wei_register_buffer
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_wei_block_now
,
p_in_block_now
,
p_out_thread
);
// LDS double buffer: store next data to LDS
blockwise_in_copy
.
RunStoreRegisterBuffer
(
p_in_register_buffer
,
p_in_block_next
);
blockwise_wei_copy
.
RunStoreRegisterBuffer
(
p_wei_register_buffer
,
p_wei_block_next
);
}
}
// LDS double buffer: tail
{
Float
p_in_register_buffer
[
blockwise_in_copy
.
GetRegisterBufferSize
()];
Float
p_wei_register_buffer
[
blockwise_wei_copy
.
GetRegisterBufferSize
()];
// even iteration
blockwise_in_copy
.
MoveSrcSlicingWindow
(
Sequence
<
EPerBlock
,
0
>
{},
True
);
p_wei_block_on_global
+=
EPerBlock
*
wei_e_k_global_desc
.
GetStrides
()[
0
];
__syncthreads
();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy
.
RunLoadRegisterBuffer
(
p_in_global
,
p_in_register_buffer
);
blockwise_wei_copy
.
RunLoadRegisterBuffer
(
p_wei_block_on_global
,
p_wei_register_buffer
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_wei_block_double
,
p_in_block_double
,
p_out_thread
);
// LDS double buffer: store next data to LDS
blockwise_in_copy
.
RunStoreRegisterBuffer
(
p_in_register_buffer
,
p_in_block_double
+
in_block_space
);
blockwise_wei_copy
.
RunStoreRegisterBuffer
(
p_wei_register_buffer
,
p_wei_block_double
+
wei_block_space
);
// odd iteration
__syncthreads
();
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_wei_block_double
+
wei_block_space
,
p_in_block_double
+
in_block_space
,
p_out_thread
);
}
// load data from xldop_acc_regs
static_if
<
EnableXdlops
>
{}([
&
](
auto
)
{
gcnasm_accvgpr_read
<
c_k_thread_mtx_desc
.
GetElementSpace
()
>
(
p_out_thread
);
});
// copy output: register to global memory
{
constexpr
index_t
K2
=
blockwise_gemm
.
OutputLayout
.
M2
;
constexpr
index_t
K1
=
blockwise_gemm
.
OutputLayout
.
M1
;
constexpr
index_t
K0
=
blockwise_gemm
.
OutputLayout
.
M0
;
constexpr
auto
out_n_k_h_w_global_desc_forw
=
out_n_k_h_w_global_desc
;
constexpr
auto
out_lengths_back
=
Sequence
<
out_n_k_h_w_global_desc
.
GetLength
(
I0
),
out_n_k_h_w_global_desc
.
GetLength
(
I1
),
math
::
integer_divide_ceil
(
out_n_k_h_w_global_desc
.
GetLength
(
I2
),
ConvStrides
{}.
Get
(
I0
)),
math
::
integer_divide_ceil
(
out_n_k_h_w_global_desc
.
GetLength
(
I3
),
ConvStrides
{}.
Get
(
I1
))
>
{};
constexpr
auto
out_strides_back
=
Sequence
<
out_n_k_h_w_global_desc
.
GetStride
(
I0
),
out_n_k_h_w_global_desc
.
GetStride
(
I1
),
out_n_k_h_w_global_desc
.
GetStride
(
I2
)
*
ConvStrides
{}.
Get
(
I0
),
out_n_k_h_w_global_desc
.
GetStride
(
I3
)
*
ConvStrides
{}.
Get
(
I1
)
>
{};
constexpr
auto
out_n_k_h_w_global_desc_back
=
make_ConstantTensorDescriptor
(
out_lengths_back
,
out_strides_back
);
constexpr
auto
out_n_k_h_w_global_desc_new
=
typename
std
::
conditional
<
isForward
,
decltype
(
out_n_k_h_w_global_desc_forw
),
decltype
(
out_n_k_h_w_global_desc_back
)
>::
type
{};
// This is a hack, because slicing a merged dimension is not supported yet.
// dst descriptor
constexpr
auto
out_k0_k1_k2_b_global_desc
=
make_ConstantMergedTensorDescriptor
(
out_n_k_h_w_global_desc_new
.
Fold
(
I1
,
Number
<
K1
>
{},
Number
<
K2
>
{}),
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
0
,
4
,
5
>
{});
// src descriptor
constexpr
auto
out_k0_k1_k2_b_thread_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
K2
,
1
,
K0
,
1
>
{});
using
OutThreadCopySliceLengths
=
Sequence
<
K2
,
1
,
K0
,
1
>
;
constexpr
index_t
NumKPerBlk
=
out_k0_k1_k2_b_thread_desc
.
GetElementSpace
();
constexpr
index_t
NumBlks
=
GemmMPerWave
/
NumKPerBlk
;
for
(
index_t
i
=
0
;
i
<
NumBlks
;
++
i
)
{
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
i
);
const
index_t
k_thread_data_on_global
=
k_block_data_on_global
+
c_thread_mtx_on_block
.
row
;
const
index_t
b_thread_data_on_global
=
b_block_data_on_global
+
c_thread_mtx_on_block
.
col
;
auto
threadwise_out_copy
=
ThreadwiseGenericTensorSliceCopy_v2r1
<
decltype
(
out_k0_k1_k2_b_thread_desc
),
decltype
(
out_k0_k1_k2_b_global_desc
),
NormalTensorCoordinate
<
decltype
(
out_k0_k1_k2_b_thread_desc
)
>
,
MergedTensorCoordinate
<
decltype
(
out_k0_k1_k2_b_global_desc
)
>
,
OutThreadCopySliceLengths
,
arithmetic_sequence_gen
<
0
,
4
,
1
>::
type
,
arithmetic_sequence_gen
<
0
,
4
,
1
>::
type
,
3
,
3
,
OutThreadCopyDataPerAccess_B
,
OutThreadCopyDataPerAccess_B
>
({
0
,
0
,
0
,
0
},
{
k_thread_data_on_global
/
(
K0
*
K1
),
k_thread_data_on_global
%
(
K0
*
K1
)
/
K0
,
k_thread_data_on_global
%
K0
,
b_thread_data_on_global
});
threadwise_out_copy
.
Run
(
p_out_thread
+
i
*
NumKPerBlk
,
p_out_global
);
}
}
}
};
}
// namespace ck
#endif
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw_lds_double_buffer.hpp
0 → 100644
View file @
32850b93
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_XDLOPS_NCHW_KCYX_NKHW_HPP_LDS_DOUBLE_BUFFER_HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_XDLOPS_NCHW_KCYX_NKHW_HPP_LDS_DOUBLE_BUFFER_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "implicitgemm_params.hpp"
namespace
ck
{
template
<
ImplicitGemmDirection
conv_dir
,
typename
WeiDesc
>
struct
make_WeiDesc_Xdlops
{
};
template
<
typename
WeiDesc
>
struct
make_WeiDesc_Xdlops
<
ImplicitGemmDirection
::
ForwardData
,
WeiDesc
>
{
__device__
constexpr
auto
get
(
WeiDesc
&
)
{
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
return
WeiDesc
{}.
Unfold
(
I1
,
I3
).
ReorderGivenNew2Old
(
Sequence
<
1
,
0
>
{});
}
};
template
<
typename
WeiDesc
>
struct
make_WeiDesc_Xdlops
<
ImplicitGemmDirection
::
BackwardWeight
,
WeiDesc
>
{
__device__
constexpr
auto
get
(
WeiDesc
&
desc
)
{
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
return
make_ConstantMergedTensorDescriptor
(
desc
.
Unfold
(
I2
,
I3
),
Sequence
<
1
,
2
>
{},
Sequence
<
0
>
{});
}
};
// B = merge(N, Ho, Wo)
template
<
index_t
GridSize
,
index_t
BlockSize
,
class
Float
,
class
AccDataType
,
class
InGlobalDesc
,
class
WeiGlobalDesc
,
class
OutGlobalDesc
,
class
ConvStrides
,
class
ConvDilations
,
index_t
BPerBlock
,
index_t
KPerBlock
,
index_t
EPerBlock
,
index_t
EPack
,
index_t
GemmMPerWave
,
index_t
GemmNPerWave
,
index_t
GemmMWaves
,
index_t
GemmNWaves
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadB
,
bool
EnableXdlops
,
class
InBlockCopySubLengths_E_B
,
class
InBlockCopyClusterLengths_E_B
,
class
InBlockCopyThreadClusterArrangeOrder
,
class
InBlockCopySrcAccessOrder
,
class
InBlockCopyDstAccessOrder
,
index_t
InBlockCopyDataPerAccess_B
,
class
WeiBlockCopySubLengths_E_K
,
class
WeiBlockCopyClusterLengths_E_K
,
class
WeiBlockCopyThreadClusterArrangeOrder
,
class
WeiBlockCopySrcAccessOrder
,
class
WeiBlockCopyDstAccessOrder
,
index_t
WeiBlockCopySrcDataPerRead_E
,
index_t
WeiBlockCopyDstDataPerWrite_K
,
index_t
OutThreadCopyDataPerAccess_B
,
ImplicitGemmDirection
conv_dir
>
struct
GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kcyx_nkhw_lds_double_buffer
{
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
in_n_c_h_w_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_k_c_y_x_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_n_k_h_w_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
N
=
in_n_c_h_w_global_desc
.
GetLengths
()[
0
];
constexpr
index_t
C
=
in_n_c_h_w_global_desc
.
GetLengths
()[
1
];
constexpr
index_t
K
=
out_n_k_h_w_global_desc
.
GetLengths
()[
1
];
constexpr
index_t
Ho
=
out_n_k_h_w_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wo
=
out_n_k_h_w_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
Y
=
wei_k_c_y_x_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
X
=
wei_k_c_y_x_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
constexpr
index_t
nonVectorizedC
=
C
/
EPack
;
constexpr
index_t
E
=
nonVectorizedC
*
Y
*
X
;
constexpr
index_t
B
=
N
*
Ho
*
Wo
;
static_assert
((
X
==
1
||
ConvDilationW
%
InBlockCopyDataPerAccess_B
==
0
),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated"
);
// divide block work by [K, B]
static_assert
(
K
%
KPerBlock
==
0
&&
B
%
BPerBlock
==
0
&&
E
%
(
2
*
EPerBlock
)
==
0
,
"wrong! cannot divide work evenly among block"
);
constexpr
index_t
KBlockWork
=
K
/
KPerBlock
;
constexpr
index_t
BBlockWork
=
B
/
BPerBlock
;
constexpr
auto
block_work_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
KBlockWork
,
BBlockWork
>
{});
const
auto
block_work_multi_id
=
block_work_desc
.
GetMultiIndexFrom1dIndex
(
get_block_1d_id
());
const
index_t
k_block_data_on_global
=
block_work_multi_id
[
0
]
*
KPerBlock
;
const
index_t
b_block_data_on_global
=
block_work_multi_id
[
1
]
*
BPerBlock
;
// input tensor
// tensor descriptor in device memory [N, Ho, Wo]
constexpr
auto
in_n_ho_wo_global_desc
=
in_n_c_h_w_global_desc
.
Extract
(
I0
,
I2
,
I3
)
.
StridedSlice
(
I1
,
Number
<
Ho
>
{},
Number
<
ConvStrideH
>
{})
.
StridedSlice
(
I2
,
Number
<
Wo
>
{},
Number
<
ConvStrideW
>
{});
// batch descritpor for device memory
constexpr
auto
in_c_y_x_global_desc
=
in_n_c_h_w_global_desc
.
StridedSlice
(
I2
,
Number
<
Y
>
{},
Number
<
ConvDilationH
>
{})
.
StridedSlice
(
I3
,
Number
<
X
>
{},
Number
<
ConvDilationW
>
{})
.
Extract
(
Sequence
<
1
,
2
,
3
>
{});
// merged tensor descriptor in device memory [E, B], src of blockwise copy
constexpr
auto
in_e_b_global_desc
=
make_ConstantMergedTensorDescriptor
(
in_c_y_x_global_desc
.
Embed
(
in_n_ho_wo_global_desc
),
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
>
{});
// memory layout descriptor in LDS [E, B], dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
in_e_b_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
EPerBlock
,
BPerBlock
>
{},
Number
<
math
::
lcm
(
InBlockCopyDataPerAccess_B
,
GemmDataPerReadB
)
>
{});
// input blockwise copy
// slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in
auto
blockwise_in_copy
=
BlockwiseGenericTensorSliceCopy_v2
<
BlockSize
,
decltype
(
in_e_b_global_desc
),
decltype
(
in_e_b_block_desc
),
MergedTensorCoordinate
<
decltype
(
in_e_b_global_desc
)
>
,
NormalTensorCoordinate
<
decltype
(
in_e_b_block_desc
)
>
,
decltype
(
in_e_b_block_desc
.
GetLengths
()),
InBlockCopySubLengths_E_B
,
InBlockCopyClusterLengths_E_B
,
InBlockCopyThreadClusterArrangeOrder
,
InBlockCopySrcAccessOrder
,
InBlockCopyDstAccessOrder
,
1
,
1
,
InBlockCopyDataPerAccess_B
,
InBlockCopyDataPerAccess_B
>
(
{
0
,
b_block_data_on_global
},
{
0
,
0
});
// weight tensor
// tensor descriptor in device memory, src of blockwise copy
constexpr
auto
wei_e_k_global_desc
=
make_WeiDesc_Xdlops
<
conv_dir
,
decltype
(
wei_k_c_y_x_global_desc
)
>
{}.
get
(
wei_k_c_y_x_global_desc
);
// tensor descriptor in LDS, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
wei_e_k_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
EPerBlock
,
KPerBlock
>
{},
Number
<
math
::
lcm
(
WeiBlockCopyDstDataPerWrite_K
,
GemmDataPerReadA
)
>
{});
// operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in
auto
blockwise_wei_copy
=
BlockwiseGenericTensorSliceCopy_v2
<
BlockSize
,
decltype
(
wei_e_k_global_desc
),
decltype
(
wei_e_k_block_desc
),
NormalTensorCoordinate
<
decltype
(
wei_e_k_global_desc
)
>
,
NormalTensorCoordinate
<
decltype
(
wei_e_k_block_desc
)
>
,
decltype
(
wei_e_k_block_desc
.
GetLengths
()),
WeiBlockCopySubLengths_E_K
,
WeiBlockCopyClusterLengths_E_K
,
WeiBlockCopyThreadClusterArrangeOrder
,
WeiBlockCopySrcAccessOrder
,
WeiBlockCopyDstAccessOrder
,
0
,
1
,
WeiBlockCopySrcDataPerRead_E
,
WeiBlockCopyDstDataPerWrite_K
>
({
0
,
k_block_data_on_global
},
{
0
,
0
});
// GEMM definition
constexpr
auto
a_e_k_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
wei_e_k_block_desc
);
constexpr
auto
b_e_b_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
in_e_b_block_desc
);
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
<
BlockSize
,
decltype
(
a_e_k_block_mtx_desc
),
decltype
(
b_e_b_block_mtx_desc
),
decltype
(
mfma_info
<
float
>
{}),
EnableXdlops
,
GemmMPerWave
,
GemmNPerWave
,
GemmMWaves
,
GemmNWaves
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{};
constexpr
auto
c_k_thread_mtx_desc
=
blockwise_gemm
.
GetThreadMatrixCDescriptor
();
constexpr
index_t
max_align
=
math
::
lcm
(
InBlockCopyDataPerAccess_B
,
WeiBlockCopyDstDataPerWrite_K
,
GemmDataPerReadA
,
GemmDataPerReadB
);
constexpr
index_t
in_block_space
=
math
::
integer_least_multiple
(
in_e_b_block_desc
.
GetElementSpace
(),
max_align
);
constexpr
index_t
wei_block_space
=
math
::
integer_least_multiple
(
wei_e_k_block_desc
.
GetElementSpace
(),
max_align
);
__shared__
Float
p_in_block_double
[
2
*
in_block_space
];
__shared__
Float
p_wei_block_double
[
2
*
wei_block_space
];
// register allocation for output
AccDataType
p_out_thread
[
c_k_thread_mtx_desc
.
GetElementSpace
()];
// zero out threadwise output
threadwise_matrix_set_zero
(
c_k_thread_mtx_desc
,
p_out_thread
);
// static_if<EnableXdlops>{}(
// [&](auto) { gcnasm_accvgpr_zero<c_k_thread_mtx_desc.GetElementSpace()>(); });
const
Float
*
p_wei_block_on_global
=
p_wei_global
;
// LDS double buffer: preload data into LDS
{
blockwise_in_copy
.
Run
(
p_in_global
,
p_in_block_double
);
blockwise_wei_copy
.
Run
(
p_wei_global
,
p_wei_block_double
);
}
// LDS double buffer: main body
for
(
index_t
e_block_data_begin
=
0
;
e_block_data_begin
+
2
*
EPerBlock
<
E
;
e_block_data_begin
+=
2
*
EPerBlock
)
{
#pragma unroll
for
(
index_t
iloop
=
0
;
iloop
<
2
;
++
iloop
)
{
const
bool
even_loop
=
(
iloop
%
2
==
0
);
Float
*
p_in_block_now
=
even_loop
?
p_in_block_double
:
p_in_block_double
+
in_block_space
;
Float
*
p_wei_block_now
=
even_loop
?
p_wei_block_double
:
p_wei_block_double
+
wei_block_space
;
Float
*
p_in_block_next
=
even_loop
?
p_in_block_double
+
in_block_space
:
p_in_block_double
;
Float
*
p_wei_block_next
=
even_loop
?
p_wei_block_double
+
wei_block_space
:
p_wei_block_double
;
Float
p_in_register_buffer
[
blockwise_in_copy
.
GetRegisterBufferSize
()];
Float
p_wei_register_buffer
[
blockwise_wei_copy
.
GetRegisterBufferSize
()];
blockwise_in_copy
.
MoveSrcSlicingWindow
(
Sequence
<
EPerBlock
,
0
>
{},
True
);
static_if
<
conv_dir
==
ImplicitGemmDirection
::
BackwardWeight
>
{}([
&
](
auto
)
{
blockwise_wei_copy
.
MoveSrcSlicingWindow
(
Sequence
<
EPerBlock
,
0
>
{},
True
);
}).
Else
([
&
](
auto
fwd
)
{
p_wei_block_on_global
+=
EPerBlock
*
fwd
(
wei_e_k_global_desc
).
GetStride
(
I0
);
});
__syncthreads
();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy
.
RunLoadRegisterBuffer
(
p_in_global
,
p_in_register_buffer
);
blockwise_wei_copy
.
RunLoadRegisterBuffer
(
p_wei_block_on_global
,
p_wei_register_buffer
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_wei_block_now
,
p_in_block_now
,
p_out_thread
);
// LDS double buffer: store next data to LDS
blockwise_in_copy
.
RunStoreRegisterBuffer
(
p_in_register_buffer
,
p_in_block_next
);
blockwise_wei_copy
.
RunStoreRegisterBuffer
(
p_wei_register_buffer
,
p_wei_block_next
);
}
}
// LDS double buffer: tail
{
Float
p_in_register_buffer
[
blockwise_in_copy
.
GetRegisterBufferSize
()];
Float
p_wei_register_buffer
[
blockwise_wei_copy
.
GetRegisterBufferSize
()];
// even iteration
blockwise_in_copy
.
MoveSrcSlicingWindow
(
Sequence
<
EPerBlock
,
0
>
{},
True
);
static_if
<
conv_dir
==
ImplicitGemmDirection
::
BackwardWeight
>
{}([
&
](
auto
)
{
blockwise_wei_copy
.
MoveSrcSlicingWindow
(
Sequence
<
EPerBlock
,
0
>
{},
True
);
}).
Else
([
&
](
auto
fwd
)
{
p_wei_block_on_global
+=
EPerBlock
*
fwd
(
wei_e_k_global_desc
).
GetStride
(
I0
);
});
__syncthreads
();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy
.
RunLoadRegisterBuffer
(
p_in_global
,
p_in_register_buffer
);
blockwise_wei_copy
.
RunLoadRegisterBuffer
(
p_wei_block_on_global
,
p_wei_register_buffer
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_wei_block_double
,
p_in_block_double
,
p_out_thread
);
// LDS double buffer: store next data to LDS
blockwise_in_copy
.
RunStoreRegisterBuffer
(
p_in_register_buffer
,
p_in_block_double
+
in_block_space
);
blockwise_wei_copy
.
RunStoreRegisterBuffer
(
p_wei_register_buffer
,
p_wei_block_double
+
wei_block_space
);
// odd iteration
__syncthreads
();
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_wei_block_double
+
wei_block_space
,
p_in_block_double
+
in_block_space
,
p_out_thread
);
}
// load data from xldop_acc_regs
// static_if<EnableXdlops>{}([&](auto) {
// gcnasm_accvgpr_read<c_k_thread_mtx_desc.GetElementSpace()>(p_out_thread);
// });
// copy output: register to global memory
{
constexpr
index_t
K2
=
blockwise_gemm
.
OutputLayout
.
M2
;
constexpr
index_t
K1
=
blockwise_gemm
.
OutputLayout
.
M1
;
constexpr
index_t
K0
=
blockwise_gemm
.
OutputLayout
.
M0
;
// This is a hack, because slicing a merged dimension is not supported yet.
// dst descriptor
constexpr
auto
out_k0_k1_k2_b_global_desc
=
make_ConstantMergedTensorDescriptor
(
out_n_k_h_w_global_desc
.
Fold
(
I1
,
Number
<
K1
>
{},
Number
<
K2
>
{}),
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
0
,
4
,
5
>
{});
// src descriptor
constexpr
auto
out_k0_k1_k2_b_thread_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
K2
,
1
,
K0
,
1
>
{});
using
OutThreadCopySliceLengths
=
Sequence
<
K2
,
1
,
K0
,
1
>
;
constexpr
index_t
NumKPerBlk
=
out_k0_k1_k2_b_thread_desc
.
GetElementSpace
();
constexpr
index_t
NumBlks
=
GemmMPerWave
/
NumKPerBlk
;
for
(
index_t
i
=
0
;
i
<
NumBlks
;
++
i
)
{
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
i
);
const
index_t
k_thread_data_on_global
=
k_block_data_on_global
+
c_thread_mtx_on_block
.
row
;
const
index_t
b_thread_data_on_global
=
b_block_data_on_global
+
c_thread_mtx_on_block
.
col
;
auto
threadwise_out_copy
=
ThreadwiseGenericTensorSliceCopy_v2r1
<
decltype
(
out_k0_k1_k2_b_thread_desc
),
decltype
(
out_k0_k1_k2_b_global_desc
),
NormalTensorCoordinate
<
decltype
(
out_k0_k1_k2_b_thread_desc
)
>
,
MergedTensorCoordinate
<
decltype
(
out_k0_k1_k2_b_global_desc
)
>
,
OutThreadCopySliceLengths
,
arithmetic_sequence_gen
<
0
,
4
,
1
>::
type
,
arithmetic_sequence_gen
<
0
,
4
,
1
>::
type
,
3
,
3
,
OutThreadCopyDataPerAccess_B
,
OutThreadCopyDataPerAccess_B
>
({
0
,
0
,
0
,
0
},
{
k_thread_data_on_global
/
(
K0
*
K1
),
k_thread_data_on_global
%
(
K0
*
K1
)
/
K0
,
k_thread_data_on_global
%
K0
,
b_thread_data_on_global
});
threadwise_out_copy
.
Run
(
p_out_thread
+
i
*
NumKPerBlk
,
p_out_global
);
}
}
}
};
}
// namespace ck
#endif
composable_kernel/include/tensor_description/ConstantMatrixDescriptor.hpp
View file @
32850b93
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
#define CK_CONSTANT_MATRIX_DESCRIPTOR_HPP
#define CK_CONSTANT_MATRIX_DESCRIPTOR_HPP
#include "common_header.hpp"
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -39,7 +40,7 @@ struct ConstantMatrixDescriptor
...
@@ -39,7 +40,7 @@ struct ConstantMatrixDescriptor
};
};
template
<
index_t
NRow
,
index_t
NCol
>
template
<
index_t
NRow
,
index_t
NCol
>
__host__
__device__
constexpr
auto
make_ConstantMatrixDescriptor
(
Number
<
NRow
>
,
Number
<
NCol
>
)
__host__
__device__
constexpr
auto
make_ConstantMatrixDescriptor
_packed
(
Number
<
NRow
>
,
Number
<
NCol
>
)
{
{
return
ConstantMatrixDescriptor
<
NRow
,
NCol
,
NCol
>
{};
return
ConstantMatrixDescriptor
<
NRow
,
NCol
,
NCol
>
{};
}
}
...
@@ -51,6 +52,17 @@ __host__ __device__ constexpr auto
...
@@ -51,6 +52,17 @@ __host__ __device__ constexpr auto
return
ConstantMatrixDescriptor
<
NRow
,
NCol
,
RowStride
>
{};
return
ConstantMatrixDescriptor
<
NRow
,
NCol
,
RowStride
>
{};
}
}
template
<
class
...
Ts
>
__host__
__device__
constexpr
auto
make_ConstantMatrixDescriptor
(
ConstantTensorDescriptor
<
Ts
...
>
)
{
using
TDesc
=
ConstantTensorDescriptor
<
Ts
...
>
;
static_assert
(
TDesc
::
GetNumOfDimension
()
==
2
,
"wrong"
);
static_assert
(
TDesc
::
GetStrides
()[
1
]
==
1
,
"wrong"
);
return
ConstantMatrixDescriptor
<
TDesc
::
GetLengths
()[
0
],
TDesc
::
GetLengths
()[
1
],
TDesc
::
GetStrides
()[
0
]
>
{};
}
template
<
class
TDesc
>
template
<
class
TDesc
>
__host__
__device__
void
print_ConstantMatrixDescriptor
(
TDesc
,
const
char
*
s
)
__host__
__device__
void
print_ConstantMatrixDescriptor
(
TDesc
,
const
char
*
s
)
{
{
...
...
composable_kernel/include/tensor_description/ConstantMergedTensorDescriptor.hpp
View file @
32850b93
...
@@ -37,7 +37,7 @@ struct ConstantMergedTensorDescriptor
...
@@ -37,7 +37,7 @@ struct ConstantMergedTensorDescriptor
return
OriginalTensorDesc
{};
return
OriginalTensorDesc
{};
}
}
__host__
__device__
static
constexpr
index_t
GetNumOfDimension
()
{
return
nDim
;
}
__host__
__device__
static
constexpr
auto
GetNumOfDimension
()
{
return
Number
<
nDim
>
{}
;
}
template
<
index_t
IDim
>
template
<
index_t
IDim
>
__host__
__device__
static
constexpr
auto
GetContainedOriginalDimensions
(
Number
<
IDim
>
)
__host__
__device__
static
constexpr
auto
GetContainedOriginalDimensions
(
Number
<
IDim
>
)
...
@@ -52,7 +52,7 @@ struct ConstantMergedTensorDescriptor
...
@@ -52,7 +52,7 @@ struct ConstantMergedTensorDescriptor
}
}
template
<
index_t
IDim
>
template
<
index_t
IDim
>
__host__
__device__
static
constexpr
index_t
GetLength
(
Number
<
IDim
>
)
__host__
__device__
static
constexpr
auto
GetLength
(
Number
<
IDim
>
)
{
{
constexpr
auto
original_dims_partial
=
std
::
get
<
IDim
>
(
mOriginalDimMergeSeqs
);
constexpr
auto
original_dims_partial
=
std
::
get
<
IDim
>
(
mOriginalDimMergeSeqs
);
...
@@ -60,22 +60,32 @@ struct ConstantMergedTensorDescriptor
...
@@ -60,22 +60,32 @@ struct ConstantMergedTensorDescriptor
}
}
template
<
index_t
IDim
>
template
<
index_t
IDim
>
__host__
__device__
static
constexpr
index_t
GetStride
(
Number
<
IDim
>
)
__host__
__device__
static
constexpr
auto
GetStride
(
Number
<
IDim
>
)
{
{
static_assert
(
!
ContainMultipleOriginalDimensions
(
Number
<
IDim
>
{}),
static_assert
(
!
ContainMultipleOriginalDimensions
(
Number
<
IDim
>
{}),
"wrong! stride of a merged dimension is undefined"
);
"wrong! stride of a merged dimension is undefined"
);
constexpr
auto
idim_original
=
std
::
get
<
IDim
>
(
mOriginalDimMergeSeqs
).
Front
();
constexpr
auto
idim_original
=
std
::
get
<
IDim
>
(
mOriginalDimMergeSeqs
).
Back
();
return
OriginalTensorDesc
::
GetStride
(
Number
<
idim_original
>
{});
return
OriginalTensorDesc
::
GetStride
(
Number
<
idim_original
>
{});
}
}
// this is a hack to return the stride of the last original dimension of a merged dimension
// TODO: refactor this once the concept of "dimension" is used
template
<
index_t
IDim
>
__host__
__device__
static
constexpr
auto
GetLastOriginalDimensionStride
(
Number
<
IDim
>
)
{
constexpr
auto
idim_last_original
=
std
::
get
<
IDim
>
(
mOriginalDimMergeSeqs
).
Back
();
return
OriginalTensorDesc
::
GetStride
(
Number
<
idim_last_original
>
{});
}
__host__
__device__
static
constexpr
auto
GetLengths
()
__host__
__device__
static
constexpr
auto
GetLengths
()
{
{
return
Sequence
<
OriginalTensorDesc
::
Extract
(
OriginalDimMergeSeqs
{}).
GetElementSize
()...
>
{};
return
Sequence
<
OriginalTensorDesc
::
Extract
(
OriginalDimMergeSeqs
{}).
GetElementSize
()...
>
{};
}
}
__host__
__device__
static
constexpr
index_t
GetElementSize
()
__host__
__device__
static
constexpr
auto
GetElementSize
()
{
{
return
OriginalTensorDesc
::
GetElementSize
();
return
OriginalTensorDesc
::
GetElementSize
();
}
}
...
@@ -174,6 +184,13 @@ struct ConstantMergedTensorDescriptor
...
@@ -174,6 +184,13 @@ struct ConstantMergedTensorDescriptor
return
packed_desc
.
GetMultiIndexFrom1dIndex
(
id
);
return
packed_desc
.
GetMultiIndexFrom1dIndex
(
id
);
}
}
__host__
__device__
static
constexpr
auto
Pack
()
{
constexpr
auto
lengths
=
GetLengths
();
constexpr
auto
strides
=
calculate_tensor_strides_packed
(
lengths
);
return
ConstantTensorDescriptor
<
decltype
(
lengths
),
decltype
(
strides
)
>
{};
}
};
};
template
<
class
OriginalTensorDesc
,
class
...
OriginalDimMergeSeqs
>
template
<
class
OriginalTensorDesc
,
class
...
OriginalDimMergeSeqs
>
...
...
composable_kernel/include/tensor_description/ConstantTensorDescriptor.hpp
View file @
32850b93
...
@@ -43,23 +43,15 @@ struct ConstantTensorDescriptor
...
@@ -43,23 +43,15 @@ struct ConstantTensorDescriptor
return
Sequence
<
IDim
>
{};
return
Sequence
<
IDim
>
{};
}
}
__host__
__device__
static
constexpr
index_t
GetNumOfDimension
()
{
return
nDim
;
}
__host__
__device__
static
constexpr
auto
GetNumOfDimension
()
{
return
Number
<
nDim
>
{}
;
}
__host__
__device__
static
constexpr
auto
GetLengths
()
{
return
Lengths
{};
}
__host__
__device__
static
constexpr
auto
GetLengths
()
{
return
Lengths
{};
}
__host__
__device__
static
constexpr
auto
GetStrides
()
{
return
Strides
{};
}
__host__
__device__
static
constexpr
auto
GetStrides
()
{
return
Strides
{};
}
template
<
index_t
I
>
__host__
__device__
static
constexpr
auto
GetLength
(
index_t
IDim
)
{
return
Lengths
{}[
IDim
];
}
__host__
__device__
static
constexpr
index_t
GetLength
(
Number
<
I
>
)
{
return
Lengths
::
Get
(
Number
<
I
>
{});
}
template
<
index_t
I
>
__host__
__device__
static
constexpr
auto
GetStride
(
index_t
IDim
)
{
return
Strides
{}[
IDim
];
}
__host__
__device__
static
constexpr
index_t
GetStride
(
Number
<
I
>
)
{
return
Strides
::
Get
(
Number
<
I
>
{});
}
struct
lambda_AreDimensionsContinuous
struct
lambda_AreDimensionsContinuous
{
{
...
@@ -102,17 +94,18 @@ struct ConstantTensorDescriptor
...
@@ -102,17 +94,18 @@ struct ConstantTensorDescriptor
return
false
;
return
false
;
}
}
__host__
__device__
static
constexpr
index_t
GetElementSize
()
__host__
__device__
static
constexpr
auto
GetElementSize
()
{
{
return
accumulate_on_sequence
(
Lengths
{},
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{});
return
Number
<
accumulate_on_sequence
(
Lengths
{},
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{})
>
{};
}
}
__host__
__device__
static
constexpr
index_t
GetElementSpace
()
__host__
__device__
static
constexpr
auto
GetElementSpace
()
{
{
constexpr
index_t
element_space_unaligned
=
accumulate_on_sequence
(
constexpr
index_t
element_space_unaligned
=
accumulate_on_sequence
(
(
GetLengths
()
-
Number
<
1
>
{})
*
GetStrides
(),
math
::
plus
<
index_t
>
{},
Number
<
1
>
{});
(
GetLengths
()
-
Number
<
1
>
{})
*
GetStrides
(),
math
::
plus
<
index_t
>
{},
Number
<
1
>
{});
return
element_space_unaligned
;
return
Number
<
element_space_unaligned
>
{}
;
}
}
// emulate constexpr lambda
// emulate constexpr lambda
...
@@ -156,13 +149,14 @@ struct ConstantTensorDescriptor
...
@@ -156,13 +149,14 @@ struct ConstantTensorDescriptor
}
}
template
<
index_t
...
Is
>
template
<
index_t
...
Is
>
__host__
__device__
static
constexpr
index_t
GetOffsetFromMultiIndex
(
Sequence
<
Is
...
>
)
__host__
__device__
static
constexpr
auto
GetOffsetFromMultiIndex
(
Sequence
<
Is
...
>
)
{
{
static_assert
(
sizeof
...(
Is
)
==
nDim
,
"wrong! Dimension not consistent"
);
static_assert
(
sizeof
...(
Is
)
==
nDim
,
"wrong! Dimension not consistent"
);
constexpr
auto
multi_id
=
Sequence
<
Is
...
>
{};
constexpr
auto
multi_id
=
Sequence
<
Is
...
>
{};
return
accumulate_on_sequence
(
multi_id
*
GetStrides
(),
math
::
plus
<
index_t
>
{},
Number
<
0
>
{});
return
Number
<
accumulate_on_sequence
(
multi_id
*
GetStrides
(),
math
::
plus
<
index_t
>
{},
Number
<
0
>
{})
>
{};
}
}
// emulate constexpr lambda
// emulate constexpr lambda
...
@@ -369,6 +363,12 @@ struct ConstantTensorDescriptor
...
@@ -369,6 +363,12 @@ struct ConstantTensorDescriptor
return
ConstantTensorDescriptor
<
decltype
(
new_lengths
),
decltype
(
new_strides
)
>
{};
return
ConstantTensorDescriptor
<
decltype
(
new_lengths
),
decltype
(
new_strides
)
>
{};
}
}
template
<
index_t
IDim
,
index_t
...
FoldIntervals
>
__host__
__device__
static
constexpr
auto
Fold
(
Number
<
IDim
>
,
Sequence
<
FoldIntervals
...
>
)
{
return
Fold
(
Number
<
IDim
>
{},
Number
<
FoldIntervals
>
{}...);
}
// this function unfold dimension [FirstUnfoldDim, ..., LastUnfoldDim] into 1 dimension
// this function unfold dimension [FirstUnfoldDim, ..., LastUnfoldDim] into 1 dimension
template
<
index_t
FirstUnfoldDim
,
index_t
LastUnfoldDim
>
template
<
index_t
FirstUnfoldDim
,
index_t
LastUnfoldDim
>
__host__
__device__
static
constexpr
auto
Unfold
(
Number
<
FirstUnfoldDim
>
,
Number
<
LastUnfoldDim
>
)
__host__
__device__
static
constexpr
auto
Unfold
(
Number
<
FirstUnfoldDim
>
,
Number
<
LastUnfoldDim
>
)
...
@@ -407,6 +407,12 @@ struct ConstantTensorDescriptor
...
@@ -407,6 +407,12 @@ struct ConstantTensorDescriptor
return
ConstantTensorDescriptor
<
decltype
(
new_lengths
),
decltype
(
new_strides
)
>
{};
return
ConstantTensorDescriptor
<
decltype
(
new_lengths
),
decltype
(
new_strides
)
>
{};
}
}
__host__
__device__
static
constexpr
auto
Pack
()
{
using
packed_strides
=
decltype
(
calculate_tensor_strides_packed
(
Lengths
{}));
return
ConstantTensorDescriptor
<
Lengths
,
packed_strides
>
{};
}
template
<
class
MapNew2Old
>
template
<
class
MapNew2Old
>
__host__
__device__
static
constexpr
auto
ReorderGivenNew2Old
(
MapNew2Old
)
__host__
__device__
static
constexpr
auto
ReorderGivenNew2Old
(
MapNew2Old
)
{
{
...
@@ -414,14 +420,12 @@ struct ConstantTensorDescriptor
...
@@ -414,14 +420,12 @@ struct ConstantTensorDescriptor
decltype
(
Strides
::
ReorderGivenNew2Old
(
MapNew2Old
{}))
>
{};
decltype
(
Strides
::
ReorderGivenNew2Old
(
MapNew2Old
{}))
>
{};
}
}
#if 0 // require sequence_sort, which is not implemented yet
template
<
class
MapOld2New
>
template
<
class
MapOld2New
>
__host__
__device__
static
constexpr
auto
ReorderGivenOld2New
(
MapOld2New
)
__host__
__device__
static
constexpr
auto
ReorderGivenOld2New
(
MapOld2New
)
{
{
return
ConstantTensorDescriptor
<
decltype
(
Lengths
::
ReorderGivenOld2New
(
MapOld2New
{})),
return
ConstantTensorDescriptor
<
decltype
(
Lengths
::
ReorderGivenOld2New
(
MapOld2New
{})),
decltype(Strides::ReorderGivenOld2New(MapOld2New{}))>{}
decltype
(
Strides
::
ReorderGivenOld2New
(
MapOld2New
{}))
>
{}
;
}
}
#endif
};
};
template
<
class
Lengths
>
template
<
class
Lengths
>
...
@@ -451,7 +455,7 @@ print_ConstantTensorDescriptor(const char* s,
...
@@ -451,7 +455,7 @@ print_ConstantTensorDescriptor(const char* s,
{
{
constexpr
index_t
ndim
=
sizeof
...(
Lengths
);
constexpr
index_t
ndim
=
sizeof
...(
Lengths
);
static_assert
(
ndim
>
0
&&
ndim
<=
1
0
,
"wrong!"
);
static_assert
(
ndim
>
0
&&
ndim
<=
1
2
,
"wrong!"
);
static_if
<
ndim
==
1
>
{}([
&
](
auto
)
{
static_if
<
ndim
==
1
>
{}([
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u}, strides {%u}
\n
"
,
s
,
ndim
,
Lengths
...,
Strides
...);
printf
(
"%s dim %u, lengths {%u}, strides {%u}
\n
"
,
s
,
ndim
,
Lengths
...,
Strides
...);
...
@@ -523,6 +527,26 @@ print_ConstantTensorDescriptor(const char* s,
...
@@ -523,6 +527,26 @@ print_ConstantTensorDescriptor(const char* s,
Lengths
...,
Lengths
...,
Strides
...);
Strides
...);
});
});
static_if
<
ndim
==
11
>
{}([
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u "
"%u %u "
"%u %u %u}
\n
"
,
s
,
ndim
,
Lengths
...,
Strides
...);
});
static_if
<
ndim
==
12
>
{}([
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u "
"%u %u %u %u "
"%u %u %u}
\n
"
,
s
,
ndim
,
Lengths
...,
Strides
...);
});
}
}
}
// namespace ck
}
// namespace ck
...
...
composable_kernel/include/tensor_description/tensor_coordinate.hpp
0 → 100644
View file @
32850b93
#ifndef CK_TENSOR_COORDINATE_HPP
#define CK_TENSOR_COORDINATE_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
namespace
ck
{
template
<
class
TensorDesc
>
struct
NormalTensorCoordinate
{
using
type
=
NormalTensorCoordinate
;
using
tensor_desc_type
=
TensorDesc
;
static
constexpr
index_t
nDim
=
tensor_desc_type
::
GetNumOfDimension
();
__host__
__device__
constexpr
NormalTensorCoordinate
(
Array
<
index_t
,
nDim
>
tensor_index
)
:
mOffset
{
tensor_desc_type
::
GetOffsetFromMultiIndex
(
tensor_index
)}
{
}
template
<
class
...
Xs
>
__host__
__device__
constexpr
NormalTensorCoordinate
(
Xs
...
xs
)
:
NormalTensorCoordinate
(
Array
<
index_t
,
nDim
>
{
xs
...})
{
}
__host__
__device__
constexpr
index_t
GetOffset
()
const
{
return
mOffset
;
}
// T is Array or Sequence
template
<
class
T
>
__host__
__device__
type
operator
+=
(
T
step_sizes
)
{
static_assert
(
is_same
<
typename
T
::
data_type
,
index_t
>
{}
&&
T
::
GetSize
()
==
nDim
,
"wrong!"
);
mOffset
+=
tensor_desc_type
::
GetOffsetFromMultiIndex
(
step_sizes
);
return
*
this
;
}
template
<
class
T
>
__host__
__device__
type
operator
-=
(
T
step_sizes
)
{
static_assert
(
is_same
<
typename
T
::
data_type
,
index_t
>
{}
&&
T
::
GetSize
()
==
nDim
,
"wrong!"
);
mOffset
-=
tensor_desc_type
::
GetOffsetFromMultiIndex
(
step_sizes
);
return
*
this
;
}
template
<
class
T
>
__host__
__device__
constexpr
type
operator
+
(
T
step_sizes
)
const
{
type
coord
=
*
this
;
coord
+=
step_sizes
;
return
coord
;
}
template
<
class
T
>
__host__
__device__
constexpr
type
operator
-
(
T
step_sizes
)
const
{
type
coord
=
*
this
;
coord
-=
step_sizes
;
return
coord
;
}
// reposition point of origin, and return compensated offset.
// This is a hack to reduce index calculation during looping over
// a tensor whose origin is this TensorCoordinate. It does so, by spitting
// out the run-time offset to the pointer (to the tensor data) held by this
// TensorCoordiante, so the caller can add the offset into the run-time pointer of
// the data, so only 1 run-time variable (update pointer) is needed, instead
// of 2 run-time variables (old pointer and this offset)
// TODO: after introducing the concept of "run-time tensor view", which contains the
// run-time pointer to the data, always keep track of the pointer, instead of both
// offset and the pointer. This also bring additional benefit that we don't need to
// worry the offset might underflow (because offset is unsigned integer) when updating it.
__host__
__device__
constexpr
index_t
RepositOrigin
()
{
index_t
offset_diff
=
mOffset
;
mOffset
=
0
;
return
offset_diff
;
}
private:
index_t
mOffset
;
};
template
<
class
TensorDesc
>
struct
MergedTensorCoordinate
{
using
type
=
MergedTensorCoordinate
;
using
tensor_desc_type
=
TensorDesc
;
static
constexpr
index_t
nDim
=
tensor_desc_type
::
GetNumOfDimension
();
static
constexpr
index_t
nOriginalDim
=
tensor_desc_type
::
GetOriginalTensorDescriptor
().
GetNumOfDimension
();
__host__
__device__
constexpr
MergedTensorCoordinate
(
Array
<
index_t
,
nDim
>
tensor_index
)
:
mOriginalIndex
{
tensor_desc_type
::
GetOriginalMultiIndexFromMultiIndex
(
tensor_index
)}
{
// partial offset on each dimension
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
idim
)
{
constexpr
auto
partial_original_dims
=
tensor_desc_type
::
GetContainedOriginalDimensions
(
idim
);
constexpr
auto
partial_original_desc
=
tensor_desc_type
::
GetOriginalTensorDescriptor
().
Extract
(
partial_original_dims
);
mPartialOffsets
(
idim
)
=
partial_original_desc
.
GetOffsetFromMultiIndex
(
extract_array
(
mOriginalIndex
,
partial_original_dims
));
});
// complete offset
mOffset
=
accumulate_on_array
(
mPartialOffsets
,
math
::
plus
<
index_t
>
{},
static_cast
<
index_t
>
(
0
));
}
template
<
class
...
Xs
>
__host__
__device__
constexpr
MergedTensorCoordinate
(
Xs
...
xs
)
:
MergedTensorCoordinate
(
Array
<
index_t
,
nDim
>
{
xs
...})
{
}
__host__
__device__
constexpr
index_t
GetOffset
()
const
{
return
mOffset
;
}
template
<
class
IDim
,
class
T
,
bool
PositiveDirection
>
__host__
__device__
void
MoveOnDimension
(
IDim
idim_
,
T
step_size
,
integral_constant
<
bool
,
PositiveDirection
>
)
{
constexpr
auto
idim
=
idim_
;
// if step_size is known at compile time
static_if
<
is_static
<
T
>::
value
>
{}(
[
&
](
auto
)
{
static_if
<
T
{}
==
0
>
{}([
&
](
auto
)
{
return
;
});
});
// update original index
static_if
<
tensor_desc_type
::
ContainMultipleOriginalDimensions
(
idim
)
>
{}([
&
](
auto
)
{
constexpr
auto
partial_original_dims
=
tensor_desc_type
::
GetContainedOriginalDimensions
(
idim
);
constexpr
index_t
ndim_partial_original
=
partial_original_dims
.
GetSize
();
constexpr
auto
partial_original_desc
=
tensor_desc_type
::
GetOriginalTensorDescriptor
().
Extract
(
partial_original_dims
);
const
auto
partial_original_step_sizes
=
partial_original_desc
.
GetMultiIndexFrom1dIndex
(
step_size
);
// update partial original multi-id
auto
partial_original_id
=
extract_array
(
mOriginalIndex
,
partial_original_dims
);
static_if
<
PositiveDirection
>
{}([
&
](
auto
)
{
partial_original_id
+=
partial_original_step_sizes
;
bool
carry
=
false
;
// do carry check in reversed order, starting from lowest dimension
// don't check the highest dimension
static_for
<
0
,
ndim_partial_original
,
1
>
{}([
&
](
auto
IReverse
)
{
constexpr
index_t
i
=
ndim_partial_original
-
1
-
IReverse
;
if
(
carry
)
{
++
partial_original_id
(
i
);
}
carry
=
false
;
if
(
partial_original_id
[
i
]
>=
partial_original_desc
.
GetLength
(
i
))
{
partial_original_id
(
i
)
-=
partial_original_desc
.
GetLength
(
i
);
carry
=
true
;
}
});
}).
Else
([
&
](
auto
)
{
// shift up multi-id to avoid unsigned integer underflow during intermediate
// calculations. After the shift, should have new_multi_id[...] >= 1
partial_original_id
+=
partial_original_desc
.
GetLengths
()
-
partial_original_step_sizes
;
bool
borrow
=
false
;
// do borrow check in reversed order, starting from lowest dimension
// don't check the highest dimension
static_for
<
0
,
ndim_partial_original
,
1
>
{}([
&
](
auto
IReverse
)
{
constexpr
index_t
i
=
ndim_partial_original
-
1
-
IReverse
;
if
(
borrow
)
{
--
partial_original_id
(
i
);
}
borrow
=
false
;
if
(
partial_original_id
[
i
]
<
partial_original_desc
.
GetLength
(
i
))
{
partial_original_id
(
i
)
+=
partial_original_desc
.
GetLength
(
i
);
borrow
=
true
;
}
});
// shift back down multi-id
// here, should have new_multi_id[...] >= GetLengths()
partial_original_id
=
partial_original_id
-
partial_original_desc
.
GetLengths
();
});
// update "mOriginalIndex"
static_for
<
0
,
ndim_partial_original
,
1
>
{}([
&
](
auto
I
)
{
constexpr
auto
idim_original
=
partial_original_dims
[
I
];
mOriginalIndex
(
idim_original
)
=
partial_original_id
[
I
];
});
// calculate new partial offset on this merged dimension
const
index_t
old_partial_offset
=
mPartialOffsets
[
idim
];
mPartialOffsets
(
idim
)
=
partial_original_desc
.
GetOffsetFromMultiIndex
(
partial_original_id
);
// update "mThreadSrcOffset", do "+" before "-" to avoid underflow
mOffset
=
(
mOffset
+
mPartialOffsets
[
idim
])
-
old_partial_offset
;
}).
Else
([
&
](
auto
fwd
)
{
static_if
<
PositiveDirection
>
{}([
&
](
auto
)
{
mOffset
+=
step_size
*
fwd
(
tensor_desc_type
{}).
GetStride
(
idim
);
}).
Else
([
&
](
auto
)
{
mOffset
-=
step_size
*
fwd
(
tensor_desc_type
{}).
GetStride
(
idim
);
});
});
}
// T is Array or Sequence
template
<
class
T
>
__host__
__device__
type
operator
+=
(
T
step_sizes
)
{
static_assert
(
is_same
<
typename
T
::
data_type
,
index_t
>
{}
&&
T
::
GetSize
()
==
nDim
,
"wrong! the rank of step size doesn't match with that of tensor coordinate"
);
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
idim
)
{
if
(
step_sizes
[
idim
]
!=
0
)
{
this
->
MoveOnDimension
(
idim
,
step_sizes
[
idim
],
integral_constant
<
bool
,
true
>
{});
}
});
return
*
this
;
}
template
<
class
T
>
__host__
__device__
type
operator
-=
(
T
step_sizes
)
{
static_assert
(
is_same
<
typename
T
::
data_type
,
index_t
>
{}
&&
T
::
GetSize
()
==
nDim
,
"wrong! the rank of step size doesn't match with that of tensor coordinate"
);
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
idim
)
{
if
(
step_sizes
[
idim
]
!=
0
)
{
this
->
MoveOnDimension
(
idim
,
step_sizes
[
idim
],
integral_constant
<
bool
,
false
>
{});
}
});
return
*
this
;
}
template
<
class
T
>
__host__
__device__
constexpr
type
operator
+
(
T
step_sizes
)
const
{
type
coord
=
*
this
;
coord
+=
step_sizes
;
return
coord
;
}
template
<
class
T
>
__host__
__device__
constexpr
type
operator
-
(
T
step_sizes
)
const
{
type
coord
=
*
this
;
coord
-=
step_sizes
;
return
coord
;
}
__host__
__device__
static
constexpr
index_t
RepositOrigin
()
{
return
0
;
}
private:
// Allocate register memory for all merged dimensions and normal dimensions.
// However, only those merged dimensions, whose index will be involved in arithmetic
// after the construction of this TensorCoordinate (e.g. when user move a slicing
// window on the merged dimension), will use these register memory.
// Let's hope compiler will optimize away those register memory allocated for normal
// dimensions, and those merged dimensions, that would never be involved in index
// arithmetic after construction of TensorCoordinate.
// TODO: refactor TensorCoordinate, after introducing the concept of "dimensions"
// and simplify implementation of ConstantMergedTensorDescriptor, so we don't need to
// count on compiler to optimize way those register memory for us
Array
<
index_t
,
nOriginalDim
>
mOriginalIndex
;
Array
<
index_t
,
nDim
>
mPartialOffsets
;
// complete offset
index_t
mOffset
;
};
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/blockwise_gemm.hpp
View file @
32850b93
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
#include "threadwise_gemm.hpp"
#include "threadwise_gemm.hpp"
#ifndef CK_BLOCKWISE_GEMM_USE_AMD_INLINE_ASM
#ifndef CK_BLOCKWISE_GEMM_USE_AMD_INLINE_ASM
#define CK_BLOCKWISE_GEMM_USE_AMD_INLINE_ASM
1
#define CK_BLOCKWISE_GEMM_USE_AMD_INLINE_ASM
0
#endif
#endif
namespace
ck
{
namespace
ck
{
...
@@ -14,6 +14,7 @@ namespace ck {
...
@@ -14,6 +14,7 @@ namespace ck {
// if following number are power of 2, index calculation shall be greatly reduced:
// if following number are power of 2, index calculation shall be greatly reduced:
// MPerThreadSubC, NPerThreadSubC, MLevel0Cluster, NLevel0Cluster, MLevel1Cluster, NLevel1Cluster
// MPerThreadSubC, NPerThreadSubC, MLevel0Cluster, NLevel0Cluster, MLevel1Cluster, NLevel1Cluster
template
<
index_t
BlockSize
,
template
<
index_t
BlockSize
,
index_t
EPack
,
class
BlockMatrixA
,
class
BlockMatrixA
,
class
BlockMatrixB
,
class
BlockMatrixB
,
class
ThreadMatrixC
,
class
ThreadMatrixC
,
...
@@ -113,6 +114,151 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -113,6 +114,151 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
}
}
#if CK_USE_AMD_INLINE_ASM
#if CK_USE_AMD_INLINE_ASM
// 1 in specialized template represent pack size. fp32 = 1
template
<
index_t
PACKSIZE
>
__device__
void
outerProduct
(
const
typename
vector_type
<
float
,
4
>::
MemoryType
&
a
,
const
typename
vector_type
<
float
,
4
>::
MemoryType
&
b
,
typename
vector_type
<
float
,
4
>::
MemoryType
*
c
)
const
{
static_assert
(
1
==
PACKSIZE
,
"only packsize of 1 is supported with float datatype!"
);
constexpr
index_t
NRepeat
=
2
;
outerProduct1x4
(
a
.
x
,
b
,
c
[
0
*
NRepeat
]);
outerProduct1x4
(
a
.
y
,
b
,
c
[
1
*
NRepeat
]);
outerProduct1x4
(
a
.
z
,
b
,
c
[
2
*
NRepeat
]);
outerProduct1x4
(
a
.
w
,
b
,
c
[
3
*
NRepeat
]);
}
// 1 in specialized template represent pack size. fp32 = 1
template
<
index_t
PACKSIZE
>
__device__
void
outerProduct
(
const
typename
vector_type
<
float
,
2
>::
MemoryType
&
a
,
const
typename
vector_type
<
float
,
4
>::
MemoryType
&
b
,
typename
vector_type
<
float
,
4
>::
MemoryType
*
c
)
const
{
static_assert
(
1
==
PACKSIZE
,
"only packsize of 1 is supported with float datatype!"
);
constexpr
index_t
NRepeat
=
2
;
outerProduct1x4
(
a
.
x
,
b
,
c
[
0
*
NRepeat
]);
outerProduct1x4
(
a
.
y
,
b
,
c
[
1
*
NRepeat
]);
}
// 1 in specialized template represent pack size. fp32 = 1
template
<
index_t
PACKSIZE
>
__device__
void
outerProduct
(
const
typename
vector_type
<
float
,
4
>::
MemoryType
&
a
,
const
typename
vector_type
<
float
,
2
>::
MemoryType
&
b
,
typename
vector_type
<
float
,
2
>::
MemoryType
*
c
)
const
{
static_assert
(
1
==
PACKSIZE
,
"only packsize of 1 is supported with float datatype!"
);
constexpr
index_t
NRepeat
=
2
;
outerProduct1x2
(
a
.
x
,
b
,
c
[
0
*
NRepeat
]);
outerProduct1x2
(
a
.
y
,
b
,
c
[
1
*
NRepeat
]);
outerProduct1x2
(
a
.
z
,
b
,
c
[
2
*
NRepeat
]);
outerProduct1x2
(
a
.
w
,
b
,
c
[
3
*
NRepeat
]);
}
// 1 in specialized template represent pack size. fp32 = 1
template
<
index_t
PACKSIZE
>
__device__
void
outerProduct
(
const
typename
vector_type
<
float
,
2
>::
MemoryType
&
a
,
const
typename
vector_type
<
float
,
2
>::
MemoryType
&
b
,
typename
vector_type
<
float
,
2
>::
MemoryType
*
c
)
const
{
static_assert
(
1
==
PACKSIZE
,
"only packsize of 1 is supported with float datatype!"
);
constexpr
index_t
NRepeat
=
2
;
outerProduct1x2
(
a
.
x
,
b
,
c
[
0
*
NRepeat
]);
outerProduct1x2
(
a
.
y
,
b
,
c
[
1
*
NRepeat
]);
}
// PACKSIZE for fp16 could be 4 or 2
template
<
index_t
PACKSIZE
>
__device__
void
outerProduct
(
const
typename
vector_type
<
typename
vector_type
<
half
,
PACKSIZE
>::
MemoryType
,
4
>::
MemoryType
&
a
,
const
typename
vector_type
<
typename
vector_type
<
half
,
PACKSIZE
>::
MemoryType
,
4
>::
MemoryType
&
b
,
typename
vector_type
<
float
,
4
>::
MemoryType
*
c
)
const
{
static_assert
(
2
==
PACKSIZE
||
4
==
PACKSIZE
,
"only packsize of 2,4 is supported with float datatype!"
);
constexpr
index_t
NRepeat
=
2
;
const
typename
vector_type
<
half
,
PACKSIZE
>::
MemoryType
*
reg_a
=
reinterpret_cast
<
const
typename
vector_type
<
half
,
PACKSIZE
>::
MemoryType
*>
(
&
a
);
outerProduct1x4Half
<
PACKSIZE
>
(
reg_a
[
0
],
b
,
c
[
0
*
NRepeat
]);
outerProduct1x4Half
<
PACKSIZE
>
(
reg_a
[
1
],
b
,
c
[
1
*
NRepeat
]);
outerProduct1x4Half
<
PACKSIZE
>
(
reg_a
[
2
],
b
,
c
[
2
*
NRepeat
]);
outerProduct1x4Half
<
PACKSIZE
>
(
reg_a
[
3
],
b
,
c
[
3
*
NRepeat
]);
}
// PACKSIZE for fp16 could be 4 or 2
template
<
index_t
PACKSIZE
>
__device__
void
outerProduct
(
const
typename
vector_type
<
typename
vector_type
<
half
,
PACKSIZE
>::
MemoryType
,
2
>::
MemoryType
&
a
,
const
typename
vector_type
<
typename
vector_type
<
half
,
PACKSIZE
>::
MemoryType
,
2
>::
MemoryType
&
b
,
typename
vector_type
<
float
,
2
>::
MemoryType
*
c
)
const
{
static_assert
(
2
==
PACKSIZE
||
4
==
PACKSIZE
,
"only packsize of 2,4 is supported with float datatype!"
);
constexpr
index_t
NRepeat
=
2
;
const
typename
vector_type
<
half
,
PACKSIZE
>::
MemoryType
*
reg_a
=
reinterpret_cast
<
const
typename
vector_type
<
half
,
PACKSIZE
>::
MemoryType
*>
(
&
a
);
outerProduct1x2Half
<
PACKSIZE
>
(
reg_a
[
0
],
b
,
c
[
0
*
NRepeat
]);
outerProduct1x2Half
<
PACKSIZE
>
(
reg_a
[
1
],
b
,
c
[
1
*
NRepeat
]);
}
// PACKSIZE for fp16 could be 4 or 2
template
<
index_t
PACKSIZE
>
__device__
void
outerProduct1x4Half
(
const
typename
vector_type
<
half
,
PACKSIZE
>::
MemoryType
&
a
,
const
typename
vector_type
<
typename
vector_type
<
half
,
PACKSIZE
>::
MemoryType
,
4
>::
MemoryType
&
b
,
vector_type
<
float
,
4
>::
MemoryType
&
c
)
const
{
static_if
<
PACKSIZE
==
4
>
{}([
&
](
auto
)
{
outerProduct1x4dot2TwoTimes
(
reinterpret_cast
<
const
half2
*>
(
&
a
),
reinterpret_cast
<
const
half2
*>
(
&
b
),
reinterpret_cast
<
float
*>
(
&
c
));
}).
Else
([
&
](
auto
)
{
static_if
<
PACKSIZE
==
2
>
{}([
&
](
auto
)
{
outerProduct1x4dot2
(
reinterpret_cast
<
const
half2
*>
(
&
a
),
reinterpret_cast
<
const
half2
*>
(
&
b
),
reinterpret_cast
<
float
*>
(
&
c
));
}).
Else
([
&
](
auto
fwd
)
{
// not implemented
static_assert
(
fwd
(
false
),
"wrong! packsize = 1 for fp16 is insensible."
);
});
});
}
// PACKSIZE for fp16 could be 4 or 2
template
<
index_t
PACKSIZE
>
__device__
void
outerProduct1x2Half
(
const
typename
vector_type
<
half
,
PACKSIZE
>::
MemoryType
&
a
,
const
typename
vector_type
<
typename
vector_type
<
half
,
PACKSIZE
>::
MemoryType
,
2
>::
MemoryType
&
b
,
vector_type
<
float
,
2
>::
MemoryType
&
c
)
const
{
static_if
<
PACKSIZE
==
4
>
{}([
&
](
auto
)
{
outerProduct1x2dot2TwoTimes
(
reinterpret_cast
<
const
half2
*>
(
&
a
),
reinterpret_cast
<
const
half2
*>
(
&
b
),
reinterpret_cast
<
float
*>
(
&
c
));
}).
Else
([
&
](
auto
)
{
static_if
<
PACKSIZE
==
2
>
{}([
&
](
auto
)
{
outerProduct1x2dot2
(
reinterpret_cast
<
const
half2
*>
(
&
a
),
reinterpret_cast
<
const
half2
*>
(
&
b
),
reinterpret_cast
<
float
*>
(
&
c
));
}).
Else
([
&
](
auto
fwd
)
{
// not implemented
static_assert
(
fwd
(
false
),
"wrong! packsize = 1 for fp16 is insensible."
);
});
});
}
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
Run_amd_asm
(
const
FloatA
*
__restrict__
p_a_block
,
__device__
void
Run_amd_asm
(
const
FloatA
*
__restrict__
p_a_block
,
const
FloatB
*
__restrict__
p_b_block
,
const
FloatB
*
__restrict__
p_b_block
,
...
@@ -131,91 +277,60 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -131,91 +277,60 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
// thread A, B for GEMM
// thread A, B for GEMM
constexpr
auto
a_thread_mtx
=
constexpr
auto
a_thread_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThread
>
{});
make_ConstantMatrixDescriptor
_packed
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThread
>
{});
constexpr
auto
b_thread_mtx
=
constexpr
auto
b_thread_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThread
>
{});
make_ConstantMatrixDescriptor_packed
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThread
>
{});
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpace
()];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpace
()];
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
static_assert
(
MPerThreadSubC
==
4
&&
NPerThreadSubC
==
4
&&
KPerThreadLoop
==
1
&&
static_assert
((
MPerThreadSubC
==
4
||
MPerThreadSubC
==
2
)
&&
MPerThread
==
8
&&
NPerThread
==
8
,
(
NPerThreadSubC
==
4
||
NPerThreadSubC
==
2
)
&&
KPerThreadLoop
==
1
,
"Run_amd_asm cannot deal with this GEMM shape yet"
);
"M/NPerThreadSubC wrong!"
);
static_assert
(
DataPerReadA
==
4
&&
DataPerReadB
==
4
,
"Run_amd_asm only do float4 read"
);
static_assert
(
MPerThread
%
4
==
0
&&
NPerThread
%
4
==
0
,
"M/NPerThread % 4 != 0"
);
// If A and B datatype is float
constexpr
index_t
MRepeat
=
M
/
(
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
);
static_if
<
std
::
is_same
<
FloatA
,
float
>::
value
&&
constexpr
index_t
NRepeat
=
N
/
(
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
);
std
::
is_same
<
FloatB
,
float
>::
value
>
{}([
&
](
auto
)
{
using
Float4
=
vector_type
<
float
,
4
>::
MemoryType
;
Float4
*
reg_a
=
reinterpret_cast
<
Float4
*>
(
p_a_thread
);
Float4
*
reg_b
=
reinterpret_cast
<
Float4
*>
(
p_b_thread
);
Float4
*
reg_c
=
reinterpret_cast
<
Float4
*>
(
p_c_thread
);
reg_a
[
0
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_a_block
[
mMyThreadOffsetA
]);
reg_b
[
0
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_b_block
[
mMyThreadOffsetB
]);
reg_b
[
1
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_b_block
[
mMyThreadOffsetB
+
NPerLevel1Cluster
]);
reg_a
[
1
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_a_block
[
mMyThreadOffsetA
+
MPerLevel1Cluster
]);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
[
0
],
reg_c
[
2
],
reg_c
[
4
],
reg_c
[
6
]);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
1
],
reg_c
[
1
],
reg_c
[
3
],
reg_c
[
5
],
reg_c
[
7
]);
#pragma unroll
for
(
index_t
k
=
1
;
k
<
K
;
++
k
)
{
reg_a
[
0
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_a_block
[
mMyThreadOffsetA
+
k
*
M
]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
0
],
reg_c
[
8
],
reg_c
[
10
],
reg_c
[
12
],
reg_c
[
14
]);
reg_b
[
0
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_b_block
[
mMyThreadOffsetB
+
k
*
N
]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
1
],
reg_c
[
9
],
reg_c
[
11
],
reg_c
[
13
],
reg_c
[
15
]);
reg_b
[
1
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_b_block
[
mMyThreadOffsetB
+
k
*
N
+
NPerLevel1Cluster
]);
reg_a
[
1
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_a_block
[
mMyThreadOffsetA
+
k
*
M
+
MPerLevel1Cluster
]);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
[
0
],
reg_c
[
2
],
reg_c
[
4
],
reg_c
[
6
]);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
1
],
reg_c
[
1
],
reg_c
[
3
],
reg_c
[
5
],
reg_c
[
7
]);
}
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
0
],
reg_c
[
8
],
reg_c
[
10
],
reg_c
[
12
],
reg_c
[
14
]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
1
],
reg_c
[
9
],
reg_c
[
11
],
reg_c
[
13
],
reg_c
[
15
]);
}).
Else
([
&
](
auto
)
{
// If A and B datatype is bfloat16/float16
using
Half4x4
=
vector_type
<
vector_type
<
half
,
4
>
,
4
>
;
using
Float4
=
vector_type
<
float
,
4
>::
MemoryType
;
Half4x4
*
reg_a
=
reinterpret_cast
<
Half4x4
*>
(
p_a_thread
);
Half4x4
*
reg_b
=
reinterpret_cast
<
Half4x4
*>
(
p_b_thread
);
Float4
*
reg_c
=
reinterpret_cast
<
Float4
*>
(
p_c_thread
);
reg_a
[
0
]
=
*
reinterpret_cast
<
const
Half4x4
*>
(
&
p_a_block
[
mMyThreadOffsetA
]);
reg_b
[
0
]
=
*
reinterpret_cast
<
const
Half4x4
*>
(
&
p_b_block
[
mMyThreadOffsetB
]);
reg_b
[
1
]
=
*
reinterpret_cast
<
const
Half4x4
*>
(
&
p_b_block
[
mMyThreadOffsetB
+
NPerLevel1Cluster
]);
reg_a
[
1
]
=
*
reinterpret_cast
<
const
Half4x4
*>
(
&
p_a_block
[
mMyThreadOffsetA
+
MPerLevel1Cluster
]);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
[
0
],
reg_c
[
2
],
reg_c
[
4
],
reg_c
[
6
]);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
1
],
reg_c
[
1
],
reg_c
[
3
],
reg_c
[
5
],
reg_c
[
7
]);
static_assert
(
MRepeat
==
2
&&
NRepeat
==
2
,
"M/NRepeat != 2"
);
using
typeA
=
typename
vector_type
<
FloatA
,
MPerThreadSubC
>::
MemoryType
;
using
typeB
=
typename
vector_type
<
FloatB
,
NPerThreadSubC
>::
MemoryType
;
using
typeC
=
typename
vector_type
<
FloatC
,
NPerThreadSubC
>::
MemoryType
;
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpace
()];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpace
()];
typeA
*
reg_a
=
reinterpret_cast
<
typeA
*>
(
p_a_thread
);
typeB
*
reg_b
=
reinterpret_cast
<
typeB
*>
(
p_b_thread
);
typeC
*
reg_c
=
reinterpret_cast
<
typeC
*>
(
p_c_thread
);
reg_a
[
0
]
=
*
reinterpret_cast
<
const
typeA
*>
(
&
p_a_block
[
mMyThreadOffsetA
]);
reg_b
[
0
]
=
*
reinterpret_cast
<
const
typeB
*>
(
&
p_b_block
[
mMyThreadOffsetB
]);
reg_b
[
1
]
=
*
reinterpret_cast
<
const
typeB
*>
(
&
p_b_block
[(
mMyThreadOffsetB
+
NPerLevel1Cluster
)]);
reg_a
[
1
]
=
*
reinterpret_cast
<
const
typeA
*>
(
&
p_a_block
[(
mMyThreadOffsetA
+
MPerLevel1Cluster
)]);
outerProduct
<
EPack
>
(
reg_a
[
0
],
reg_b
[
0
],
&
reg_c
[
0
]);
outerProduct
<
EPack
>
(
reg_a
[
0
],
reg_b
[
1
],
&
reg_c
[
1
]);
#pragma unroll
#pragma unroll
for
(
index_t
k
=
1
;
k
<
K
;
++
k
)
for
(
index_t
k
=
1
;
k
<
K
;
++
k
)
{
{
reg_a
[
0
]
=
*
reinterpret_cast
<
const
Half4x4
*>
(
&
p_a_block
[
mMyThreadOffsetA
+
k
*
M
]);
reg_a
[
0
]
=
*
reinterpret_cast
<
const
typeA
*>
(
&
p_a_block
[(
mMyThreadOffsetA
+
k
*
M
)]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
0
],
reg_c
[
8
],
reg_c
[
10
],
reg_c
[
12
],
reg_c
[
14
]);
outerProduct
<
EPack
>
(
reg_a
[
1
],
reg_b
[
0
],
&
reg_c
[
NRepeat
*
MPerThreadSubC
]);
reg_b
[
0
]
=
*
reinterpret_cast
<
const
Half4x4
*>
(
&
p_b_block
[
mMyThreadOffsetB
+
k
*
N
]);
reg_b
[
0
]
=
*
reinterpret_cast
<
const
typeB
*>
(
&
p_b_block
[(
mMyThreadOffsetB
+
k
*
N
)]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
1
],
reg_c
[
9
],
reg_c
[
11
],
reg_c
[
13
],
reg_c
[
15
]);
outerProduct
<
EPack
>
(
reg_a
[
1
],
reg_b
[
1
],
&
reg_c
[
NRepeat
*
MPerThreadSubC
+
1
]);
reg_b
[
1
]
=
*
reinterpret_cast
<
const
Half4x4
*>
(
reg_b
[
1
]
=
*
reinterpret_cast
<
const
typeB
*>
(
&
p_b_block
[
mMyThreadOffsetB
+
k
*
N
+
NPerLevel1Cluster
]);
&
p_b_block
[(
mMyThreadOffsetB
+
k
*
N
+
NPerLevel1Cluster
)]);
reg_a
[
1
]
=
*
reinterpret_cast
<
const
Half4x4
*>
(
reg_a
[
1
]
=
*
reinterpret_cast
<
const
typeA
*>
(
&
p_a_block
[
mMyThreadOffsetA
+
k
*
M
+
MPerLevel1Cluster
]);
&
p_a_block
[(
mMyThreadOffsetA
+
k
*
M
+
MPerLevel1Cluster
)]);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
[
0
],
reg_c
[
2
],
reg_c
[
4
],
reg_c
[
6
]);
outerProduct
<
EPack
>
(
reg_a
[
0
],
reg_b
[
0
],
&
reg_c
[
0
]);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
1
],
reg_c
[
1
],
reg_c
[
3
],
reg_c
[
5
],
reg_c
[
7
]);
outerProduct
<
EPack
>
(
reg_a
[
0
],
reg_b
[
1
],
&
reg_c
[
1
]);
}
}
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
0
],
reg_c
[
8
],
reg_c
[
10
],
reg_c
[
12
],
reg_c
[
14
]);
outerProduct
<
EPack
>
(
reg_a
[
1
],
reg_b
[
0
],
&
reg_c
[
NRepeat
*
MPerThreadSubC
]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
1
],
reg_c
[
9
],
reg_c
[
11
],
reg_c
[
13
],
reg_c
[
15
]);
outerProduct
<
EPack
>
(
reg_a
[
1
],
reg_b
[
1
],
&
reg_c
[
NRepeat
*
MPerThreadSubC
+
1
]);
});
}
}
#endif
#endif
...
@@ -250,8 +365,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -250,8 +365,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr
auto
b_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
constexpr
auto
b_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThreadSubC
>
{},
Number
<
NPerThread
>
{});
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThreadSubC
>
{},
Number
<
NPerThread
>
{});
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpace
()
*
4
];
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpace
()];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpace
()
*
4
];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpace
()];
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
...
@@ -270,10 +385,10 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -270,10 +385,10 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
threadwise_matrix_copy
(
threadwise_matrix_copy
(
a_block_mtx
,
a_block_mtx
,
p_a_block
+
p_a_block
+
(
a_block_mtx
.
GetOffsetFromMultiIndex
(
k_begin
,
m_repeat
*
MPerLevel1Cluster
)
+
a_block_mtx
.
GetOffsetFromMultiIndex
(
k_begin
,
m_repeat
*
MPerLevel1Cluster
)
+
mMyThreadOffsetA
)
*
4
,
mMyThreadOffsetA
,
a_thread_mtx
,
a_thread_mtx
,
p_a_thread
+
(
a_thread_mtx
.
GetOffsetFromMultiIndex
(
0
,
m_repeat
*
MPerThreadSubC
)
)
*
4
,
p_a_thread
+
a_thread_mtx
.
GetOffsetFromMultiIndex
(
0
,
m_repeat
*
MPerThreadSubC
),
a_thread_sub_mtx
.
GetLengths
(),
a_thread_sub_mtx
.
GetLengths
(),
Number
<
DataPerReadA
>
{});
Number
<
DataPerReadA
>
{});
}
}
...
@@ -285,10 +400,10 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -285,10 +400,10 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
threadwise_matrix_copy
(
threadwise_matrix_copy
(
b_block_mtx
,
b_block_mtx
,
p_b_block
+
p_b_block
+
(
b_block_mtx
.
GetOffsetFromMultiIndex
(
k_begin
,
n_repeat
*
NPerLevel1Cluster
)
+
b_block_mtx
.
GetOffsetFromMultiIndex
(
k_begin
,
n_repeat
*
NPerLevel1Cluster
)
+
mMyThreadOffsetB
)
*
4
,
mMyThreadOffsetB
,
b_thread_mtx
,
b_thread_mtx
,
p_b_thread
+
(
b_thread_mtx
.
GetOffsetFromMultiIndex
(
0
,
n_repeat
*
NPerThreadSubC
)
)
*
4
,
p_b_thread
+
b_thread_mtx
.
GetOffsetFromMultiIndex
(
0
,
n_repeat
*
NPerThreadSubC
),
b_thread_sub_mtx
.
GetLengths
(),
b_thread_sub_mtx
.
GetLengths
(),
Number
<
DataPerReadB
>
{});
Number
<
DataPerReadB
>
{});
}
}
...
@@ -306,156 +421,21 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -306,156 +421,21 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
}
}
}
}
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
RunRegisterDoubleBuffer_source
(
FloatA
*
const
p_a_block
,
FloatB
*
const
p_b_block
,
FloatC
*
p_c_thread
)
const
{
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
False
=
integral_constant
<
bool
,
false
>
{};
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
index_t
K
=
a_block_mtx
.
NRow
();
constexpr
index_t
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
index_t
NPerThread
=
c_thread_mtx
.
NCol
();
// thread A, B for GEMM
constexpr
auto
a_thread_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThread
>
{});
constexpr
auto
b_thread_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThread
>
{});
// thread A-sub, B-sub for copy
constexpr
auto
a_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThreadSubC
>
{},
Number
<
MPerThread
>
{});
constexpr
auto
b_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThreadSubC
>
{},
Number
<
NPerThread
>
{});
// register
FloatA
p_a_thread_0
[
a_thread_mtx
.
GetElementSpace
()];
FloatB
p_b_thread_0
[
b_thread_mtx
.
GetElementSpace
()];
FloatA
p_a_thread_1
[
a_thread_mtx
.
GetElementSpace
()];
FloatB
p_b_thread_1
[
b_thread_mtx
.
GetElementSpace
()];
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
// preload A, B
#pragma unroll
for
(
index_t
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
{
// copy A-sub to form A
threadwise_matrix_copy
(
a_block_mtx
,
p_a_block
+
mMyThreadOffsetA
+
m_repeat
*
MPerLevel1Cluster
,
a_thread_sub_mtx
,
p_a_thread_0
+
m_repeat
*
MPerThreadSubC
,
a_thread_sub_mtx
.
GetLengths
(),
Number
<
DataPerReadA
>
{});
}
#pragma unroll
for
(
index_t
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
{
// copy B-sub to form B
threadwise_matrix_copy
(
b_block_mtx
,
p_b_block
+
mMyThreadOffsetB
+
n_repeat
*
NPerLevel1Cluster
,
b_thread_sub_mtx
,
p_b_thread_0
+
n_repeat
*
NPerThreadSubC
,
b_thread_sub_mtx
.
GetLengths
(),
Number
<
DataPerReadB
>
{});
}
bool
even_loop
=
true
;
#pragma unroll
for
(
index_t
k_begin
=
0
;
k_begin
+
KPerThreadLoop
<
K
;
k_begin
+=
KPerThreadLoop
,
even_loop
=
!
even_loop
)
{
// loop over k
FloatA
*
p_a_thread_now
=
even_loop
?
p_a_thread_0
:
p_a_thread_1
;
FloatB
*
p_b_thread_now
=
even_loop
?
p_b_thread_0
:
p_b_thread_1
;
FloatA
*
p_a_thread_next
=
even_loop
?
p_a_thread_1
:
p_a_thread_0
;
FloatB
*
p_b_thread_next
=
even_loop
?
p_b_thread_1
:
p_b_thread_0
;
// preload next A, B
#pragma unroll
for
(
index_t
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
{
// copy A-sub to form A
threadwise_matrix_copy
(
a_block_mtx
,
p_a_block
+
mMyThreadOffsetA
+
(
k_begin
+
1
)
*
a_block_mtx
.
RowStride
()
+
m_repeat
*
MPerLevel1Cluster
,
a_thread_sub_mtx
,
p_a_thread_next
+
m_repeat
*
MPerThreadSubC
,
a_thread_sub_mtx
.
GetLengths
(),
Number
<
DataPerReadA
>
{});
}
#pragma unroll
for
(
index_t
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
{
// copy B-sub to form B
threadwise_matrix_copy
(
b_block_mtx
,
p_b_block
+
mMyThreadOffsetB
+
(
k_begin
+
1
)
*
b_block_mtx
.
RowStride
()
+
n_repeat
*
NPerLevel1Cluster
,
b_thread_sub_mtx
,
p_b_thread_next
+
n_repeat
*
NPerThreadSubC
,
b_thread_sub_mtx
.
GetLengths
(),
Number
<
DataPerReadB
>
{});
}
// C = A * B
threadwise_gemm
(
a_thread_mtx
,
True
,
p_a_thread_now
,
b_thread_mtx
,
False
,
p_b_thread_now
,
c_thread_mtx
,
False
,
p_c_thread
);
}
// last loop
{
FloatA
*
p_a_thread_now
=
even_loop
?
p_a_thread_0
:
p_a_thread_1
;
FloatB
*
p_b_thread_now
=
even_loop
?
p_b_thread_0
:
p_b_thread_1
;
// C = A * B
threadwise_gemm
(
a_thread_mtx
,
True
,
p_a_thread_now
,
b_thread_mtx
,
False
,
p_b_thread_now
,
c_thread_mtx
,
False
,
p_c_thread
);
}
}
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
Run
(
const
FloatA
*
__restrict__
p_a_block
,
__device__
void
Run
(
const
FloatA
*
__restrict__
p_a_block
,
const
FloatB
*
__restrict__
p_b_block
,
const
FloatB
*
__restrict__
p_b_block
,
FloatC
*
__restrict__
p_c_thread
)
const
FloatC
*
__restrict__
p_c_thread
)
const
{
{
#if CK_USE_AMD_INLINE_ASM && CK_BLOCKWISE_GEMM_USE_AMD_INLINE_ASM
#if CK_USE_AMD_INLINE_ASM && CK_BLOCKWISE_GEMM_USE_AMD_INLINE_ASM
static_if
<
std
::
is_same
<
FloatA
,
ushort
>::
value
&&
std
::
is_same
<
FloatB
,
ushort
>::
value
>
{}(
// The assembly path doesn't support bfloat16 using asm instructions
[
&
](
auto
)
{
Run_source
(
p_a_block
,
p_b_block
,
p_c_thread
);
})
#if MIOPEN_USE_BFP16 == 1
.
Else
([
&
](
auto
)
{
// If A and B datatype is bfloat16/float16
Run_amd_asm
(
p_a_block
,
p_b_block
,
p_c_thread
);
});
#else
Run_source
(
p_a_block
,
p_b_block
,
p_c_thread
);
Run_source
(
p_a_block
,
p_b_block
,
p_c_thread
);
#else
Run_amd_asm
(
p_a_block
,
p_b_block
,
p_c_thread
);
#endif
#endif
#else
Run_source
(
p_a_block
,
p_b_block
,
p_c_thread
);
#endif // CK_USE_AMD_INLINE_ASM
}
}
};
};
...
...
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
0 → 100644
View file @
32850b93
#ifndef CK_BLOCKWISE_GEMM_XDLOPS_HPP
#define CK_BLOCKWISE_GEMM_XDLOPS_HPP
#include "common_header.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "threadwise_gemm.hpp"
namespace
ck
{
template
<
class
input_type
>
struct
mfma_info
{
};
template
<
>
struct
mfma_info
<
float
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_blk
=
4
;
static
constexpr
index_t
num_blks_wave
=
2
;
static
constexpr
index_t
num_regs_blk
=
group_size
*
num_groups_blk
;
static
constexpr
index_t
num_regs_xdlops
=
num_regs_blk
*
num_blks_wave
;
static
constexpr
index_t
num_threads_blk
=
32
;
static
constexpr
index_t
m
=
32
;
static
constexpr
index_t
n
=
32
;
static
constexpr
index_t
k
=
1
;
static
constexpr
index_t
wave_size
=
64
;
};
template
<
>
struct
mfma_info
<
half
>
{
static
const
index_t
group_size
=
4
;
static
const
index_t
num_groups_blk
=
4
;
static
const
index_t
num_blks_wave
=
2
;
static
const
index_t
num_regs_blk
=
group_size
*
num_groups_blk
;
static
const
index_t
num_regs_xdlops
=
num_regs_blk
*
num_blks_wave
;
static
const
index_t
num_threads_blk
=
32
;
static
const
index_t
m
=
32
;
static
const
index_t
n
=
32
;
static
const
index_t
k
=
4
;
static
const
index_t
wave_size
=
64
;
};
template
<
>
struct
mfma_info
<
ushort
>
{
static
const
index_t
group_size
=
4
;
static
const
index_t
num_groups_blk
=
4
;
static
const
index_t
num_blks_wave
=
2
;
static
const
index_t
num_regs_blk
=
group_size
*
num_groups_blk
;
static
const
index_t
num_regs_xdlops
=
num_regs_blk
*
num_blks_wave
;
static
const
index_t
num_threads_blk
=
32
;
static
const
index_t
m
=
32
;
static
const
index_t
n
=
32
;
static
const
index_t
k
=
2
;
static
const
index_t
wave_size
=
64
;
};
// emulate xdlops
template
<
index_t
M
,
index_t
N
,
index_t
K
,
index_t
MPerWave
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadB
,
class
mfma_info
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
WaveWiseGemmMx64
(
const
FloatA
*
const
__restrict__
p_a_wave
,
const
FloatB
*
const
__restrict__
p_b_wave
,
FloatC
*
const
__restrict__
p_c_thread
)
{
static_assert
(
GemmDataPerReadA
==
1
&&
GemmDataPerReadB
==
1
,
"GemmDataPerReadA/B != 1"
);
const
index_t
laneId
=
get_thread_local_1d_id
()
%
mfma_info
::
wave_size
;
const
index_t
blk_id
=
laneId
/
mfma_info
::
num_threads_blk
;
const
index_t
lane_b
=
laneId
%
mfma_info
::
num_threads_blk
;
for
(
index_t
k
=
0
;
k
<
K
;
++
k
)
{
for
(
index_t
b
=
0
;
b
<
MPerWave
/
mfma_info
::
m
;
++
b
)
{
index_t
a_off
=
k
*
M
+
b
*
mfma_info
::
m
;
index_t
b_off
=
k
*
N
;
// pseudo mfma
for
(
index_t
n
=
0
;
n
<
mfma_info
::
num_blks_wave
;
++
n
)
{
index_t
output_m
=
mfma_info
::
num_regs_blk
;
for
(
index_t
m
=
0
;
m
<
output_m
;
++
m
)
{
index_t
aindex
=
m
%
mfma_info
::
group_size
+
blk_id
*
mfma_info
::
group_size
+
m
/
mfma_info
::
group_size
*
(
mfma_info
::
group_size
*
mfma_info
::
num_blks_wave
)
+
a_off
;
// A is transposed
index_t
bindex
=
b_off
+
lane_b
+
n
*
mfma_info
::
num_threads_blk
;
p_c_thread
[
m
+
n
*
output_m
+
b
*
output_m
*
mfma_info
::
num_blks_wave
]
+=
math
::
inner_product_with_conversion
<
FloatC
>
{}(
p_a_wave
[
aindex
],
p_b_wave
[
bindex
]);
}
}
}
}
}
#if 0
template <index_t M,
index_t N,
index_t K,
index_t MPerWave,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
class mfma_info>
__device__ void WaveWiseGemmMx64_xdlops(const float* const __restrict__ p_a_wave,
const float* const __restrict__ p_b_wave,
float* const __restrict__ p_c_thread)
{
static_assert(MPerWave == 32 || MPerWave == 64, "only support MPerWave = 32/64");
static_assert(GemmDataPerReadA == 1 && GemmDataPerReadB == 1, "GemmDataPerReadA/B != 1");
const index_t laneId = get_thread_local_1d_id() % mfma_info::wave_size;
for(index_t k = 0; k < K; k += mfma_info::k)
{
float reg_a = p_a_wave[k * M + laneId];
float reg_b = p_b_wave[k * N + laneId];
float32_t* reg_c = reinterpret_cast<float32_t*>(p_c_thread);
gcnasm_mfma_f32_32x32x1f32<MPerWave>(reg_a, reg_b, reg_c);
}
}
template <index_t M,
index_t N,
index_t K,
index_t MPerWave,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
class mfma_info>
__device__ void WaveWiseGemmMx64_xdlops(
const typename vector_type<half, 4>::MemoryType* const __restrict__ p_a_wave,
const typename vector_type<half, 4>::MemoryType* const __restrict__ p_b_wave,
float* const __restrict__ p_c_thread)
{
static_assert(MPerWave == 32 || MPerWave == 64, "only support MPerWave = 32/64");
const index_t laneId = threadIdx.x % mfma_info::wave_size;
for(index_t k = 0; k < K; k += mfma_info::k / 4)
{
typename vector_type<half, 4>::MemoryType reg_a = p_a_wave[k * M + laneId];
typename vector_type<half, 4>::MemoryType reg_b = p_b_wave[k * N + laneId];
float32_t* reg_c = reinterpret_cast<float32_t*>(p_c_thread);
gcnasm_mfma_f32_32x32x4f16<MPerWave>(reg_a, reg_b, reg_c);
}
}
template <index_t M,
index_t N,
index_t K,
index_t MPerWave,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
class mfma_info>
__device__ void WaveWiseGemmMx64_xdlops(
const typename vector_type<ushort, 2>::MemoryType* const __restrict__ p_a_wave,
const typename vector_type<ushort, 2>::MemoryType* const __restrict__ p_b_wave,
float* const __restrict__ p_c_thread)
{
static_assert(MPerWave == 32 || MPerWave == 64, "only support MPerWave = 32/64");
const index_t laneId = threadIdx.x % mfma_info::wave_size;
for(index_t k = 0; k < K; k += mfma_info::k / 2)
{
typename vector_type<ushort, 2>::MemoryType reg_a = p_a_wave[k * M + laneId];
typename vector_type<ushort, 2>::MemoryType reg_b = p_b_wave[k * N + laneId];
float32_t* reg_c = reinterpret_cast<float32_t*>(p_c_thread);
gcnasm_mfma_f32_32x32x2bf16<MPerWave>(reg_a, reg_b, reg_c);
}
}
#endif
template
<
index_t
BlockSize
,
class
BlockMatrixA
,
class
BlockMatrixB
,
class
mfma_info
,
bool
EnableXdlops
,
index_t
GemmMPerWave
,
index_t
GemmNPerWave
,
index_t
GemmMWaves
,
index_t
GemmNWaves
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadB
>
struct
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
{
struct
MatrixIndex
{
index_t
row
;
index_t
col
;
};
struct
OutputLayout_t
{
static
constexpr
index_t
M3
=
GemmMPerWave
/
mfma_info
::
m
;
static
constexpr
index_t
M2
=
mfma_info
::
num_groups_blk
;
static
constexpr
index_t
M1
=
mfma_info
::
num_blks_wave
;
static
constexpr
index_t
M0
=
mfma_info
::
group_size
;
};
index_t
mMyWaveOffsetA
;
index_t
mMyWaveOffsetB
;
OutputLayout_t
OutputLayout
;
__device__
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
()
{
static_assert
(
BlockMatrixA
::
NRow
()
==
BlockMatrixB
::
NRow
(),
"wrong! K dimension not consistent
\n
"
);
constexpr
index_t
M
=
BlockMatrixA
::
NCol
();
// A is transposed
constexpr
index_t
N
=
BlockMatrixB
::
NCol
();
static_assert
(
GemmNPerWave
==
64
,
"Only support GemmNPerWave == 64 for xdlops"
);
static_assert
(
GemmMPerWave
==
32
||
GemmMPerWave
==
64
,
"Only support GemmMPerWave == 32 or 64 for xdlops"
);
static_assert
(
GemmMPerWave
*
GemmMWaves
==
M
,
"GemmMWaves * GemmMPerWave != M"
);
static_assert
(
GemmNPerWave
*
GemmNWaves
==
N
,
"GemmNWaves * GemmNPerWave != N"
);
static_assert
(
BlockSize
==
GemmMWaves
*
GemmNWaves
*
64
,
"BlockSize != GemmMWaves * GemmNWaves * 64
\n
"
);
const
index_t
waveId
=
get_thread_local_1d_id
()
/
mfma_info
::
wave_size
;
const
index_t
waveId_m
=
waveId
/
GemmNWaves
;
const
index_t
waveId_n
=
waveId
%
GemmNWaves
;
mMyWaveOffsetA
=
waveId_m
*
GemmMPerWave
;
mMyWaveOffsetB
=
waveId_n
*
GemmNPerWave
;
}
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
Run
(
const
FloatA
*
__restrict__
p_a_block
,
const
FloatB
*
__restrict__
p_b_block
,
FloatC
*
__restrict__
p_c_thread
)
const
{
constexpr
index_t
M
=
BlockMatrixA
::
NCol
();
// A is transposed
constexpr
index_t
N
=
BlockMatrixB
::
NCol
();
constexpr
index_t
K
=
BlockMatrixA
::
NRow
();
// static_if<EnableXdlops>{}([&](auto) {
// WaveWiseGemmMx64_xdlops<M,
// N,
// K,
// GemmMPerWave,
// GemmDataPerReadA,
// GemmDataPerReadB,
// mfma_info>(
// &p_a_block[mMyWaveOffsetA], &p_b_block[mMyWaveOffsetB], p_c_thread);
// }).Else([&](auto) {
WaveWiseGemmMx64
<
M
,
N
,
K
,
GemmMPerWave
,
GemmDataPerReadA
,
GemmDataPerReadB
,
mfma_info
>
(
&
p_a_block
[
mMyWaveOffsetA
],
&
p_b_block
[
mMyWaveOffsetB
],
p_c_thread
);
// });
}
__device__
static
MatrixIndex
GetBeginOfThreadMatrixC
(
index_t
i
)
{
const
index_t
laneId
=
get_thread_local_1d_id
()
%
mfma_info
::
wave_size
;
const
index_t
waveId
=
get_thread_local_1d_id
()
/
mfma_info
::
wave_size
;
const
index_t
col_i
=
i
%
mfma_info
::
num_blks_wave
;
const
index_t
col
=
waveId
%
GemmNWaves
*
mfma_info
::
wave_size
+
laneId
%
mfma_info
::
num_threads_blk
+
col_i
*
mfma_info
::
num_threads_blk
;
const
index_t
row_i
=
i
/
mfma_info
::
num_blks_wave
;
const
index_t
row
=
waveId
/
GemmNWaves
*
GemmMPerWave
+
laneId
/
mfma_info
::
num_threads_blk
*
mfma_info
::
group_size
+
row_i
*
mfma_info
::
num_threads_blk
;
return
MatrixIndex
{
row
,
col
};
}
__device__
constexpr
auto
GetThreadMatrixCDescriptor
()
const
{
constexpr
index_t
num_xdlops
=
GemmMPerWave
/
mfma_info
::
m
;
return
make_ConstantMatrixDescriptor_packed
(
Number
<
mfma_info
::
num_regs_xdlops
*
num_xdlops
>
{},
Number
<
1
>
{});
}
};
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp
View file @
32850b93
...
@@ -10,12 +10,541 @@
...
@@ -10,12 +10,541 @@
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1 1
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1 1
#endif
#endif
#define JOINTCAT(x, y) x##y
#define ASSERT_MSG_ARG1(msg, var1) JOINTCAT(msg, var1)
#define ASSERT_MSG_ARG2(msg, var1, va2) ASSERT_MSG_ARG1(JOINTCAT(msg, var1), var2)
namespace
ck
{
namespace
ck
{
// Slice a (normal or merged) tensor, and copy it into another (normal or merged) tensor
// memory layout (ordering of dimensions) can be different between src and dst.
// This functions assume each thread is reading and writing a normal (not merged) tensor,
// to simplify index calculations. To satisfy this assumption, the user need to make sure
// that, on a merged dimension that constains multiple original dimensions, the length of
// the last original dimension need to be evenly dividable by its sub-lengths. Also, the
// repeat-length on the merged dimension need to be 1. These sanity checks are performed
// in constructor of BlockwiseGenericTensorSliceCopy_v1
template
<
index_t
BlockSize
,
class
SrcDesc
,
class
DstDesc
,
class
SliceLengths
,
class
SubLengths
,
class
ThreadClusterLengths
,
class
ThreadClusterArrangeOrder
,
class
SrcDimAccessOrder
,
class
DstDimAccessOrder
,
index_t
SrcVectorAccessDim
,
index_t
DstVectorAccessDim
,
index_t
SrcDataPerAccess
,
index_t
DstDataPerAccess
>
struct
BlockwiseGenericTensorSliceCopy_v1
{
static
constexpr
index_t
nDim
=
SrcDesc
::
GetNumOfDimension
();
static
constexpr
index_t
nOriginalDimSrc
=
SrcDesc
::
GetOriginalTensorDescriptor
().
GetNumOfDimension
();
static
constexpr
index_t
nOriginalDimDst
=
DstDesc
::
GetOriginalTensorDescriptor
().
GetNumOfDimension
();
// per-thread offset
index_t
mThreadSrcOffset
;
index_t
mThreadDstOffset
;
// "mThreadSrcOriginalMultiId", "mThreadSrcPartialOffsets, "mThreadDstOriginalMultiId",
// "mThreadDstPartialOffsets" are always calculated inside constructor, and would be
// updated if slicing-window is moved. However, they will not be used if you always move
// the slicing-window along a non-merged dimension. In that case, compiler should be
// able to remove these calculation.
// TODO: make sure compiler would actually remove them in that case
// partial offset in each (merged) dimension
Array
<
index_t
,
nDim
>
mThreadSrcPartialOffsets
;
Array
<
index_t
,
nDim
>
mThreadDstPartialOffsets
;
// multi-id of original tensor
Array
<
index_t
,
nOriginalDimSrc
>
mThreadSrcOriginalMultiId
;
Array
<
index_t
,
nOriginalDimDst
>
mThreadDstOriginalMultiId
;
__device__
BlockwiseGenericTensorSliceCopy_v1
(
Array
<
index_t
,
nDim
>
src_block_data_id_begin
,
Array
<
index_t
,
nDim
>
dst_block_data_id_begin
)
{
// check NDim consistency
static_assert
(
nDim
==
SrcDesc
::
GetNumOfDimension
()
&&
nDim
==
DstDesc
::
GetNumOfDimension
()
&&
nDim
==
SliceLengths
::
GetSize
()
&&
nDim
==
SubLengths
::
GetSize
()
&&
nDim
==
ThreadClusterLengths
::
GetSize
()
&&
nDim
==
ThreadClusterArrangeOrder
::
GetSize
()
&&
nDim
==
SrcDimAccessOrder
::
GetSize
()
&&
nDim
==
DstDimAccessOrder
::
GetSize
(),
"wrong"
);
// check thread arrange order and read/write access order are valid
static_assert
(
is_valid_sequence_map
<
ThreadClusterArrangeOrder
>::
value
&&
is_valid_sequence_map
<
SrcDimAccessOrder
>::
value
&&
is_valid_sequence_map
<
DstDimAccessOrder
>::
value
,
"wrong!"
);
// thread cluster
constexpr
auto
thread_cluster_desc
=
make_ConstantTensorDescriptor_packed
(
ThreadClusterLengths
::
ReorderGivenNew2Old
(
ThreadClusterArrangeOrder
{}));
// BlockSize
static_assert
(
BlockSize
==
thread_cluster_desc
.
GetElementSize
(),
"wrong! BlockSize"
);
// divide work
constexpr
auto
data_per_cluster_per_dims
=
SubLengths
{}
*
ThreadClusterLengths
{};
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
IDim
)
{
static_assert
(
SliceLengths
::
Get
(
IDim
)
%
data_per_cluster_per_dims
.
Get
(
IDim
)
==
0
,
"wrong! cannot evenly divide sliced tensor into cluster"
);
});
constexpr
auto
repeat_lengths
=
SliceLengths
{}
/
data_per_cluster_per_dims
;
// additional check for merged dimension
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
IDim_
)
{
// src
static_if
<
SrcDesc
::
ContainMultipleOriginalDimensions
(
IDim_
)
>
{}([
&
](
auto
)
{
constexpr
auto
IDim
=
decltype
(
IDim_
){};
// on a merged dimension that constains multiple original dimensions,
// the length of the last original dimension need to evenly dividable by its
// sub-length,
// so each thread is effectively reading a normal (not merged) tensor
constexpr
auto
idim_last_original_src
=
SrcDesc
::
GetContainedOriginalDimensions
(
IDim
).
Back
();
static_assert
(
SrcDesc
::
GetOriginalTensorDescriptor
().
GetLength
(
idim_last_original_src
)
%
SubLengths
::
Get
(
IDim
)
==
0
,
"wrong!"
);
// merged dimension should have repeat_lengths = 1
static_assert
(
repeat_lengths
[
IDim
]
==
1
,
"wrong! repeat_lengths shoud be 1 on merged dimension"
);
});
// dst
static_if
<
DstDesc
::
ContainMultipleOriginalDimensions
(
IDim_
)
>
{}([
&
](
auto
)
{
constexpr
auto
IDim
=
decltype
(
IDim_
){};
// on a merged dimension that constains multiple original dimensions,
// the length of the last original dimension need to evenly dividable by its
// sub-length,
// so each thread is effectively reading a normal (not merged) tensor
constexpr
auto
idim_last_original_dst
=
DstDesc
::
GetContainedOriginalDimensions
(
IDim
).
Back
();
static_assert
(
DstDesc
::
GetOriginalTensorDescriptor
().
GetLength
(
idim_last_original_dst
)
%
SubLengths
::
Get
(
IDim
)
==
0
,
"wrong!"
);
// merged dimension should have repeat_lengths = 1
static_assert
(
repeat_lengths
[
IDim
]
==
1
,
"wrong! repeat_lengths shoud be 1 on merged dimension"
);
});
});
// calculate mThreadSrcOffset, mThreadDstOffset
const
auto
thread_cluster_id
=
thread_cluster_desc
.
GetMultiIndexFrom1dIndex
(
get_thread_local_1d_id
());
const
auto
data_cluster_id
=
reorder_array_given_old2new
(
thread_cluster_id
,
ThreadClusterArrangeOrder
{});
const
auto
thread_data_id_begin
=
data_cluster_id
*
SubLengths
{};
// original multi-id
mThreadSrcOriginalMultiId
=
SrcDesc
::
GetOriginalMultiIndexFromMultiIndex
(
src_block_data_id_begin
+
thread_data_id_begin
);
mThreadDstOriginalMultiId
=
DstDesc
::
GetOriginalMultiIndexFromMultiIndex
(
dst_block_data_id_begin
+
thread_data_id_begin
);
// partial offset on each dimension
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
IDim
)
{
constexpr
auto
src_partial_original_dims
=
SrcDesc
::
GetContainedOriginalDimensions
(
IDim
);
constexpr
auto
src_partial_original_desc
=
SrcDesc
::
GetOriginalTensorDescriptor
().
Extract
(
src_partial_original_dims
);
mThreadSrcPartialOffsets
(
IDim
)
=
src_partial_original_desc
.
GetOffsetFromMultiIndex
(
extract_array
(
mThreadSrcOriginalMultiId
,
src_partial_original_dims
));
});
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
IDim
)
{
constexpr
auto
dst_partial_original_dims
=
DstDesc
::
GetContainedOriginalDimensions
(
IDim
);
constexpr
auto
dst_partial_original_desc
=
DstDesc
::
GetOriginalTensorDescriptor
().
Extract
(
dst_partial_original_dims
);
mThreadDstPartialOffsets
(
IDim
)
=
dst_partial_original_desc
.
GetOffsetFromMultiIndex
(
extract_array
(
mThreadDstOriginalMultiId
,
dst_partial_original_dims
));
});
// complete offset
mThreadSrcOffset
=
accumulate_on_array
(
mThreadSrcPartialOffsets
,
math
::
plus
<
index_t
>
{},
static_cast
<
index_t
>
(
0
));
mThreadDstOffset
=
accumulate_on_array
(
mThreadDstPartialOffsets
,
math
::
plus
<
index_t
>
{},
static_cast
<
index_t
>
(
0
));
}
__device__
static
constexpr
auto
GetRegisterBufferDescriptor
()
{
constexpr
auto
repeat_lengths
=
SliceLengths
{}
/
(
SubLengths
{}
*
ThreadClusterLengths
{});
return
make_ConstantTensorDescriptor_packed
(
SubLengths
{}
*
repeat_lengths
);
}
__device__
static
constexpr
index_t
GetRegisterBufferSize
()
{
return
GetRegisterBufferDescriptor
().
GetElementSpace
();
}
template
<
class
TData
>
__device__
void
RunLoadRegisterBuffer
(
const
TData
*
__restrict__
p_src
,
TData
*
__restrict__
p_buffer
)
const
{
constexpr
auto
thread_sub_tensor_lengths
=
SubLengths
{};
constexpr
auto
data_per_cluster_per_dims
=
thread_sub_tensor_lengths
*
ThreadClusterLengths
{};
constexpr
auto
repeat_lengths
=
SliceLengths
{}
/
(
SubLengths
{}
*
ThreadClusterLengths
{});
constexpr
auto
thread_buffer_desc
=
GetRegisterBufferDescriptor
();
#if CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1
static_ford
<
decltype
(
repeat_lengths
)
>
{}([
&
](
auto
repeat_id
)
{
constexpr
auto
src_thread_data_id_begin
=
repeat_id
*
data_per_cluster_per_dims
;
constexpr
auto
buffer_data_id_begin
=
repeat_id
*
thread_sub_tensor_lengths
;
constexpr
index_t
src_offset
=
SrcDesc
::
GetOffsetFromMultiIndex
(
src_thread_data_id_begin
);
constexpr
index_t
buffer_offset
=
thread_buffer_desc
.
GetOffsetFromMultiIndex
(
buffer_data_id_begin
);
#else
ford
<
decltype
(
repeat_lengths
)
>
{}([
&
](
auto
repeat_id
)
{
const
auto
src_thread_data_id_begin
=
repeat_id
*
data_per_cluster_per_dims
;
const
auto
buffer_data_id_begin
=
repeat_id
*
thread_sub_tensor_lengths
;
const
index_t
src_offset
=
SrcDesc
::
GetOffsetFromMultiIndex
(
src_thread_data_id_begin
);
const
index_t
buffer_offset
=
thread_buffer_desc
.
GetOffsetFromMultiIndex
(
buffer_data_id_begin
);
#endif
// By position the origin of the per-thread window at the point, where multi-index
// of the SrcDesc (might be a merged tensor) is all-zero. This threadwise slice copy
// is assuming each thread is copy a noraml (not merged) tensor.
// To satisfy this assumption, the user need to make sure that, on a merged dimension
// that constains multiple original dimensions, the length of the last original
// dimension need to be evenly dividable by its sub-lengths. Also, the repeat-length on
// the merged dimension need to be 1. These sanity checks are performed in constructor
// of BlockwiseGenericTensorSliceCopy_v1
ThreadwiseGenericTensorSliceCopy_v1r2
<
SrcDesc
,
decltype
(
thread_buffer_desc
),
SubLengths
,
SrcDimAccessOrder
,
SrcVectorAccessDim
,
SrcDataPerAccess
,
1
>
(
make_zero_array
<
index_t
,
nDim
>
(),
make_zero_array
<
index_t
,
nDim
>
())
.
Run
(
p_src
+
src_offset
+
mThreadSrcOffset
,
p_buffer
+
buffer_offset
);
});
}
template
<
class
TData
>
__device__
void
RunStoreRegisterBuffer
(
const
TData
*
__restrict__
p_buffer
,
TData
*
__restrict__
p_dst
)
const
{
constexpr
auto
thread_sub_tensor_lengths
=
SubLengths
{};
constexpr
auto
data_per_cluster_per_dims
=
thread_sub_tensor_lengths
*
ThreadClusterLengths
{};
constexpr
auto
repeat_lengths
=
SliceLengths
{}
/
(
SubLengths
{}
*
ThreadClusterLengths
{});
constexpr
auto
thread_buffer_desc
=
GetRegisterBufferDescriptor
();
#if CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1
static_ford
<
decltype
(
repeat_lengths
)
>
{}([
&
](
auto
repeat_id
)
{
constexpr
auto
buffer_data_id_begin
=
repeat_id
*
thread_sub_tensor_lengths
;
constexpr
auto
dst_data_id_begin
=
repeat_id
*
data_per_cluster_per_dims
;
constexpr
index_t
buffer_offset
=
thread_buffer_desc
.
GetOffsetFromMultiIndex
(
buffer_data_id_begin
);
constexpr
index_t
dst_offset
=
DstDesc
::
GetOffsetFromMultiIndex
(
dst_data_id_begin
);
#else
ford
<
decltype
(
repeat_lengths
)
>
{}([
&
](
auto
repeat_id
)
{
const
auto
buffer_data_id_begin
=
repeat_id
*
thread_sub_tensor_lengths
;
const
auto
dst_data_id_begin
=
repeat_id
*
data_per_cluster_per_dims
;
const
index_t
buffer_offset
=
thread_buffer_desc
.
GetOffsetFromMultiIndex
(
buffer_data_id_begin
);
const
index_t
dst_offset
=
DstDesc
::
GetOffsetFromMultiIndex
(
dst_data_id_begin
);
#endif
// By position the origin of the per-thread window at the point, where multi-index
// of the SrcDesc (might be a merged tensor) is all-zero. This threadwise slice copy
// is assuming each thread is copy a noraml (not merged) tensor.
// To satisfy this assumption, the user need to make sure that, on a merged dimension
// that constains multiple original dimensions, the length of the last original
// dimension need to be evenly dividable by its sub-lengths. Also, the repeat-length on
// the merged dimension need to be 1. These sanity checks are performed in constructor
// of BlockwiseGenericTensorSliceCopy_v1
ThreadwiseGenericTensorSliceCopy_v1r2
<
decltype
(
thread_buffer_desc
),
DstDesc
,
SubLengths
,
DstDimAccessOrder
,
DstVectorAccessDim
,
1
,
DstDataPerAccess
>
(
make_zero_array
<
index_t
,
nDim
>
(),
make_zero_array
<
index_t
,
nDim
>
())
.
Run
(
p_buffer
+
buffer_offset
,
p_dst
+
dst_offset
+
mThreadDstOffset
);
});
}
template
<
class
TData
>
__device__
void
Run
(
const
TData
*
__restrict__
p_src
,
TData
*
__restrict__
p_dst
)
const
{
TData
p_buffer
[
GetRegisterBufferSize
()];
RunLoadRegisterBuffer
(
p_src
,
p_buffer
);
RunStoreRegisterBuffer
(
p_buffer
,
p_dst
);
}
// When moving the slicing windows along a merged dimension, if the strides of the
// contained (by the merged dimension) original dimensions are not in descending order,
// then there is no guarantee that the new offset will be larger than the old offset
// for movement in positive direction (vice versue for movement in negative direction).
// As a result, there is the possiblity that the offset calculation may result in
// unsigned integer underflow (due to "-" operation). However, this hazard should not
// happen, as long as the users make sure the slicing window would not be moved out of
// the boundary of the tensor being sliced. This functions doesn't do runtime sanity
// check on out-of-bound slicing window, for performance reason
template
<
index_t
IDim_
,
index_t
StepSize
,
bool
PositiveDirection
>
__device__
void
MoveSlicingWindowOnSourceTensor
(
Number
<
IDim_
>
,
Number
<
StepSize
>
,
integral_constant
<
bool
,
PositiveDirection
>
direction
)
{
constexpr
auto
IDim
=
Number
<
IDim_
>
{};
static_if
<
SrcDesc
::
ContainMultipleOriginalDimensions
(
IDim
)
>
{}([
&
](
auto
)
{
// logic for a merged dimension, also works for non-merged dimension, but its logic may
// be unncessarily complicated for compiler to remove calculations that are useless for
// a non-merged dimension
// extract partial original dimensions
constexpr
auto
src_partial_original_dims
=
SrcDesc
::
GetContainedOriginalDimensions
(
IDim
);
constexpr
auto
src_partial_original_desc
=
SrcDesc
::
GetOriginalTensorDescriptor
().
Extract
(
src_partial_original_dims
);
// calculate new partial original multi-id
auto
old_src_partial_original_id
=
extract_array
(
mThreadSrcOriginalMultiId
,
src_partial_original_dims
);
auto
new_src_partial_original_id
=
src_partial_original_desc
.
UpdateMultiIndexGivenStepSizeOf1dIndex
(
old_src_partial_original_id
,
StepSize
,
direction
);
// update "mThreadSrcOriginalMultiId"
static_for
<
0
,
decltype
(
src_partial_original_dims
)
::
GetSize
(),
1
>
{}([
&
](
auto
I
)
{
constexpr
auto
IDimOriginal
=
src_partial_original_dims
[
I
];
mThreadSrcOriginalMultiId
(
IDimOriginal
)
=
new_src_partial_original_id
[
I
];
});
// calculate new partial offset on this merged dimension
const
index_t
old_src_partial_offset
=
mThreadSrcPartialOffsets
[
IDim
];
const
index_t
new_src_partial_offset
=
src_partial_original_desc
.
GetOffsetFromMultiIndex
(
new_src_partial_original_id
);
// update "mThreadSrcPartialOffsets"
mThreadSrcPartialOffsets
(
IDim
)
=
new_src_partial_offset
;
// update "mThreadSrcOffset", do "+" before "-" to avoid underflow
mThreadSrcOffset
=
(
mThreadSrcOffset
+
new_src_partial_offset
)
-
old_src_partial_offset
;
}).
Else
([
&
](
auto
)
{
// Logic for non-merged dimension. If you are never going to move the slicing window on
// a merged dimension, then "mThreadSrcOriginalMultiId" and "mThreadSrcPartialOffsets",
// which are being calculated here, will never be used later. In this case, compiler
// should be able to remove these calculations.
// TODO: make sure compiler would actually remove them in this case.
// It is the user's responsiblity to make sure the slicing window will not be moved out
// of the boundary of the tensor being sliced. Otherwise, there might be hazard like
// unsigned integer underflow. That is NO runtime sanity check to prevent the hazard
constexpr
auto
IDimOriginal
=
SrcDesc
::
GetContainedOriginalDimensions
(
IDim
).
Front
();
static_if
<
PositiveDirection
>
{}([
&
](
auto
fwd
)
{
mThreadSrcOffset
+=
StepSize
*
fwd
(
SrcDesc
{}).
GetStride
(
IDim
);
mThreadSrcOriginalMultiId
(
IDimOriginal
)
+=
StepSize
;
mThreadSrcPartialOffsets
(
IDim
)
+=
StepSize
*
fwd
(
SrcDesc
{}).
GetStride
(
IDim
);
}).
Else
([
&
](
auto
fwd
)
{
mThreadSrcOffset
-=
StepSize
*
fwd
(
SrcDesc
{}).
GetStride
(
IDim
);
mThreadSrcOriginalMultiId
(
IDimOriginal
)
-=
StepSize
;
mThreadSrcPartialOffsets
(
IDim
)
-=
StepSize
*
fwd
(
SrcDesc
{}).
GetStride
(
IDim
);
});
});
}
template
<
class
T
,
bool
PositiveDirection
>
__device__
void
MoveSrcSlicingWindow
(
T
step_sizes
,
integral_constant
<
bool
,
PositiveDirection
>
positive_direction
)
{
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
idim
)
{
if
(
step_sizes
[
idim
]
!=
0
)
{
MoveSlicingWindowOnSourceTensor
(
idim
,
step_sizes
[
idim
],
positive_direction
);
}
});
}
};
template
<
index_t
BlockSize
,
class
SrcDesc
,
class
DstDesc
,
class
SrcCoordinate
,
class
DstCoordinate
,
class
SliceLengths
,
class
SubLengths
,
class
ThreadClusterLengths
,
class
ThreadClusterArrangeOrder
,
class
SrcDimAccessOrder
,
class
DstDimAccessOrder
,
index_t
SrcVectorAccessDim
,
index_t
DstVectorAccessDim
,
index_t
SrcDataPerAccess
,
index_t
DstDataPerAccess
>
struct
BlockwiseGenericTensorSliceCopy_v2
{
static
constexpr
index_t
nDim
=
SrcDesc
::
GetNumOfDimension
();
__device__
constexpr
BlockwiseGenericTensorSliceCopy_v2
(
SrcCoordinate
src_block_slice_origin
,
DstCoordinate
dst_block_slice_origin
)
{
static_assert
(
nDim
==
SrcDesc
::
GetNumOfDimension
()
&&
nDim
==
DstDesc
::
GetNumOfDimension
()
&&
nDim
==
SliceLengths
::
GetSize
()
&&
nDim
==
SubLengths
::
GetSize
()
&&
nDim
==
ThreadClusterLengths
::
GetSize
()
&&
nDim
==
ThreadClusterArrangeOrder
::
GetSize
(),
"wrong! nDim not consistent"
);
static_assert
(
is_same
<
SliceLengths
,
decltype
(
SubLengths
{}
*
ThreadClusterLengths
{})
>
{},
"wrong! threads should be mapped to cover entire slicing window"
);
constexpr
auto
thread_cluster_desc
=
make_ConstantTensorDescriptor_packed
(
ThreadClusterLengths
::
ReorderGivenNew2Old
(
ThreadClusterArrangeOrder
{}));
static_assert
(
BlockSize
==
thread_cluster_desc
.
GetElementSize
(),
"wrong! BlockSize not consistent with ThreadClusterLengths"
);
const
auto
thread_cluster_id
=
thread_cluster_desc
.
GetMultiIndexFrom1dIndex
(
get_thread_local_1d_id
());
const
auto
data_cluster_id
=
reorder_array_given_old2new
(
thread_cluster_id
,
ThreadClusterArrangeOrder
{});
const
auto
thread_data_id_begin
=
data_cluster_id
*
SubLengths
{};
mThreadwiseLoad
.
SetSrcSliceOrigin
(
src_block_slice_origin
+
thread_data_id_begin
);
mThreadwiseLoad
.
SetDstSliceOrigin
(
make_zero_array
<
index_t
,
nDim
>
());
mThreadwiseStore
.
SetSrcSliceOrigin
(
make_zero_array
<
index_t
,
nDim
>
());
mThreadwiseStore
.
SetDstSliceOrigin
(
dst_block_slice_origin
+
thread_data_id_begin
);
}
__device__
static
constexpr
index_t
GetRegisterBufferSize
()
{
return
RegisterBufferDesc
::
GetElementSpace
();
}
template
<
class
TData
>
__device__
void
RunLoadRegisterBuffer
(
const
TData
*
p_src
,
TData
*
p_buffer
)
const
{
mThreadwiseLoad
.
Run
(
p_src
,
p_buffer
);
}
template
<
class
TData
>
__device__
void
RunStoreRegisterBuffer
(
const
TData
*
p_buffer
,
TData
*
p_dst
)
const
{
mThreadwiseStore
.
Run
(
p_buffer
,
p_dst
);
}
template
<
class
TData
>
__device__
void
Run
(
const
TData
*
p_src
,
TData
*
p_dst
)
const
{
TData
p_buffer
[
GetRegisterBufferSize
()];
mThreadwiseLoad
.
Run
(
p_src
,
p_buffer
);
mThreadwiseStore
.
Run
(
p_buffer
,
p_dst
);
}
template
<
class
T
,
bool
PositiveDirection
>
__device__
void
MoveSrcSlicingWindow
(
T
step_sizes
,
integral_constant
<
bool
,
PositiveDirection
>
positive_direction
)
{
mThreadwiseLoad
.
MoveSrcSlicingWindow
(
step_sizes
,
positive_direction
);
}
template
<
class
T
,
bool
PositiveDirection
>
__device__
void
MoveDstSlicingWindow
(
T
step_sizes
,
integral_constant
<
bool
,
PositiveDirection
>
positive_direction
)
{
mThreadwiseStore
.
MoveDstSlicingWindow
(
step_sizes
,
positive_direction
);
}
private:
using
RegisterBufferDesc
=
decltype
(
make_ConstantTensorDescriptor_packed
(
SubLengths
{}));
using
ThreadwiseLoad
=
ThreadwiseGenericTensorSliceCopy_v2r1
<
SrcDesc
,
RegisterBufferDesc
,
SrcCoordinate
,
NormalTensorCoordinate
<
RegisterBufferDesc
>
,
SubLengths
,
SrcDimAccessOrder
,
SrcDimAccessOrder
,
SrcVectorAccessDim
,
SrcVectorAccessDim
,
SrcDataPerAccess
,
1
>
;
using
ThreadwiseStore
=
ThreadwiseGenericTensorSliceCopy_v2r1
<
RegisterBufferDesc
,
DstDesc
,
NormalTensorCoordinate
<
RegisterBufferDesc
>
,
DstCoordinate
,
SubLengths
,
DstDimAccessOrder
,
DstDimAccessOrder
,
DstVectorAccessDim
,
DstVectorAccessDim
,
1
,
DstDataPerAccess
>
;
ThreadwiseLoad
mThreadwiseLoad
;
ThreadwiseStore
mThreadwiseStore
;
};
// this will be deprecated
// slice a (normal or merged) tensor, and copy it into another (normal or merged) tensor
// slice a (normal or merged) tensor, and copy it into another (normal or merged) tensor
// memory layout (ordering of dimensions) can be different between src and dst
// memory layout (ordering of dimensions) can be different between src and dst
// For now, only support SubLengths[...] == 1 on a merged dimension
// For now, only support SubLengths[...] == 1 on a merged dimension
...
@@ -31,7 +560,7 @@ template <index_t BlockSize,
...
@@ -31,7 +560,7 @@ template <index_t BlockSize,
class
DstAccessOrder
,
class
DstAccessOrder
,
index_t
SrcDataPerRead
,
index_t
SrcDataPerRead
,
index_t
DstDataPerWrite
>
index_t
DstDataPerWrite
>
struct
BlockwiseGenericTensorSliceCopy_v1
struct
BlockwiseGenericTensorSliceCopy_v1
_deprecated
{
{
static
constexpr
index_t
nDim
=
SrcDesc
::
GetNumOfDimension
();
static
constexpr
index_t
nDim
=
SrcDesc
::
GetNumOfDimension
();
...
@@ -59,9 +588,9 @@ struct BlockwiseGenericTensorSliceCopy_v1
...
@@ -59,9 +588,9 @@ struct BlockwiseGenericTensorSliceCopy_v1
Array
<
index_t
,
nOriginalDimSrc
>
mThreadSrcOriginalMultiId
;
Array
<
index_t
,
nOriginalDimSrc
>
mThreadSrcOriginalMultiId
;
Array
<
index_t
,
nOriginalDimDst
>
mThreadDstOriginalMultiId
;
Array
<
index_t
,
nOriginalDimDst
>
mThreadDstOriginalMultiId
;
__device__
__device__
BlockwiseGenericTensorSliceCopy_v1_deprecated
(
BlockwiseGenericTensorSliceCopy_v1
(
Array
<
index_t
,
nDim
>
src_block_data_multi_id_begin
,
Array
<
index_t
,
nDim
>
src_block_data_multi_id_begin
,
Array
<
index_t
,
nDim
>
dst_block_data_multi_id_begin
)
Array
<
index_t
,
nDim
>
dst_block_data_multi_id_begin
)
{
{
// check NDim consistency
// check NDim consistency
static_assert
(
nDim
==
SrcDesc
::
GetNumOfDimension
()
&&
static_assert
(
nDim
==
SrcDesc
::
GetNumOfDimension
()
&&
...
@@ -213,15 +742,16 @@ struct BlockwiseGenericTensorSliceCopy_v1
...
@@ -213,15 +742,16 @@ struct BlockwiseGenericTensorSliceCopy_v1
thread_tensor_desc
.
GetOffsetFromMultiIndex
(
clipboard_data_multi_id_begin
);
thread_tensor_desc
.
GetOffsetFromMultiIndex
(
clipboard_data_multi_id_begin
);
#endif
#endif
threadwise_generic_tensor_slice_copy_v1
(
SrcDesc
{},
threadwise_generic_tensor_slice_copy_v1_deprecated
(
SrcDesc
{},
p_src
+
src_offset
+
mThreadSrcOffset
,
p_src
+
src_offset
+
make_zero_array
<
index_t
,
nDim
>
(),
mThreadSrcOffset
,
thread_tensor_desc
,
make_zero_array
<
index_t
,
nDim
>
(),
p_clipboard
+
clipboard_offset
,
thread_tensor_desc
,
make_zero_array
<
index_t
,
nDim
>
(),
p_clipboard
+
clipboard_offset
,
thread_sub_tensor_lengths
,
make_zero_array
<
index_t
,
nDim
>
(),
SrcAccessOrder
{},
thread_sub_tensor_lengths
,
Number
<
SrcDataPerRead
>
{});
SrcAccessOrder
{},
Number
<
SrcDataPerRead
>
{});
});
});
}
}
...
@@ -264,15 +794,16 @@ struct BlockwiseGenericTensorSliceCopy_v1
...
@@ -264,15 +794,16 @@ struct BlockwiseGenericTensorSliceCopy_v1
const
index_t
dst_offset
=
DstDesc
{}.
GetOffsetFromMultiIndex
(
dst_data_multi_id_begin
);
const
index_t
dst_offset
=
DstDesc
{}.
GetOffsetFromMultiIndex
(
dst_data_multi_id_begin
);
#endif
#endif
threadwise_generic_tensor_slice_copy_v1
(
thread_tensor_desc
,
threadwise_generic_tensor_slice_copy_v1_deprecated
(
thread_tensor_desc
,
p_clipboard
+
clipboard_offset
,
p_clipboard
+
clipboard_offset
,
make_zero_array
<
index_t
,
nDim
>
(),
make_zero_array
<
index_t
,
nDim
>
(),
DstDesc
{},
DstDesc
{},
p_dst
+
dst_offset
+
mThreadDstOffset
,
p_dst
+
dst_offset
+
make_zero_array
<
index_t
,
nDim
>
(),
mThreadDstOffset
,
thread_sub_tensor_lengths
,
make_zero_array
<
index_t
,
nDim
>
(),
DstAccessOrder
{},
thread_sub_tensor_lengths
,
Number
<
DstDataPerWrite
>
{});
DstAccessOrder
{},
Number
<
DstDataPerWrite
>
{});
});
});
}
}
...
...
composable_kernel/include/tensor_operation/threadwise_gemm.hpp
View file @
32850b93
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
#include "common_header.hpp"
#include "common_header.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "
float_types.h
"
#include "
math.hpp
"
namespace
ck
{
namespace
ck
{
...
@@ -37,58 +37,18 @@ __device__ void threadwise_matrix_copy(SrcMatrix,
...
@@ -37,58 +37,18 @@ __device__ void threadwise_matrix_copy(SrcMatrix,
constexpr
auto
src_mtx
=
SrcMatrix
{};
constexpr
auto
src_mtx
=
SrcMatrix
{};
constexpr
auto
dst_mtx
=
DstMatrix
{};
constexpr
auto
dst_mtx
=
DstMatrix
{};
using
vector_t
=
typename
vector_type
<
Float
,
DataPerRead
>::
MemoryType
;
// Depending upon datatype i.e float/half/bfloat16, carry out data movement
for
(
index_t
i
=
0
;
i
<
NRow
;
++
i
)
// in appropriate vectorized form
{
// float - 4, half - 4, bfloat16 - 2
for
(
index_t
j
=
0
;
j
<
NCol
;
j
+=
DataPerRead
)
static_if
<
std
::
is_same
<
Float
,
float
>::
value
>
{}([
&
](
auto
)
{
using
vector_t
=
typename
vector_type
<
float
,
DataPerRead
>::
MemoryType
;
for
(
index_t
i
=
0
;
i
<
NRow
;
++
i
)
{
{
for
(
index_t
j
=
0
;
j
<
NCol
;
j
+=
DataPerRead
)
const
index_t
src_index
=
src_mtx
.
GetOffsetFromMultiIndex
(
i
,
j
);
{
const
index_t
dst_index
=
dst_mtx
.
GetOffsetFromMultiIndex
(
i
,
j
);
const
index_t
src_index
=
src_mtx
.
GetOffsetFromMultiIndex
(
i
,
j
);
const
index_t
dst_index
=
dst_mtx
.
GetOffsetFromMultiIndex
(
i
,
j
);
*
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_index
])
=
*
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_index
])
=
*
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_index
]);
*
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_index
]);
}
}
}
}
}).
Else
([
&
](
auto
)
{
static_if
<
std
::
is_same
<
Float
,
half
>::
value
>
{}([
&
](
auto
)
{
// If src/dst matrix datatype is bfloat16/float16 (vector size 2/4 respectively)
using
vector_t
=
typename
vector_type
<
Float
,
4
>::
MemoryType
;
for
(
index_t
i
=
0
;
i
<
NRow
;
++
i
)
{
for
(
index_t
j
=
0
;
j
<
NCol
;
++
j
)
{
const
index_t
src_index
=
src_mtx
.
GetOffsetFromMultiIndex
(
i
,
j
);
const
index_t
dst_index
=
dst_mtx
.
GetOffsetFromMultiIndex
(
i
,
j
);
*
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_index
*
4
])
=
*
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_index
*
4
]);
}
}
}).
Else
([
&
](
auto
)
{
using
vector_t
=
typename
vector_type
<
Float
,
2
>::
MemoryType
;
for
(
index_t
i
=
0
;
i
<
NRow
;
++
i
)
{
for
(
index_t
j
=
0
;
j
<
NCol
;
++
j
)
{
const
index_t
src_index
=
src_mtx
.
GetOffsetFromMultiIndex
(
i
,
j
);
const
index_t
dst_index
=
dst_mtx
.
GetOffsetFromMultiIndex
(
i
,
j
);
*
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_index
*
2
])
=
*
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_index
*
2
]);
}
}
});
});
}
}
template
<
class
MatrixA
,
template
<
class
MatrixA
,
...
@@ -119,7 +79,6 @@ __device__ void threadwise_gemm(MatrixA,
...
@@ -119,7 +79,6 @@ __device__ void threadwise_gemm(MatrixA,
constexpr
index_t
N
=
c_mtx
.
NCol
();
constexpr
index_t
N
=
c_mtx
.
NCol
();
constexpr
index_t
K
=
a_mtx
.
NRow
();
// A is transposed
constexpr
index_t
K
=
a_mtx
.
NRow
();
// A is transposed
for
(
index_t
k
=
0
;
k
<
K
;
++
k
)
for
(
index_t
k
=
0
;
k
<
K
;
++
k
)
{
{
for
(
index_t
i
=
0
;
i
<
M
;
++
i
)
for
(
index_t
i
=
0
;
i
<
M
;
++
i
)
...
@@ -130,32 +89,8 @@ __device__ void threadwise_gemm(MatrixA,
...
@@ -130,32 +89,8 @@ __device__ void threadwise_gemm(MatrixA,
const
index_t
bindex
=
b_mtx
.
GetOffsetFromMultiIndex
(
k
,
j
);
const
index_t
bindex
=
b_mtx
.
GetOffsetFromMultiIndex
(
k
,
j
);
const
index_t
cindex
=
c_mtx
.
GetOffsetFromMultiIndex
(
i
,
j
);
const
index_t
cindex
=
c_mtx
.
GetOffsetFromMultiIndex
(
i
,
j
);
static_if
<
std
::
is_same
<
FloatA
,
float
>::
value
>
{}([
&
](
auto
)
{
p_c_thread
[
cindex
]
+=
math
::
inner_product_with_conversion
<
FloatC
>
{}(
p_c_thread
[
cindex
]
+=
CVT_FLOAT2ACCUM
(
p_a_thread
[
aindex
])
*
p_a_thread
[
aindex
],
p_b_thread
[
bindex
]);
CVT_FLOAT2ACCUM
(
p_b_thread
[
bindex
]);
}).
Else
([
&
](
auto
)
{
static_if
<
std
::
is_same
<
FloatA
,
half
>::
value
>
{}([
&
](
auto
)
{
// If src/dst matrix datatype is bfloat16/float16 (vector size 2/4
// respectively)
float
acc
=
0.0
;
for
(
index_t
v
=
0
;
v
<
4
;
++
v
)
{
acc
+=
CVT_FLOAT2ACCUM
(
p_a_thread
[
aindex
*
4
+
v
])
*
CVT_FLOAT2ACCUM
(
p_b_thread
[
bindex
*
4
+
v
]);
}
p_c_thread
[
cindex
]
+=
acc
;
}).
Else
([
&
](
auto
)
{
// If src/dst matrix datatype is bfloat16/float16 (vector size 2/4
// respectively)
float
acc
=
0.0
;
for
(
index_t
v
=
0
;
v
<
2
;
++
v
)
{
acc
+=
CVT_FLOAT2ACCUM
(
p_a_thread
[
aindex
*
2
+
v
])
*
CVT_FLOAT2ACCUM
(
p_b_thread
[
bindex
*
2
+
v
]);
}
p_c_thread
[
cindex
]
+=
acc
;
});
});
}
}
}
}
}
}
...
...
composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp
View file @
32850b93
...
@@ -4,123 +4,802 @@
...
@@ -4,123 +4,802 @@
#include "common_header.hpp"
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "tensor_coordinate.hpp"
#include "float_types.h"
#include "float_types.h"
#ifndef CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1
#ifndef CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R1
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R1 0
#endif
#ifndef CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R2
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R2 0
#endif
#ifndef CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1 0
#endif
#endif
namespace
ck
{
namespace
ck
{
// user need to make sure alignment requirement is satisfied when setting DataPerAccesss > 1
// This threadwise copy allow vector access of src and dst.
template
<
class
SrcFloat
,
// It allows the dimensions of vector access to be different on src and dst.
class
DesFloat
,
// It also allows the vector size to be different on src and dst.
class
SrcDesc
,
// It also allows order of access to be different on src and dst.
// It use register as buffer to hold all data moving from src to dst.
// It is designed for copying small amount of data, and src and dst are
// device memory or LDS.
// When copying large amout of data, let's hope compiler will reduce register
// used for the buffer.
template
<
class
SrcDesc
,
class
DstDesc
,
class
SliceLengths
,
class
SrcDimAccessOrder
,
class
DstDimAccessOrder
,
index_t
SrcVectorAccessDim
,
index_t
DstVectorAccessDim
,
index_t
SrcDataPerAccess
,
index_t
DstDataPerAccess
>
struct
ThreadwiseGenericTensorSliceCopy_v1r1
{
static
constexpr
index_t
nDim
=
SliceLengths
::
GetSize
();
__device__
constexpr
ThreadwiseGenericTensorSliceCopy_v1r1
(
Array
<
index_t
,
nDim
>
src_slice_origin
,
Array
<
index_t
,
nDim
>
dst_slice_origin
)
:
mSrcSliceOrigin
(
src_slice_origin
),
mDstSliceOrigin
(
dst_slice_origin
)
{
static_assert
(
nDim
==
SrcDesc
::
GetNumOfDimension
()
&&
nDim
==
DstDesc
::
GetNumOfDimension
()
&&
nDim
==
SliceLengths
::
GetSize
()
&&
nDim
==
SrcDimAccessOrder
::
GetSize
()
&&
nDim
==
DstDimAccessOrder
::
GetSize
(),
"wrong! # of dimensions not the same"
);
static_assert
(
is_valid_sequence_map
<
SrcDimAccessOrder
>::
value
&&
is_valid_sequence_map
<
DstDimAccessOrder
>::
value
,
"wrong! map is not valid"
);
static_assert
(
SliceLengths
{}[
SrcVectorAccessDim
]
%
SrcDataPerAccess
==
0
&&
SliceLengths
{}[
DstVectorAccessDim
]
%
DstDataPerAccess
==
0
,
"wrong! cannot evenly divide"
);
// check vectorized memory access
constexpr
auto
src_vector_access_dim
=
Number
<
SrcVectorAccessDim
>
{};
constexpr
auto
dst_vector_access_dim
=
Number
<
DstVectorAccessDim
>
{};
static_if
<!
SrcDesc
::
ContainMultipleOriginalDimensions
(
src_vector_access_dim
)
>
{}(
[
&
](
auto
fwd
)
{
static_assert
(
(
fwd
(
SrcDesc
{}).
GetStride
(
src_vector_access_dim
)
==
1
||
SrcDataPerAccess
==
1
),
"wrong! vectorized access is allowed only if stride == 1"
);
})
.
Else
([
&
](
auto
fwd
)
{
static_assert
(
(
fwd
(
SrcDesc
{}).
GetLastOriginalDimensionStride
(
src_vector_access_dim
)
==
1
||
SrcDataPerAccess
==
1
),
"wrong! vectorized access is allowed only if stride == 1"
);
});
static_if
<!
DstDesc
::
ContainMultipleOriginalDimensions
(
dst_vector_access_dim
)
>
{}(
[
&
](
auto
fwd
)
{
static_assert
(
(
fwd
(
DstDesc
{}).
GetStride
(
dst_vector_access_dim
)
==
1
||
DstDataPerAccess
==
1
),
"wrong! vectorized access is allowed only if stride == 1"
);
})
.
Else
([
&
](
auto
fwd
)
{
static_assert
(
(
fwd
(
DstDesc
{}).
GetLastOriginalDimensionStride
(
dst_vector_access_dim
)
==
1
||
DstDataPerAccess
==
1
),
"wrong! vectorized access is allowed only if stride == 1"
);
});
}
__device__
constexpr
ThreadwiseGenericTensorSliceCopy_v1r1
()
:
ThreadwiseGenericTensorSliceCopy_v1r1
(
make_zero_array
<
index_t
,
nDim
>
(),
make_zero_array
<
index_t
,
nDim
>
())
{
}
__device__
void
SetSrcSliceOrigin
(
Array
<
index_t
,
nDim
>
src_slice_origin
)
{
mSrcSliceOrigin
=
src_slice_origin
;
}
__device__
void
SetDstSliceOrigin
(
Array
<
index_t
,
nDim
>
dst_slice_origin
)
{
mDstSliceOrigin
=
dst_slice_origin
;
}
template
<
class
SrcData
,
class
DstData
>
__device__
void
Run
(
const
SrcData
*
p_src
,
DstData
*
p_dst
)
const
{
constexpr
auto
buffer_desc
=
make_ConstantTensorDescriptor_packed
(
SliceLengths
{});
SrcData
p_src_buffer_
[
buffer_desc
.
GetElementSpace
()];
SrcData
*
p_src_buffer
=
p_src_buffer_
;
// copy data from src into src buffer
{
using
src_vector_t
=
typename
vector_type
<
SrcData
,
SrcDataPerAccess
>::
MemoryType
;
constexpr
auto
src_vector_access_dim
=
Number
<
SrcVectorAccessDim
>
{};
constexpr
auto
src_data_per_access
=
Number
<
SrcDataPerAccess
>
{};
constexpr
auto
src_access_lengths
=
SliceLengths
::
Modify
(
src_vector_access_dim
,
SliceLengths
::
Get
(
src_vector_access_dim
)
/
src_data_per_access
);
#if CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R1
static_ford
<
decltype
(
src_access_lengths
),
SrcDimAccessOrder
>
{}([
&
](
auto
src_access_id
)
{
constexpr
auto
src_data_begin_id
=
src_access_id
.
Modify
(
src_vector_access_dim
,
src_access_id
[
src_vector_access_dim
]
*
src_data_per_access
);
const
index_t
src_offset
=
SrcDesc
::
GetOffsetFromMultiIndex
(
mSrcSliceOrigin
+
src_data_begin_id
);
// load vector from src
const
src_vector_t
src_vector_data
=
*
reinterpret_cast
<
const
src_vector_t
*>
(
&
p_src
[
src_offset
]);
// unpack vector into buffer
static_for
<
0
,
SrcDataPerAccess
,
1
>
{}([
&
](
auto
i
)
{
constexpr
auto
scalar_id
=
typename
uniform_sequence_gen
<
nDim
,
0
>::
type
{}.
Modify
(
src_vector_access_dim
,
i
);
constexpr
index_t
buffer_offset
=
buffer_desc
.
GetOffsetFromMultiIndex
(
src_data_begin_id
+
scalar_id
);
p_src_buffer
[
buffer_offset
]
=
reinterpret_cast
<
const
SrcData
*>
(
&
src_vector_data
)[
i
];
});
});
#else
ford
<
decltype
(
src_access_lengths
),
SrcDimAccessOrder
>
{}([
&
](
auto
src_access_id
)
{
auto
src_data_begin_id
=
src_access_id
;
src_data_begin_id
(
src_vector_access_dim
)
=
src_access_id
[
src_vector_access_dim
]
*
src_data_per_access
;
const
index_t
src_offset
=
SrcDesc
::
GetOffsetFromMultiIndex
(
mSrcSliceOrigin
+
src_data_begin_id
);
// load vector from src
const
src_vector_t
src_vector_data
=
*
reinterpret_cast
<
const
src_vector_t
*>
(
&
p_src
[
src_offset
]);
// unpack vector into buffer
for
(
index_t
i
=
0
;
i
<
SrcDataPerAccess
;
++
i
)
{
auto
scalar_id
=
make_zero_array
<
index_t
,
nDim
>
();
scalar_id
(
src_vector_access_dim
)
=
i
;
const
index_t
buffer_offset
=
buffer_desc
.
GetOffsetFromMultiIndex
(
src_data_begin_id
+
scalar_id
);
p_src_buffer
[
buffer_offset
]
=
reinterpret_cast
<
const
SrcData
*>
(
&
src_vector_data
)[
i
];
}
});
#endif
}
// copy data from buffer to dst
{
using
dst_vector_t
=
typename
vector_type
<
DstData
,
DstDataPerAccess
>::
MemoryType
;
constexpr
auto
dst_vector_access_dim
=
Number
<
DstVectorAccessDim
>
{};
constexpr
auto
dst_data_per_access
=
Number
<
DstDataPerAccess
>
{};
constexpr
auto
dst_access_lengths
=
SliceLengths
::
Modify
(
dst_vector_access_dim
,
SliceLengths
::
Get
(
dst_vector_access_dim
)
/
dst_data_per_access
);
#if CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R1
static_ford
<
decltype
(
dst_access_lengths
),
DstDimAccessOrder
>
{}([
&
](
auto
dst_access_id
)
{
constexpr
auto
dst_data_begin_id
=
dst_access_id
.
Modify
(
dst_vector_access_dim
,
dst_access_id
[
dst_vector_access_dim
]
*
dst_data_per_access
);
dst_vector_t
dst_vector_data
;
// pack vector from buffer and type conversion
static_for
<
0
,
DstDataPerAccess
,
1
>
{}([
&
](
auto
i
)
{
constexpr
auto
scalar_id
=
typename
uniform_sequence_gen
<
nDim
,
0
>::
type
{}.
Modify
(
dst_vector_access_dim
,
i
);
constexpr
index_t
buffer_offset
=
buffer_desc
.
GetOffsetFromMultiIndex
(
dst_data_begin_id
+
scalar_id
);
// SrcData to DstData type conversion is done here
reinterpret_cast
<
DstData
*>
(
&
dst_vector_data
)[
i
]
=
type_convert
<
DstData
>
{}(
p_src_buffer
[
buffer_offset
]);
});
const
index_t
dst_offset
=
DstDesc
::
GetOffsetFromMultiIndex
(
mDstSliceOrigin
+
dst_data_begin_id
);
// store vector into dst
*
reinterpret_cast
<
dst_vector_t
*>
(
&
p_dst
[
dst_offset
])
=
dst_vector_data
;
});
#else
ford
<
decltype
(
dst_access_lengths
),
DstDimAccessOrder
>
{}([
&
](
auto
dst_access_id
)
{
auto
dst_data_begin_id
=
dst_access_id
;
dst_data_begin_id
(
dst_vector_access_dim
)
=
dst_access_id
[
dst_vector_access_dim
]
*
dst_data_per_access
;
dst_vector_t
dst_vector_data
;
// pack vector from buffer and type conversion
for
(
index_t
i
=
0
;
i
<
DstDataPerAccess
;
++
i
)
{
auto
scalar_id
=
make_zero_array
<
index_t
,
nDim
>
();
scalar_id
(
dst_vector_access_dim
)
=
i
;
const
index_t
buffer_offset
=
buffer_desc
.
GetOffsetFromMultiIndex
(
dst_data_begin_id
+
scalar_id
);
// SrcData to DstData type conversion is done here
reinterpret_cast
<
DstData
*>
(
&
dst_vector_data
)[
i
]
=
type_convert
<
DstData
>
{}(
p_src_buffer
[
buffer_offset
]);
}
const
index_t
dst_offset
=
DstDesc
::
GetOffsetFromMultiIndex
(
mDstSliceOrigin
+
dst_data_begin_id
);
// store vector into dst
*
reinterpret_cast
<
dst_vector_t
*>
(
&
p_dst
[
dst_offset
])
=
dst_vector_data
;
});
#endif
}
}
private:
Array
<
index_t
,
nDim
>
mSrcSliceOrigin
;
Array
<
index_t
,
nDim
>
mDstSliceOrigin
;
};
// This threadwise copy allow vector access of src and dst.
// It allows the vector size to be different on src and dst.
// The dimensions of vector access should be the same on src and dst.
// The dimension access order should be the same on src and dst.
// It is designed for cases, where one of src and dst is register, and
// the other is device memory or LDS
template
<
class
SrcDesc
,
class
DstDesc
,
class
DstDesc
,
class
SliceLengths
,
class
SliceLengths
,
class
DimAccessOrder
,
class
DimAccessOrder
,
index_t
DataPerAccess
>
index_t
VectorAccessDim
,
__device__
void
threadwise_generic_tensor_slice_copy_v1
(
index_t
SrcDataPerAccess
,
SrcDesc
,
index_t
DstDataPerAccess
>
const
SrcFloat
*
__restrict__
p_src
,
struct
ThreadwiseGenericTensorSliceCopy_v1r2
Array
<
index_t
,
SrcDesc
::
GetNumOfDimension
()
>
src_multi_id_begin
,
DstDesc
,
DesFloat
*
__restrict__
p_dst
,
Array
<
index_t
,
DstDesc
::
GetNumOfDimension
()
>
dst_multi_id_begin
,
SliceLengths
,
DimAccessOrder
,
Number
<
DataPerAccess
>
)
{
{
constexpr
index_t
nDim
=
SrcDesc
::
GetNumOfDimension
();
static
constexpr
index_t
nDim
=
SliceLengths
::
GetSize
();
__device__
constexpr
ThreadwiseGenericTensorSliceCopy_v1r2
(
Array
<
index_t
,
nDim
>
src_slice_origin
,
Array
<
index_t
,
nDim
>
dst_slice_origin
)
:
mSrcSliceOrigin
(
src_slice_origin
),
mDstSliceOrigin
(
dst_slice_origin
)
{
static_assert
(
nDim
==
SrcDesc
::
GetNumOfDimension
()
&&
nDim
==
DstDesc
::
GetNumOfDimension
()
&&
nDim
==
SliceLengths
::
GetSize
()
&&
nDim
==
DimAccessOrder
::
GetSize
(),
"wrong! # of dimensions not the same"
);
static_assert
(
is_valid_sequence_map
<
DimAccessOrder
>::
value
,
"wrong! map is not valid"
);
static_assert
(
SliceLengths
{}[
VectorAccessDim
]
%
math
::
lcm
(
SrcDataPerAccess
,
DstDataPerAccess
)
==
0
,
"wrong! cannot evenly divide"
);
// check vectorized memory access
constexpr
auto
vector_access_dim
=
Number
<
VectorAccessDim
>
{};
static_assert
(
nDim
==
SrcDesc
::
GetNumOfDimension
()
&&
nDim
==
DstDesc
::
GetNumOfDimension
()
&&
static_if
<!
SrcDesc
::
ContainMultipleOriginalDimensions
(
vector_access_dim
)
>
{}([
&
](
auto
fwd
)
{
nDim
==
SliceLengths
::
GetSize
()
&&
nDim
==
DimAccessOrder
::
GetSize
(),
static_assert
(
"wrong! # of dimensions not the same"
);
(
fwd
(
SrcDesc
{}).
GetStride
(
vector_access_dim
)
==
1
||
SrcDataPerAccess
==
1
),
"wrong! vectorized access is allowed only if stride == 1"
);
}).
Else
([
&
](
auto
fwd
)
{
static_assert
((
fwd
(
SrcDesc
{}).
GetLastOriginalDimensionStride
(
vector_access_dim
)
==
1
||
SrcDataPerAccess
==
1
),
"wrong! vectorized access is allowed only if stride == 1"
);
});
static_if
<!
DstDesc
::
ContainMultipleOriginalDimensions
(
vector_access_dim
)
>
{}([
&
](
auto
fwd
)
{
static_assert
(
(
fwd
(
DstDesc
{}).
GetStride
(
vector_access_dim
)
==
1
||
DstDataPerAccess
==
1
),
"wrong! vectorized access is allowed only if stride == 1"
);
}).
Else
([
&
](
auto
fwd
)
{
static_assert
((
fwd
(
DstDesc
{}).
GetLastOriginalDimensionStride
(
vector_access_dim
)
==
1
||
DstDataPerAccess
==
1
),
"wrong! vectorized access is allowed only if stride == 1"
);
});
}
static_assert
(
is_valid_sequence_map
<
DimAccessOrder
>::
value
,
"wrong! map is not valid"
);
__device__
constexpr
ThreadwiseGenericTensorSliceCopy_v1r2
()
:
ThreadwiseGenericTensorSliceCopy_v1r2
(
make_zero_array
<
index_t
,
nDim
>
(),
make_zero_array
<
index_t
,
nDim
>
())
{
}
// TODO: do more sanity-check here, something like:
__device__
void
SetSrcSliceOrigin
(
Array
<
index_t
,
nDim
>
src_slice_origin
)
// constexpr auto src_strides_in_access_order =
{
// SrcDesc::ReorderGivenNew2Old(DimAccessOrder{}).GetStride(Number<nDim-1>{});
mSrcSliceOrigin
=
src_slice_origin
;
}
// constexpr auto dst_strides_in_access_order =
__device__
void
SetDstSliceOrigin
(
Array
<
index_t
,
nDim
>
dst_slice_origin
)
// SrcDesc::ReorderGivenNew2Old(DimAccessOrder{}).GetStride(Number<nDim-1>{});
{
mDstSliceOrigin
=
dst_slice_origin
;
}
// // check src/dst stride on the lowest access dimension
template
<
class
SrcData
,
class
DstData
>
// static_assert((DataPerAccess == 1 || src_strides_in_access_order.Back() == 1) &&
__device__
void
Run
(
const
SrcData
*
p_src
,
DstData
*
p_dst
)
const
// (DataPerAccess == 1 || dst_strides_in_access_order.Back() == 1),
{
// "wro
ng
!
src
/dst stride on the lowest access dimension needs to be 1 for "
usi
ng
src
_vector_t
=
typename
vector_type
<
SrcData
,
SrcDataPerAccess
>::
MemoryType
;
//
"vectorized read/write")
;
using
dst_vector_t
=
typename
vector_type
<
DstData
,
DstDataPerAccess
>::
MemoryType
;
constexpr
auto
slice_lengths_in_access_order
=
constexpr
auto
vector_access_dim
=
Number
<
VectorAccessDim
>
{};
SliceLengths
::
ReorderGivenNew2Old
(
DimAccessOrder
{});
// check slice length on the lowest access dimension
constexpr
auto
src_data_per_access
=
Number
<
SrcDataPerAccess
>
{};
static_assert
(
slice_lengths_in_access_order
.
Back
()
%
DataPerAccess
==
0
,
constexpr
auto
dst_data_per_access
=
Number
<
DstDataPerAccess
>
{};
"wrong! slice length on the lowest access dimension should be evenly divided by "
"DataPerAccess"
);
constexpr
index_t
num_access_on_lowest_access_dimension
=
constexpr
auto
long_vector_size
=
Number
<
math
::
lcm
(
SrcDataPerAccess
,
DstDataPerAccess
)
>
{};
slice_lengths_in_access_order
.
Back
()
/
DataPerAccess
;
constexpr
auto
access_lengths
=
s
lice
_l
engths
_in_access_order
.
Modify
(
constexpr
auto
long_vector_
access_lengths
=
S
lice
L
engths
::
Modify
(
Number
<
nDim
-
1
>
{},
Number
<
num_access_on_lowest_access_dimension
>
{}
);
vector_access_dim
,
SliceLengths
::
Get
(
vector_access_dim
)
/
long_vector_size
);
using
vector_src_t
=
typename
vector_type
<
SrcFloat
,
DataPerAccess
>::
MemoryType
;
#if CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R2
using
vector_dest_t
=
typename
vector_type
<
DesFloat
,
DataPerAccess
>::
MemoryType
;
static_ford
<
decltype
(
long_vector_access_lengths
),
DimAccessOrder
>
{}([
&
](
auto
long_vector_access_id
)
{
#if CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1
// data id w.r.t slicing-window
static_ford
<
decltype
(
access_lengths
)
>
{}([
&
](
auto
access_multi_id
)
{
constexpr
auto
long_vector_data_begin_id
=
long_vector_access_id
.
Modify
(
constexpr
index_t
itmp
=
access_multi_id
.
Back
()
*
DataPerAccess
;
vector_access_dim
,
long_vector_access_id
[
vector_access_dim
]
*
long_vector_size
)
;
constexpr
auto
data_multi_id_in_access_order
=
// buffer to hold a long-vector
access_multi_id
.
Modify
(
Number
<
nDim
-
1
>
{},
Number
<
itmp
>
{});
SrcData
p_src_long_vector
[
long_vector_size
];
DstData
p_dst_long_vector
[
long_vector_size
];
constexpr
auto
data_multi_id
=
reorder_array_given_old2new
(
// load data from src to the long-vector buffer
sequence2array
(
data_multi_id_in_access_order
),
DimAccessOrder
{});
static_for
<
0
,
long_vector_size
/
src_data_per_access
,
1
>
{}([
&
](
auto
i
)
{
constexpr
auto
scalar_id
=
typename
uniform_sequence_gen
<
nDim
,
0
>::
type
{}.
Modify
(
vector_access_dim
,
i
*
src_data_per_access
);
const
index_t
src_
i
ndex
=
const
index_t
src_
offset
=
SrcDesc
::
GetOffsetFromMultiI
ndex
(
SrcDesc
::
GetOffsetFromMultiIndex
(
src_multi_id_begin
+
data_multi
_id
);
mSrcSliceOrigin
+
(
long_vector_data_begin_id
+
scalar
_id
)
)
;
const
index_t
dst_index
=
constexpr
index_t
buffer_offset
=
i
*
src_data_per_access
;
DstDesc
::
GetOffsetFromMultiIndex
(
dst_multi_id_begin
+
data_multi_id
);
static_if
<
std
::
is_same
<
vector_src_t
,
vector_dest_t
>::
value
>
{}([
&
](
auto
)
{
*
reinterpret_cast
<
src_vector_t
*>
(
&
p_src_long_vector
[
buffer_offset
])
=
*
reinterpret_cast
<
vector_dest_t
*>
(
&
p_dst
[
dst_index
])
=
*
reinterpret_cast
<
const
src_vector_t
*>
(
&
p_src
[
src_offset
]);
*
reinterpret_cast
<
const
vector_src_t
*>
(
&
p_src
[
src_index
]);
});
}).
Else
([
&
](
auto
)
{
for
(
unsigned
int
data_idx
=
0
;
data_idx
<
DataPerAccess
;
++
data_idx
)
// type conversion
for
(
index_t
i
=
0
;
i
<
long_vector_size
;
++
i
)
{
{
p_dst
[
dst_index
+
data_idx
]
=
CVT_ACCUM2FLOAT
(
p_src
[
src_index
+
data_idx
]);
p_dst
_long_vector
[
i
]
=
type_convert
<
DstType
>
{}(
p_src_long_vector
[
i
]);
}
}
// store data from the long-vector buffer to dst
static_for
<
0
,
long_vector_size
/
dst_data_per_access
,
1
>
{}([
&
](
auto
i
)
{
constexpr
auto
scalar_id
=
typename
uniform_sequence_gen
<
nDim
,
0
>::
type
{}.
Modify
(
vector_access_dim
,
i
*
dst_data_per_access
);
constexpr
index_t
buffer_offset
=
i
*
dst_data_per_access
;
const
index_t
dst_offset
=
DstDesc
::
GetOffsetFromMultiIndex
(
mDstSliceOrigin
+
(
long_vector_data_begin_id
+
scalar_id
));
*
reinterpret_cast
<
dst_vector_t
*>
(
&
p_dst
[
dst_offset
])
=
*
reinterpret_cast
<
dst_vector_t
*>
(
&
p_dst_long_vector
[
buffer_offset
]);
});
});
});
});
#else
#else
ford
<
decltype
(
access_lengths
)
>
{}([
&
](
auto
access_multi_id
)
{
ford
<
decltype
(
long_vector_access_lengths
),
DimAccessOrder
>
{}(
auto
data_multi_id_in_access_order
=
access_multi_id
;
[
&
](
auto
long_vector_access_id
)
{
data_multi_id_in_access_order
(
nDim
-
1
)
=
access_multi_id
[
nDim
-
1
]
*
DataPerAccess
;
const
auto
data_multi_id
=
// data id w.r.t slicing-window
reorder_array_given_old2new
(
data_multi_id_in_access_order
,
DimAccessOrder
{});
auto
long_vector_data_begin_id
=
long_vector_access_id
;
long_vector_data_begin_id
(
vector_access_dim
)
=
long_vector_size
*
long_vector_access_id
[
vector_access_dim
];
const
index_t
src_index
=
// buffer to hold a long-vector
SrcDesc
::
GetOffsetFromMultiIndex
(
src_multi_id_begin
+
data_multi_id
);
SrcData
p_src_long_vector
[
long_vector_size
];
DstData
p_dst_long_vector
[
long_vector_size
];
const
index_t
dst_index
=
// load data from src to the long-vector buffer
DstDesc
::
GetOffsetFromMultiIndex
(
dst_multi_id_begin
+
data_multi_id
);
for
(
index_t
i
=
0
;
i
<
long_vector_size
/
src_data_per_access
;
++
i
)
{
auto
scalar_id
=
make_zero_array
<
index_t
,
nDim
>
();
scalar_id
(
vector_access_dim
)
=
i
*
src_data_per_access
;
static_if
<
std
::
is_same
<
vector_src_t
,
vector_dest_t
>::
value
>
{}([
&
](
auto
)
{
const
index_t
src_offset
=
SrcDesc
::
GetOffsetFromMultiIndex
(
*
reinterpret_cast
<
vector_dest_t
*>
(
&
p_dst
[
dst_index
])
=
mSrcSliceOrigin
+
(
long_vector_data_begin_id
+
scalar_id
));
*
reinterpret_cast
<
const
vector_src_t
*>
(
&
p_src
[
src_index
]);
}).
Else
([
&
](
auto
)
{
const
index_t
buffer_offset
=
i
*
src_data_per_access
;
for
(
unsigned
int
data_idx
=
0
;
data_idx
<
DataPerAccess
;
++
data_idx
)
{
*
reinterpret_cast
<
src_vector_t
*>
(
&
p_src_long_vector
[
buffer_offset
])
=
p_dst
[
dst_index
+
data_idx
]
=
CVT_ACCUM2FLOAT
(
p_src
[
src_index
+
data_idx
]);
*
reinterpret_cast
<
const
src_vector_t
*>
(
&
p_src
[
src_offset
]);
}
}
});
});
// type conversion
for
(
index_t
i
=
0
;
i
<
long_vector_size
;
++
i
)
{
p_dst_long_vector
[
i
]
=
type_convert
<
DstData
>
{}(
p_src_long_vector
[
i
]);
}
// store data from the long-vector buffer to dst
for
(
index_t
i
=
0
;
i
<
long_vector_size
/
dst_data_per_access
;
++
i
)
{
auto
scalar_id
=
make_zero_array
<
index_t
,
nDim
>
();
scalar_id
(
vector_access_dim
)
=
i
*
dst_data_per_access
;
const
index_t
buffer_offset
=
i
*
dst_data_per_access
;
const
index_t
dst_offset
=
DstDesc
::
GetOffsetFromMultiIndex
(
mDstSliceOrigin
+
(
long_vector_data_begin_id
+
scalar_id
));
*
reinterpret_cast
<
dst_vector_t
*>
(
&
p_dst
[
dst_offset
])
=
*
reinterpret_cast
<
dst_vector_t
*>
(
&
p_dst_long_vector
[
buffer_offset
]);
}
});
#endif
#endif
}
}
private:
Array
<
index_t
,
nDim
>
mSrcSliceOrigin
;
Array
<
index_t
,
nDim
>
mDstSliceOrigin
;
};
// This threadwise copy allow vector access of src and dst.
// It allows the dimensions of vector access to be different on src and dst.
// It also allows the vector size to be different on src and dst.
// It also allows order of access to be different on src and dst.
// It use register as buffer to hold all data moving from src to dst.
// It is designed for copying small amount of data, and src and dst are
// device memory or LDS.
// When copying large amout of data, let's hope compiler will reduce register
// used for the buffer.
template
<
class
SrcDesc
,
class
DstDesc
,
class
SrcCoordinate
,
class
DstCoordinate
,
class
SliceLengths
,
class
SrcDimAccessOrder
,
class
DstDimAccessOrder
,
index_t
SrcVectorAccessDim
,
index_t
DstVectorAccessDim
,
index_t
SrcDataPerAccess
,
index_t
DstDataPerAccess
>
struct
ThreadwiseGenericTensorSliceCopy_v2r1
{
static
constexpr
index_t
nDim
=
SliceLengths
::
GetSize
();
__device__
constexpr
ThreadwiseGenericTensorSliceCopy_v2r1
(
SrcCoordinate
src_slice_origin
,
DstCoordinate
dst_slice_origin
)
:
mSrcSliceOrigin
(
src_slice_origin
),
mDstSliceOrigin
(
dst_slice_origin
)
{
static_assert
(
nDim
==
SrcDesc
::
GetNumOfDimension
()
&&
nDim
==
DstDesc
::
GetNumOfDimension
()
&&
nDim
==
SliceLengths
::
GetSize
()
&&
nDim
==
SrcDimAccessOrder
::
GetSize
()
&&
nDim
==
DstDimAccessOrder
::
GetSize
(),
"wrong! # of dimensions not the same"
);
static_assert
(
is_valid_sequence_map
<
SrcDimAccessOrder
>::
value
&&
is_valid_sequence_map
<
DstDimAccessOrder
>::
value
,
"wrong! map is not valid"
);
static_assert
(
SliceLengths
{}[
SrcVectorAccessDim
]
%
SrcDataPerAccess
==
0
&&
SliceLengths
{}[
DstVectorAccessDim
]
%
DstDataPerAccess
==
0
,
"wrong! cannot evenly divide"
);
// check vectorized memory access
constexpr
auto
src_vector_access_dim
=
Number
<
SrcVectorAccessDim
>
{};
constexpr
auto
dst_vector_access_dim
=
Number
<
DstVectorAccessDim
>
{};
static_if
<!
SrcDesc
::
ContainMultipleOriginalDimensions
(
src_vector_access_dim
)
>
{}(
[
&
](
auto
fwd
)
{
static_assert
(
(
fwd
(
SrcDesc
{}).
GetStride
(
src_vector_access_dim
)
==
1
||
SrcDataPerAccess
==
1
),
"wrong! vectorized access is allowed only if stride == 1"
);
})
.
Else
([
&
](
auto
fwd
)
{
static_assert
(
(
fwd
(
SrcDesc
{}).
GetLastOriginalDimensionStride
(
src_vector_access_dim
)
==
1
||
SrcDataPerAccess
==
1
),
"wrong! vectorized access is allowed only if stride == 1"
);
});
static_if
<!
DstDesc
::
ContainMultipleOriginalDimensions
(
dst_vector_access_dim
)
>
{}(
[
&
](
auto
fwd
)
{
static_assert
(
(
fwd
(
DstDesc
{}).
GetStride
(
dst_vector_access_dim
)
==
1
||
DstDataPerAccess
==
1
),
"wrong! vectorized access is allowed only if stride == 1"
);
})
.
Else
([
&
](
auto
fwd
)
{
static_assert
(
(
fwd
(
DstDesc
{}).
GetLastOriginalDimensionStride
(
dst_vector_access_dim
)
==
1
||
DstDataPerAccess
==
1
),
"wrong! vectorized access is allowed only if stride == 1"
);
});
}
__device__
constexpr
ThreadwiseGenericTensorSliceCopy_v2r1
()
:
ThreadwiseGenericTensorSliceCopy_v2r1
(
make_zero_array
<
index_t
,
nDim
>
(),
make_zero_array
<
index_t
,
nDim
>
())
{
}
__device__
void
SetSrcSliceOrigin
(
SrcCoordinate
src_slice_origin
)
{
mSrcSliceOrigin
=
src_slice_origin
;
}
__device__
void
SetDstSliceOrigin
(
DstCoordinate
dst_slice_origin
)
{
mDstSliceOrigin
=
dst_slice_origin
;
}
template
<
class
TDesc
,
class
Lengths
>
struct
IsolateMergedDimLengths
{
template
<
class
IDim
>
__device__
constexpr
index_t
operator
()(
IDim
idim
)
const
{
return
TDesc
::
ContainMultipleOriginalDimensions
(
idim
)
?
Lengths
{}[
idim
]
:
1
;
}
};
template
<
class
SrcTData
,
class
DstTData
>
__device__
void
Run
(
const
SrcTData
*
p_src
,
DstTData
*
p_dst
)
const
{
constexpr
auto
buffer_desc
=
make_ConstantTensorDescriptor_packed
(
SliceLengths
{});
SrcTData
p_buffer_
[
buffer_desc
.
GetElementSpace
()];
SrcTData
*
p_buffer
=
p_buffer_
;
// copy data from src into buffer
{
using
src_vector_t
=
typename
vector_type
<
SrcTData
,
SrcDataPerAccess
>::
MemoryType
;
constexpr
auto
src_vector_access_dim
=
Number
<
SrcVectorAccessDim
>
{};
constexpr
auto
src_data_per_access
=
Number
<
SrcDataPerAccess
>
{};
constexpr
auto
src_access_lengths
=
SliceLengths
::
Modify
(
src_vector_access_dim
,
SliceLengths
::
Get
(
src_vector_access_dim
)
/
src_data_per_access
);
// Offset w.r.t merged dimensions need to be calculated at run-time. Offset w.r.t
// normal dimensions is known at compile time.
// Below is a hack to isolate merged dimension id from normal dimension id, so the
// corresponding offset can be calculated seperately at run-time and compile-time.
// src_merged_dim_access_lengths has the same value as src_access_lengths on src's
// merged dimensions, and has value = 1 on normal dimensions;
// src_merged_dim_access_lengths has the same value as src_access_lengths on src's
// normal dimensions, and has value = 1 on merged dimensions;
constexpr
auto
src_merged_dim_access_lengths
=
typename
sequence_gen
<
nDim
,
IsolateMergedDimLengths
<
SrcDesc
,
decltype
(
src_access_lengths
)
>>::
type
{};
constexpr
auto
src_normal_dim_access_lengths
=
src_access_lengths
+
Number
<
1
>
{}
-
src_merged_dim_access_lengths
;
#if CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1
// offset w.r.t. merged dimension need to be computed at run-time
static_ford
<
decltype
(
src_merged_dim_access_lengths
),
SrcDimAccessOrder
>
{}([
&
](
auto
src_merged_dim_access_id_
)
{
constexpr
auto
src_merged_dim_access_id
=
decltype
(
src_merged_dim_access_id_
){};
constexpr
auto
src_merged_dim_data_id
=
src_merged_dim_access_id
.
Modify
(
src_vector_access_dim
,
src_merged_dim_access_id
[
src_vector_access_dim
]
*
src_data_per_access
);
const
SrcTData
*
p_src_tmp
=
p_src
+
(
mSrcSliceOrigin
+
src_merged_dim_data_id
).
GetOffset
();
// offset w.r.t. normal dimension can be computed at compile-time
static_ford
<
decltype
(
src_normal_dim_access_lengths
),
SrcDimAccessOrder
>
{}([
&
](
auto
src_normal_dim_access_id_
)
{
constexpr
auto
src_normal_dim_access_id
=
decltype
(
src_normal_dim_access_id_
){};
constexpr
auto
src_normal_dim_data_id
=
src_normal_dim_access_id
.
Modify
(
src_vector_access_dim
,
src_normal_dim_access_id
[
src_vector_access_dim
]
*
src_data_per_access
);
constexpr
index_t
src_normal_offset
=
SrcDesc
::
GetOffsetFromMultiIndex
(
src_normal_dim_data_id
);
// load vector from src
const
src_vector_t
vector_data
=
*
reinterpret_cast
<
const
src_vector_t
*>
(
&
p_src_tmp
[
src_normal_offset
]);
// unpack vector into buffer
static_for
<
0
,
SrcDataPerAccess
,
1
>
{}([
&
](
auto
i
)
{
constexpr
auto
scalar_id
=
typename
uniform_sequence_gen
<
nDim
,
0
>::
type
{}.
Modify
(
src_vector_access_dim
,
i
);
constexpr
index_t
buffer_offset
=
buffer_desc
.
GetOffsetFromMultiIndex
(
src_merged_dim_data_id
+
src_normal_dim_data_id
+
scalar_id
);
p_buffer
[
buffer_offset
]
=
reinterpret_cast
<
const
SrcTData
*>
(
&
vector_data
)[
i
];
});
});
});
#else
ford
<
decltype
(
src_merged_dim_access_lengths
),
SrcDimAccessOrder
>
{}(
[
&
](
auto
src_merged_dim_access_id
)
{
auto
src_merged_dim_data_id
=
src_merged_dim_access_id
;
src_merged_dim_data_id
(
src_vector_access_dim
)
=
src_merged_dim_access_id
[
src_vector_access_dim
]
*
src_data_per_access
;
const
SrcTData
*
p_src_tmp
=
p_src
+
(
mSrcSliceOrigin
+
src_merged_dim_data_id
).
GetOffset
();
// these should be compile-time known
ford
<
decltype
(
src_normal_dim_access_lengths
),
SrcDimAccessOrder
>
{}([
&
](
auto
src_normal_dim_access_id
)
{
auto
src_normal_dim_data_id
=
src_normal_dim_access_id
;
src_normal_dim_data_id
(
src_vector_access_dim
)
=
src_normal_dim_access_id
[
src_vector_access_dim
]
*
src_data_per_access
;
const
index_t
src_normal_offset
=
SrcDesc
::
GetOffsetFromMultiIndex
(
src_normal_dim_data_id
);
// load vector from src
const
src_vector_t
vector_data
=
*
reinterpret_cast
<
const
src_vector_t
*>
(
&
p_src_tmp
[
src_normal_offset
]);
// unpack vector into buffer
for
(
index_t
i
=
0
;
i
<
SrcDataPerAccess
;
++
i
)
{
auto
scalar_id
=
make_zero_array
<
index_t
,
nDim
>
();
scalar_id
(
src_vector_access_dim
)
=
i
;
const
index_t
buffer_offset
=
buffer_desc
.
GetOffsetFromMultiIndex
(
src_merged_dim_data_id
+
src_normal_dim_data_id
+
scalar_id
);
p_buffer
[
buffer_offset
]
=
reinterpret_cast
<
const
SrcTData
*>
(
&
vector_data
)[
i
];
}
});
});
#endif
}
// copy data from buffer into dst
{
using
dst_vector_t
=
typename
vector_type
<
DstTData
,
DstDataPerAccess
>::
MemoryType
;
constexpr
auto
dst_vector_access_dim
=
Number
<
DstVectorAccessDim
>
{};
constexpr
auto
dst_data_per_access
=
Number
<
DstDataPerAccess
>
{};
constexpr
auto
dst_access_lengths
=
SliceLengths
::
Modify
(
dst_vector_access_dim
,
SliceLengths
::
Get
(
dst_vector_access_dim
)
/
dst_data_per_access
);
constexpr
auto
dst_merged_dim_access_lengths
=
typename
sequence_gen
<
nDim
,
IsolateMergedDimLengths
<
DstDesc
,
decltype
(
dst_access_lengths
)
>>::
type
{};
constexpr
auto
dst_normal_dim_access_lengths
=
dst_access_lengths
+
Number
<
1
>
{}
-
dst_merged_dim_access_lengths
;
#if CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1
// offset w.r.t. merged dimension need to be computed at run-time
static_ford
<
decltype
(
dst_merged_dim_access_lengths
),
DstDimAccessOrder
>
{}([
&
](
auto
dst_merged_dim_access_id_
)
{
constexpr
auto
dst_merged_dim_access_id
=
decltype
(
dst_merged_dim_access_id_
){};
constexpr
auto
dst_merged_dim_data_id
=
dst_merged_dim_access_id
.
Modify
(
dst_vector_access_dim
,
dst_merged_dim_access_id
[
dst_vector_access_dim
]
*
dst_data_per_access
);
DstTData
*
p_dst_tmp
=
p_dst
+
(
mDstSliceOrigin
+
dst_merged_dim_data_id
).
GetOffset
();
// offset w.r.t. normal dimension can be computed at compile-time
static_ford
<
decltype
(
dst_normal_dim_access_lengths
),
DstDimAccessOrder
>
{}([
&
](
auto
dst_normal_dim_access_id_
)
{
constexpr
auto
dst_normal_dim_access_id
=
decltype
(
dst_normal_dim_access_id_
){};
constexpr
auto
dst_normal_dim_data_id
=
dst_normal_dim_access_id
.
Modify
(
dst_vector_access_dim
,
dst_normal_dim_access_id
[
dst_vector_access_dim
]
*
dst_data_per_access
);
dst_vector_t
vector_data
{};
// pack vector from buffer
static_for
<
0
,
DstDataPerAccess
,
1
>
{}([
&
](
auto
i
)
{
constexpr
auto
scalar_id
=
typename
uniform_sequence_gen
<
nDim
,
0
>::
type
{}.
Modify
(
dst_vector_access_dim
,
i
);
constexpr
index_t
buffer_offset
=
buffer_desc
.
GetOffsetFromMultiIndex
(
dst_merged_dim_data_id
+
dst_normal_dim_data_id
+
scalar_id
);
reinterpret_cast
<
DstTData
*>
(
&
vector_data
)[
i
]
=
type_convert
<
DstTData
>
{}(
p_buffer
[
buffer_offset
]);
});
constexpr
index_t
dst_normal_offset
=
DstDesc
::
GetOffsetFromMultiIndex
(
dst_normal_dim_data_id
);
// write vector into dst
*
reinterpret_cast
<
dst_vector_t
*>
(
&
p_dst_tmp
[
dst_normal_offset
])
=
vector_data
;
});
});
#else
// offset w.r.t. merged dimension need to be computed at run-time
ford
<
decltype
(
dst_merged_dim_access_lengths
),
DstDimAccessOrder
>
{}([
&
](
auto
dst_merged_dim_access_id
)
{
auto
dst_merged_dim_data_id
=
dst_merged_dim_access_id
;
dst_merged_dim_data_id
(
dst_vector_access_dim
)
=
dst_merged_dim_access_id
[
dst_vector_access_dim
]
*
dst_data_per_access
;
DstTData
*
p_dst_tmp
=
p_dst
+
(
mDstSliceOrigin
+
dst_merged_dim_data_id
).
GetOffset
();
// offset w.r.t. normal dimension can be computed at compile-time
ford
<
decltype
(
dst_normal_dim_access_lengths
),
DstDimAccessOrder
>
{}([
&
](
auto
dst_normal_dim_access_id
)
{
auto
dst_normal_dim_data_id
=
dst_normal_dim_access_id
;
dst_normal_dim_data_id
(
dst_vector_access_dim
)
=
dst_normal_dim_access_id
[
dst_vector_access_dim
]
*
dst_data_per_access
;
dst_vector_t
vector_data
{};
// pack vector from buffer
for
(
index_t
i
=
0
;
i
<
DstDataPerAccess
;
++
i
)
{
auto
scalar_id
=
make_zero_array
<
index_t
,
nDim
>
();
scalar_id
(
dst_vector_access_dim
)
=
i
;
const
index_t
buffer_offset
=
buffer_desc
.
GetOffsetFromMultiIndex
(
dst_merged_dim_data_id
+
dst_normal_dim_data_id
+
scalar_id
);
reinterpret_cast
<
DstTData
*>
(
&
vector_data
)[
i
]
=
type_convert
<
DstTData
>
{}(
p_buffer
[
buffer_offset
]);
}
const
index_t
dst_normal_offset
=
DstDesc
::
GetOffsetFromMultiIndex
(
dst_normal_dim_data_id
);
// write vector into dst
*
reinterpret_cast
<
dst_vector_t
*>
(
&
p_dst_tmp
[
dst_normal_offset
])
=
vector_data
;
});
});
#endif
}
}
// T can be Sequence or Array
template
<
class
T
,
bool
PositiveDirection
>
__device__
void
MoveSrcSlicingWindow
(
T
step_sizes
,
integral_constant
<
bool
,
PositiveDirection
>
)
{
static_if
<
PositiveDirection
>
{}([
&
](
auto
)
{
mSrcSliceOrigin
+=
step_sizes
;
}).
Else
([
&
](
auto
)
{
mSrcSliceOrigin
-=
step_sizes
;
});
}
template
<
class
T
,
bool
PositiveDirection
>
__device__
void
MoveDstSlicingWindow
(
T
step_sizes
,
integral_constant
<
bool
,
PositiveDirection
>
)
{
static_if
<
PositiveDirection
>
{}([
&
](
auto
)
{
mDstSliceOrigin
+=
step_sizes
;
}).
Else
([
&
](
auto
)
{
mDstSliceOrigin
-=
step_sizes
;
});
}
private:
SrcCoordinate
mSrcSliceOrigin
;
DstCoordinate
mDstSliceOrigin
;
};
}
// namespace ck
}
// namespace ck
#endif
#endif
composable_kernel/include/utility/Array.hpp
View file @
32850b93
...
@@ -9,7 +9,8 @@ namespace ck {
...
@@ -9,7 +9,8 @@ namespace ck {
template
<
class
TData
,
index_t
NSize
>
template
<
class
TData
,
index_t
NSize
>
struct
Array
struct
Array
{
{
using
Type
=
Array
<
TData
,
NSize
>
;
using
Type
=
Array
<
TData
,
NSize
>
;
using
data_type
=
TData
;
static
constexpr
index_t
nSize
=
NSize
;
static
constexpr
index_t
nSize
=
NSize
;
...
@@ -20,7 +21,7 @@ struct Array
...
@@ -20,7 +21,7 @@ struct Array
{
{
}
}
__host__
__device__
constexpr
index_t
GetSize
()
const
{
return
NSize
;
}
__host__
__device__
static
constexpr
index_t
GetSize
()
{
return
NSize
;
}
template
<
index_t
I
>
template
<
index_t
I
>
__host__
__device__
constexpr
TData
operator
[](
Number
<
I
>
)
const
__host__
__device__
constexpr
TData
operator
[](
Number
<
I
>
)
const
...
@@ -208,6 +209,21 @@ __host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Array<TData,
...
@@ -208,6 +209,21 @@ __host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Array<TData,
return
result
;
return
result
;
}
}
// Array += Array
template
<
class
TData
,
index_t
NSize
>
__host__
__device__
constexpr
auto
operator
+=
(
Array
<
TData
,
NSize
>&
a
,
Array
<
TData
,
NSize
>
b
)
{
a
=
a
+
b
;
return
a
;
}
// Array -= Array
template
<
class
TData
,
index_t
NSize
>
__host__
__device__
constexpr
auto
operator
-=
(
Array
<
TData
,
NSize
>&
a
,
Array
<
TData
,
NSize
>
b
)
{
a
=
a
-
b
;
return
a
;
}
// Array = Array + Sequence
// Array = Array + Sequence
template
<
class
TData
,
index_t
NSize
,
index_t
...
Is
>
template
<
class
TData
,
index_t
NSize
,
index_t
...
Is
>
__host__
__device__
constexpr
auto
operator
+
(
Array
<
TData
,
NSize
>
a
,
Sequence
<
Is
...
>
b
)
__host__
__device__
constexpr
auto
operator
+
(
Array
<
TData
,
NSize
>
a
,
Sequence
<
Is
...
>
b
)
...
...
composable_kernel/include/utility/Sequence.hpp
View file @
32850b93
...
@@ -6,41 +6,63 @@
...
@@ -6,41 +6,63 @@
namespace
ck
{
namespace
ck
{
template
<
class
Seq
>
template
<
index_t
...>
struct
Sequence
;
template
<
class
Seq
,
index_t
I
>
struct
sequence_split
;
template
<
class
>
struct
sequence_reverse
;
template
<
class
>
struct
sequence_map_inverse
;
template
<
class
>
struct
is_valid_sequence_map
;
struct
is_valid_sequence_map
;
template
<
index_t
I
,
index_t
...
Is
>
__host__
__device__
constexpr
auto
sequence_pop_front
(
Sequence
<
I
,
Is
...
>
);
template
<
class
Seq
>
__host__
__device__
constexpr
auto
sequence_pop_back
(
Seq
);
template
<
index_t
...
Is
>
template
<
index_t
...
Is
>
struct
Sequence
struct
Sequence
{
{
using
Type
=
Sequence
;
using
Type
=
Sequence
;
using
data_type
=
index_t
;
static
constexpr
index_t
mSize
=
sizeof
...(
Is
);
static
constexpr
index_t
mSize
=
sizeof
...(
Is
);
__host__
__device__
static
constexpr
index_t
GetSize
()
{
return
mSize
;
}
__host__
__device__
static
constexpr
auto
GetSize
()
{
return
Number
<
mSize
>
{}
;
}
template
<
index_t
I
>
__host__
__device__
static
constexpr
index_t
GetImpl
(
index_t
I
)
__host__
__device__
static
constexpr
index_t
Get
(
Number
<
I
>
)
{
{
static_assert
(
I
<
mSize
,
"wrong! I too large"
);
// the last dummy element is to prevent compiler complain about empty array, when mSize = 0
// the last dummy element is to prevent compiler complain about empty array, when mSize = 0
const
index_t
mData
[
mSize
+
1
]
=
{
Is
...,
0
};
const
index_t
mData
[
mSize
+
1
]
=
{
Is
...,
0
};
return
mData
[
I
];
return
mData
[
I
];
}
}
template
<
index_t
I
>
template
<
index_t
I
>
__host__
__device__
constexpr
auto
operator
[]
(
Number
<
I
>
)
const
__host__
__device__
static
constexpr
auto
Get
(
Number
<
I
>
)
{
{
return
Number
<
Get
(
Number
<
I
>
{})
>
{};
static_assert
(
I
<
mSize
,
"wrong! I too large"
);
return
Number
<
GetImpl
(
Number
<
I
>
{})
>
{};
}
}
// make sure I is constepxr
__host__
__device__
static
constexpr
auto
Get
(
index_t
I
)
{
return
GetImpl
(
I
);
}
__host__
__device__
constexpr
index_t
operator
[](
index_t
I
)
const
template
<
index_t
I
>
__host__
__device__
constexpr
auto
operator
[](
Number
<
I
>
)
const
{
{
const
index_t
mData
[
mSize
+
1
]
=
{
Is
...,
0
};
return
Get
(
Number
<
I
>
{});
return
mData
[
I
];
}
}
// make sure I is constepxr if you want a constexpr return type
__host__
__device__
constexpr
index_t
operator
[](
index_t
I
)
const
{
return
GetImpl
(
I
);
}
template
<
index_t
...
IRs
>
template
<
index_t
...
IRs
>
__host__
__device__
static
constexpr
auto
ReorderGivenNew2Old
(
Sequence
<
IRs
...
>
/*new2old*/
)
__host__
__device__
static
constexpr
auto
ReorderGivenNew2Old
(
Sequence
<
IRs
...
>
/*new2old*/
)
{
{
...
@@ -52,23 +74,38 @@ struct Sequence
...
@@ -52,23 +74,38 @@ struct Sequence
return
Sequence
<
Type
::
Get
(
Number
<
IRs
>
{})...
>
{};
return
Sequence
<
Type
::
Get
(
Number
<
IRs
>
{})...
>
{};
}
}
__host__
__device__
static
constexpr
auto
Reverse
();
// MapOld2New is Sequence<...>
template
<
class
MapOld2New
>
__host__
__device__
static
constexpr
auto
ReorderGivenOld2New
(
MapOld2New
)
{
static_assert
(
MapOld2New
::
GetSize
()
==
GetSize
(),
"wrong! reorder map should have the same size as Sequence to be rerodered"
);
static_assert
(
is_valid_sequence_map
<
MapOld2New
>::
value
,
"wrong! invalid reorder map"
);
return
ReorderGivenNew2Old
(
typename
sequence_map_inverse
<
MapOld2New
>::
type
{});
}
__host__
__device__
static
constexpr
index_t
Front
()
__host__
__device__
static
constexpr
auto
Reverse
()
{
{
const
index_t
mData
[
mSize
+
1
]
=
{
Is
...,
0
};
return
typename
sequence_reverse
<
Type
>::
type
{};
return
mData
[
0
];
}
}
__host__
__device__
static
constexpr
index_t
Back
()
__host__
__device__
static
constexpr
auto
Front
()
{
{
const
index_t
mData
[
mSize
+
1
]
=
{
Is
...,
0
};
static_assert
(
mSize
>
0
,
"wrong!"
);
return
mData
[
mSize
-
1
];
return
Get
(
Number
<
0
>
{});
}
__host__
__device__
static
constexpr
auto
Back
()
{
static_assert
(
mSize
>
0
,
"wrong!"
);
return
Get
(
Number
<
mSize
-
1
>
{});
}
}
__host__
__device__
static
constexpr
auto
PopFront
()
;
__host__
__device__
static
constexpr
auto
PopFront
()
{
return
sequence_pop_front
(
Type
{});
}
__host__
__device__
static
constexpr
auto
PopBack
()
;
__host__
__device__
static
constexpr
auto
PopBack
()
{
return
sequence_pop_back
(
Type
{});
}
template
<
index_t
...
Xs
>
template
<
index_t
...
Xs
>
__host__
__device__
static
constexpr
auto
PushFront
(
Sequence
<
Xs
...
>
)
__host__
__device__
static
constexpr
auto
PushFront
(
Sequence
<
Xs
...
>
)
...
@@ -107,7 +144,16 @@ struct Sequence
...
@@ -107,7 +144,16 @@ struct Sequence
}
}
template
<
index_t
I
,
index_t
X
>
template
<
index_t
I
,
index_t
X
>
__host__
__device__
static
constexpr
auto
Modify
(
Number
<
I
>
,
Number
<
X
>
);
__host__
__device__
static
constexpr
auto
Modify
(
Number
<
I
>
,
Number
<
X
>
)
{
static_assert
(
I
<
GetSize
(),
"wrong!"
);
using
seq_split
=
sequence_split
<
Type
,
I
>
;
constexpr
auto
seq_left
=
typename
seq_split
::
SeqType0
{};
constexpr
auto
seq_right
=
typename
seq_split
::
SeqType1
{}.
PopFront
();
return
seq_left
.
PushBack
(
Number
<
X
>
{}).
PushBack
(
seq_right
);
}
template
<
class
F
>
template
<
class
F
>
__host__
__device__
static
constexpr
auto
Transform
(
F
f
)
__host__
__device__
static
constexpr
auto
Transform
(
F
f
)
...
@@ -126,48 +172,63 @@ struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>>
...
@@ -126,48 +172,63 @@ struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>>
using
type
=
Sequence
<
Xs
...,
Ys
...
>
;
using
type
=
Sequence
<
Xs
...,
Ys
...
>
;
};
};
//
arithmetic sqe
uence
//
generate seq
uence
template
<
index_t
IBegin
,
index_t
N
Size
,
index_t
Increment
>
template
<
index_t
IBegin
,
index_t
N
Remain
,
class
F
>
struct
arithmetic_
sequence_gen_impl
struct
sequence_gen_impl
{
{
static
constexpr
index_t
NSizeLeft
=
NSize
/
2
;
static
constexpr
index_t
NRemainLeft
=
NRemain
/
2
;
static
constexpr
index_t
NRemainRight
=
NRemain
-
NRemainLeft
;
static
constexpr
index_t
IMiddle
=
IBegin
+
NRemainLeft
;
using
type
=
typename
sequence_merge
<
using
type
=
typename
arithmetic_sequence_gen_impl
<
IBegin
,
NSizeLeft
,
Increment
>::
type
,
typename
sequence_merge
<
typename
sequence_gen_impl
<
IBegin
,
NRemainLeft
,
F
>::
type
,
typename
arithmetic_sequence_gen_impl
<
IBegin
+
NSizeLeft
*
Increment
,
typename
sequence_gen_impl
<
IMiddle
,
NRemainRight
,
F
>::
type
>::
type
;
NSize
-
NSizeLeft
,
Increment
>::
type
>::
type
;
};
};
template
<
index_t
I
Begin
,
index_t
Increment
>
template
<
index_t
I
,
class
F
>
struct
arithmetic_
sequence_gen_impl
<
I
Begin
,
1
,
Increment
>
struct
sequence_gen_impl
<
I
,
1
,
F
>
{
{
using
type
=
Sequence
<
IBegin
>
;
static
constexpr
index_t
Is
=
F
{}(
Number
<
I
>
{});
using
type
=
Sequence
<
Is
>
;
};
};
template
<
index_t
I
Begin
,
index_t
Increment
>
template
<
index_t
I
,
class
F
>
struct
arithmetic_
sequence_gen_impl
<
I
Begin
,
0
,
Increment
>
struct
sequence_gen_impl
<
I
,
0
,
F
>
{
{
using
type
=
Sequence
<>
;
using
type
=
Sequence
<>
;
};
};
template
<
index_t
NSize
,
class
F
>
struct
sequence_gen
{
using
type
=
typename
sequence_gen_impl
<
0
,
NSize
,
F
>::
type
;
};
// arithmetic sequence
template
<
index_t
IBegin
,
index_t
IEnd
,
index_t
Increment
>
template
<
index_t
IBegin
,
index_t
IEnd
,
index_t
Increment
>
struct
arithmetic_sequence_gen
struct
arithmetic_sequence_gen
{
{
using
type
=
typename
arithmetic_sequence_gen_impl
<
IBegin
,
IEnd
-
IBegin
,
Increment
>::
type
;
struct
F
{
__host__
__device__
constexpr
index_t
operator
()(
index_t
i
)
const
{
return
i
*
Increment
+
IBegin
;
}
};
using
type
=
typename
sequence_gen
<
(
IEnd
-
IBegin
)
/
Increment
,
F
>::
type
;
};
};
// uniform sequence
// uniform sequence
template
<
index_t
NSize
,
index_t
I
>
template
<
index_t
NSize
,
index_t
I
>
struct
uniform_sequence_gen
struct
uniform_sequence_gen
{
{
struct
return_constant
struct
F
{
{
__host__
__device__
constexpr
index_t
operator
()(
index_t
)
const
{
return
I
;
}
__host__
__device__
constexpr
index_t
operator
()(
index_t
)
const
{
return
I
;
}
};
};
using
type
=
decltype
(
using
type
=
typename
sequence_gen
<
NSize
,
F
>::
type
;
typename
arithmetic_sequence_gen
<
0
,
NSize
,
1
>::
type
{}.
Transform
(
return_constant
{}));
};
};
// reverse inclusive scan (with init) sequence
// reverse inclusive scan (with init) sequence
...
@@ -236,6 +297,7 @@ struct sequence_reverse<Sequence<I0, I1>>
...
@@ -236,6 +297,7 @@ struct sequence_reverse<Sequence<I0, I1>>
template
<
class
Seq
>
template
<
class
Seq
>
struct
is_valid_sequence_map
struct
is_valid_sequence_map
{
{
// not implemented yet, always return true
static
constexpr
integral_constant
<
bool
,
true
>
value
=
integral_constant
<
bool
,
true
>
{};
static
constexpr
integral_constant
<
bool
,
true
>
value
=
integral_constant
<
bool
,
true
>
{};
// TODO: add proper check for is_valid, something like:
// TODO: add proper check for is_valid, something like:
...
@@ -244,6 +306,34 @@ struct is_valid_sequence_map
...
@@ -244,6 +306,34 @@ struct is_valid_sequence_map
// typename sequence_sort<Seq>::SortedSeqType>{};
// typename sequence_sort<Seq>::SortedSeqType>{};
};
};
template
<
class
X2Y
,
class
WorkingY2X
,
index_t
XBegin
,
index_t
XRemain
>
struct
sequence_map_inverse_impl
{
private:
static
constexpr
auto
new_y2x
=
WorkingY2X
::
Modify
(
X2Y
::
Get
(
Number
<
XBegin
>
{}),
Number
<
XBegin
>
{});
public:
using
type
=
typename
sequence_map_inverse_impl
<
X2Y
,
decltype
(
new_y2x
),
XBegin
+
1
,
XRemain
-
1
>::
type
;
};
template
<
class
X2Y
,
class
WorkingY2X
,
index_t
XBegin
>
struct
sequence_map_inverse_impl
<
X2Y
,
WorkingY2X
,
XBegin
,
0
>
{
using
type
=
WorkingY2X
;
};
template
<
class
X2Y
>
struct
sequence_map_inverse
{
using
type
=
typename
sequence_map_inverse_impl
<
X2Y
,
typename
uniform_sequence_gen
<
X2Y
::
GetSize
(),
0
>::
type
,
0
,
X2Y
::
GetSize
()
>::
type
;
};
template
<
index_t
...
Xs
,
index_t
...
Ys
>
template
<
index_t
...
Xs
,
index_t
...
Ys
>
__host__
__device__
constexpr
auto
operator
+
(
Sequence
<
Xs
...
>
,
Sequence
<
Ys
...
>
)
__host__
__device__
constexpr
auto
operator
+
(
Sequence
<
Xs
...
>
,
Sequence
<
Ys
...
>
)
{
{
...
@@ -355,8 +445,8 @@ __host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>)
...
@@ -355,8 +445,8 @@ __host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>)
template
<
class
Seq
>
template
<
class
Seq
>
__host__
__device__
constexpr
auto
sequence_pop_back
(
Seq
)
__host__
__device__
constexpr
auto
sequence_pop_back
(
Seq
)
{
{
static_assert
(
Seq
{}.
GetSize
()
>
0
,
"wrong! cannot pop an empty Sequence!"
);
static_assert
(
Seq
::
GetSize
()
>
0
,
"wrong! cannot pop an empty Sequence!"
);
return
sequence_pop_front
(
Seq
{}.
Reverse
()).
Reverse
();
return
sequence_pop_front
(
Seq
::
Reverse
()).
Reverse
();
}
}
template
<
class
F
,
index_t
...
Xs
>
template
<
class
F
,
index_t
...
Xs
>
...
@@ -396,37 +486,6 @@ __host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number<I
...
@@ -396,37 +486,6 @@ __host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number<I
return
reverse_inclusive_scan_sequence
(
Seq
{}.
Reverse
(),
Reduce
{},
Number
<
Init
>
{}).
Reverse
();
return
reverse_inclusive_scan_sequence
(
Seq
{}.
Reverse
(),
Reduce
{},
Number
<
Init
>
{}).
Reverse
();
}
}
template
<
index_t
...
Is
>
__host__
__device__
constexpr
auto
Sequence
<
Is
...
>::
PopFront
()
{
return
sequence_pop_front
(
Type
{});
}
template
<
index_t
...
Is
>
__host__
__device__
constexpr
auto
Sequence
<
Is
...
>::
PopBack
()
{
return
sequence_pop_back
(
Type
{});
}
template
<
index_t
...
Is
>
__host__
__device__
constexpr
auto
Sequence
<
Is
...
>::
Reverse
()
{
return
typename
sequence_reverse
<
Sequence
<
Is
...
>>::
type
{};
}
template
<
index_t
...
Is
>
template
<
index_t
I
,
index_t
X
>
__host__
__device__
constexpr
auto
Sequence
<
Is
...
>::
Modify
(
Number
<
I
>
,
Number
<
X
>
)
{
static_assert
(
I
<
GetSize
(),
"wrong!"
);
using
seq_split
=
sequence_split
<
Type
,
I
>
;
constexpr
auto
seq_left
=
typename
seq_split
::
SeqType0
{};
constexpr
auto
seq_right
=
typename
seq_split
::
SeqType1
{}.
PopFront
();
return
seq_left
.
PushBack
(
Number
<
X
>
{}).
PushBack
(
seq_right
);
}
template
<
index_t
...
Xs
>
template
<
index_t
...
Xs
>
__host__
__device__
void
print_Sequence
(
const
char
*
s
,
Sequence
<
Xs
...
>
)
__host__
__device__
void
print_Sequence
(
const
char
*
s
,
Sequence
<
Xs
...
>
)
{
{
...
...
composable_kernel/include/utility/amd_inline_asm.hpp
View file @
32850b93
...
@@ -3,80 +3,111 @@
...
@@ -3,80 +3,111 @@
#include "vector_type.hpp"
#include "vector_type.hpp"
#define WORKAROUND_SWDEV_202749 1
namespace
ck
{
namespace
ck
{
#if !CK_USE_INLINE_ASM_XDLOPS
// A, B, C, cbsz, abid, blgp
extern
"C"
__device__
float32_t
__llvm_amdgcn_mfma_f32_32x32x1f32
(
float
,
float
,
float32_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.32x32x1f32"
);
extern
"C"
__device__
float32_t
__llvm_amdgcn_mfma_f32_32x32x4f16
(
half4_t
,
half4_t
,
float32_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.32x32x4f16"
);
extern
"C"
__device__
float32_t
__llvm_amdgcn_mfma_f32_32x32x2bf16
(
ushort2_t
,
ushort2_t
,
float32_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.32x32x2bf16"
);
#endif
// cast a pointer of LDS to its address
// cast a pointer of LDS to its address
extern
"C"
__attribute__
((
address_space
(
3
)))
__device__
void
*
__to_local
(
void
*
p
);
__device__
void
vmcnt
(
index_t
cnt
)
extern
"C"
__attribute__
((
address_space
(
3
)))
__device__
void
*
__to_local
(
const
void
*
p
);
{
if
(
cnt
==
0
)
{
asm
volatile
(
"
\n
\
s_waitcnt vmcnt(0)
\n
\
"
::
);
}
else
if
(
cnt
==
1
)
{
asm
volatile
(
"
\n
\
s_waitcnt vmcnt(1)
\n
\
"
::
);
}
else
if
(
cnt
==
2
)
{
asm
volatile
(
"
\n
\
s_waitcnt vmcnt(2)
\n
\
"
::
);
}
else
if
(
cnt
==
4
)
{
asm
volatile
(
"
\n
\
s_waitcnt vmcnt(2)
\n
\
"
::
);
}
else
{
assert
(
false
);
}
}
__device__
void
lgkmcnt
(
index_t
cnt
)
// clang-format off
{
#define REPEATx4(f, off) f(off) f(off + 1) f(off + 2) f(off + 3)
if
(
cnt
==
0
)
{
#define REPEATx16(f, off) \
asm
volatile
(
"
\n
\
REPEATx4(f, off) REPEATx4(f, off + 4) REPEATx4(f, off + 8) REPEATx4(f, off + 12)
s_waitcnt lgkmcnt(0)
\n
\
"
::
);
#define REPEATx64(f, off) \
}
REPEATx16(f, off) REPEATx16(f, off + 16) REPEATx16(f, off + 32) REPEATx16(f, off + 48)
else
if
(
cnt
==
1
)
{
#define REPEAT_STRIDEx4(f, stride, off) \
asm
volatile
(
"
\n
\
f(off) f(off + 1 * stride) f(off + 2 * stride) f(off + 3 * stride)
s_waitcnt lgkmcnt(1)
\n
\
"
::
);
#define REPEAT_STRIDEx16(f, stride, off) \
REPEAT_STRIDEx4(f, stride, off) REPEAT_STRIDEx4(f, stride, off + 1 * stride * 4) \
REPEAT_STRIDEx4(f, stride, off + 2 * stride * 4) \
REPEAT_STRIDEx4(f, stride, off + 3 * stride * 4)
#define REPEAT_STRIDEx64(f, stride, off) \
REPEAT_STRIDEx16(f, stride, off) REPEAT_STRIDEx16(f, stride, off + 1 * stride * 16) \
REPEAT_STRIDEx16(f, stride, off + 2 * stride * 16) \
REPEAT_STRIDEx16(f, stride, off + 3 * stride * 16)
#define NOP(n) asm volatile("\n s_nop " #n " " : :);
#define DS_READ_B32(off) \
if(offset == off) \
{ \
asm volatile("ds_read_b32 %0, %1 offset:" #off " " : "=v"(r) : "v"(__to_local(lds))); \
}
}
else
if
(
cnt
==
2
)
{
#define DS_READ_B128(off) \
asm
volatile
(
"
\n
\
if(offset == off)
\
s_waitcnt lgkmcnt(2)
\n
\
{
\
"
::
);
asm volatile("ds_read_b128 %0, %1 offset:" #off " " : "=v"(r) : "v"(__to_local(lds))); \
}
}
else
if
(
cnt
==
3
)
{
#define DS_WRITE_B128(off) \
asm
volatile
(
"
\n
\
if(offset == off)
\
s_waitcnt lgkmcnt(3)
\n
\
{
\
"
::
);
asm volatile("ds_write_b128 %0, %1 offset:" #off " " : : "v"(__to_local(lds)), "v"(r)); \
}
}
else
if
(
cnt
==
4
)
{
#define MFMA_F32_32x32x1F32(acc, reg_a, reg_b, cbsz, abid, blgp) \
asm
volatile
(
"
\n
\
asm volatile("v_mfma_f32_32x32x1f32 a[" #acc ":" #acc "+31], %0, %1, a[" #acc ":" #acc \
s_waitcnt lgkmcnt(4)
\n
\
"+31] cbsz: " #cbsz " abid: " #abid " blgp:" #blgp " " \
"
::
);
: \
: "v"(reg_a), "v"(reg_b));
#define MFMA_F32_32x32x4F16(acc, reg_a, reg_b, cbsz, abid, blgp) \
asm volatile("v_mfma_f32_32x32x4f16 a[" #acc ":" #acc "+31], %0, %1, a[" #acc ":" #acc \
"+31] cbsz: " #cbsz " abid: " #abid " blgp:" #blgp " " \
: \
: "v"(reg_a), "v"(reg_b));
#define MFMA_F32_32x32x2BF16(acc, reg_a, reg_b, cbsz, abid, blgp) \
asm volatile("v_mfma_f32_32x32x2bf16 a[" #acc ":" #acc "+31], %0, %1, a[" #acc ":" #acc \
"+31] cbsz: " #cbsz " abid: " #abid " blgp:" #blgp " " \
: \
: "v"(reg_a), "v"(reg_b));
#define ACCVGPR_READ(acc_reg_id) \
asm volatile("v_accvgpr_read_b32 %0, a[" #acc_reg_id "]" : "=v"(arch_reg[acc_reg_id]) :);
#define ACCVGPR_WRITE(acc_reg_id) \
asm volatile("v_accvgpr_write_b32 a[" #acc_reg_id "], %0" : : "v"(arch_reg[acc_reg_id]));
#define ACCVGPR_ZERO(acc_reg_id) \
asm volatile("v_accvgpr_write_b32 a[" #acc_reg_id "], 0" : :);
#define S_WAIT_VMCNT(id) \
if(cnt == id) \
{ \
asm volatile("s_waitcnt vmcnt(" #id ")" ::); \
}
}
else
{
#define S_WAIT_LGKMCNT(id) \
assert
(
false
);
if(cnt == id) \
{ \
asm volatile("s_waitcnt lgkmcnt(" #id ")" ::); \
}
}
}
__device__
void
s_wait_vmcnt
(
index_t
cnt
)
{
REPEATx4
(
S_WAIT_VMCNT
,
0
)
}
__device__
void
s_wait_lgkmcnt
(
index_t
cnt
)
{
REPEATx4
(
S_WAIT_LGKMCNT
,
0
)
}
__device__
void
outerProduct1x4
(
const
float
*
a
,
const
float
*
b
,
float
*
c
)
__device__
void
outerProduct1x4
(
const
float
*
a
,
const
float
*
b
,
float
*
c
)
{
{
...
@@ -98,6 +129,23 @@ __device__ void outerProduct1x4(const float* a, const float* b, float* c)
...
@@ -98,6 +129,23 @@ __device__ void outerProduct1x4(const float* a, const float* b, float* c)
"3"
(
c
[
3
]));
"3"
(
c
[
3
]));
}
}
__device__
void
outerProduct1x2
(
const
float
*
a
,
const
float
*
b
,
float
*
c
)
{
// disable inline asm due to the compiler issue: SWDEV-202749
///\to-do: enable the inline asm after the compiler fix
#if WORKAROUND_SWDEV_202749
c
[
0
]
+=
a
[
0
]
*
b
[
0
];
c
[
1
]
+=
a
[
0
]
*
b
[
1
];
#else
asm
volatile
(
"
\n
\
v_mac_f32 %0, %2, %3
\n
\
v_mac_f32 %1, %2, %4
\n
\
"
:
"=v"
(
c
[
0
]),
"=v"
(
c
[
1
])
:
"v"
(
a
[
0
]),
"v"
(
b
[
0
]),
"v"
(
b
[
1
]),
"0"
(
c
[
0
]),
"1"
(
c
[
1
]));
#endif
}
__device__
void
outerProduct1x4
(
const
float
&
a
,
__device__
void
outerProduct1x4
(
const
float
&
a
,
const
vector_type
<
float
,
4
>::
MemoryType
&
b
,
const
vector_type
<
float
,
4
>::
MemoryType
&
b
,
vector_type
<
float
,
4
>::
MemoryType
&
c
)
vector_type
<
float
,
4
>::
MemoryType
&
c
)
...
@@ -105,20 +153,14 @@ __device__ void outerProduct1x4(const float& a,
...
@@ -105,20 +153,14 @@ __device__ void outerProduct1x4(const float& a,
outerProduct1x4
(
&
a
,
reinterpret_cast
<
const
float
*>
(
&
b
),
reinterpret_cast
<
float
*>
(
&
c
));
outerProduct1x4
(
&
a
,
reinterpret_cast
<
const
float
*>
(
&
b
),
reinterpret_cast
<
float
*>
(
&
c
));
}
}
__device__
void
outerProduct4x4
(
const
vector_type
<
float
,
4
>::
MemoryType
&
a
,
__device__
void
outerProduct1x2
(
const
float
&
a
,
const
vector_type
<
float
,
4
>::
MemoryType
&
b
,
const
vector_type
<
float
,
2
>::
MemoryType
&
b
,
vector_type
<
float
,
4
>::
MemoryType
&
c0
,
vector_type
<
float
,
2
>::
MemoryType
&
c
)
vector_type
<
float
,
4
>::
MemoryType
&
c1
,
vector_type
<
float
,
4
>::
MemoryType
&
c2
,
vector_type
<
float
,
4
>::
MemoryType
&
c3
)
{
{
outerProduct1x4
(
a
.
x
,
b
,
c0
);
outerProduct1x2
(
&
a
,
reinterpret_cast
<
const
float
*>
(
&
b
),
reinterpret_cast
<
float
*>
(
&
c
));
outerProduct1x4
(
a
.
y
,
b
,
c1
);
outerProduct1x4
(
a
.
z
,
b
,
c2
);
outerProduct1x4
(
a
.
w
,
b
,
c3
);
}
}
__device__
void
outerProduct1x4
(
const
half2
*
a
,
const
half2
*
b
,
float
*
c
)
__device__
void
outerProduct1x4
dot2TwoTimes
(
const
half2
*
a
,
const
half2
*
b
,
float
*
c
)
{
{
asm
volatile
(
"
\n
\
asm
volatile
(
"
\n
\
v_dot2_f32_f16 %0, %4, %6 %0
\n
\
v_dot2_f32_f16 %0, %4, %6 %0
\n
\
...
@@ -147,579 +189,240 @@ __device__ void outerProduct1x4(const half2* a, const half2* b, float* c)
...
@@ -147,579 +189,240 @@ __device__ void outerProduct1x4(const half2* a, const half2* b, float* c)
"3"
(
c
[
3
]));
// 3rd Src Acc registers for 2 half2 registers
"3"
(
c
[
3
]));
// 3rd Src Acc registers for 2 half2 registers
}
}
__device__
void
outerProduct1x4Half
(
const
vector_type
<
half
,
4
>&
a
,
__device__
void
outerProduct1x4dot2
(
const
half2
*
a
,
const
half2
*
b
,
float
*
c
)
const
vector_type
<
vector_type
<
half
,
4
>
,
4
>&
b
,
vector_type
<
float
,
4
>::
MemoryType
&
c
)
{
{
outerProduct1x4
(
reinterpret_cast
<
const
half2
*>
(
&
a
),
asm
volatile
(
"
\n
\
reinterpret_cast
<
const
half2
*>
(
&
b
),
v_dot2_f32_f16 %0, %4, %5 %0
\n
\
reinterpret_cast
<
float
*>
(
&
c
));
v_dot2_f32_f16 %1, %4, %6 %1
\n
\
v_dot2_f32_f16 %2, %4, %7 %2
\n
\
v_dot2_f32_f16 %3, %4, %8 %3
\n
\
"
:
"=v"
(
c
[
0
]),
"=v"
(
c
[
1
]),
"=v"
(
c
[
2
]),
"=v"
(
c
[
3
])
// Dest registers
:
"v"
(
a
[
0
]),
// 1st Src register for 1 half2 registers
"v"
(
b
[
0
]),
// 2nd Src register
"v"
(
b
[
1
]),
"v"
(
b
[
2
]),
"v"
(
b
[
3
]),
"0"
(
c
[
0
]),
// 3rd Src register
"1"
(
c
[
1
]),
"2"
(
c
[
2
]),
"3"
(
c
[
3
]));
}
}
__device__
void
outerProduct4x4
(
const
vector_type
<
vector_type
<
half
,
4
>
,
4
>&
a
,
__device__
void
outerProduct1x2dot2TwoTimes
(
const
half2
*
a
,
const
half2
*
b
,
float
*
c
)
const
vector_type
<
vector_type
<
half
,
4
>
,
4
>&
b
,
vector_type
<
float
,
4
>::
MemoryType
&
c0
,
vector_type
<
float
,
4
>::
MemoryType
&
c1
,
vector_type
<
float
,
4
>::
MemoryType
&
c2
,
vector_type
<
float
,
4
>::
MemoryType
&
c3
)
{
{
const
vector_type
<
half
,
4
>*
reg_a
=
reinterpret_cast
<
const
vector_type
<
half
,
4
>*>
(
&
a
);
asm
volatile
(
"
\n
\
outerProduct1x4Half
(
reg_a
[
0
],
b
,
c0
);
v_dot2_f32_f16 %0, %2, %4 %0
\n
\
outerProduct1x4Half
(
reg_a
[
1
],
b
,
c1
);
v_dot2_f32_f16 %1, %2, %6 %1
\n
\
outerProduct1x4Half
(
reg_a
[
2
],
b
,
c2
);
v_dot2_f32_f16 %0, %3, %5 %0
\n
\
outerProduct1x4Half
(
reg_a
[
3
],
b
,
c3
);
v_dot2_f32_f16 %1, %3, %7 %1
\n
\
"
:
"=v"
(
c
[
0
]),
"=v"
(
c
[
1
])
// Dest registers
:
"v"
(
a
[
0
]),
"v"
(
a
[
1
]),
// 1st Src registers for 2 half2 registers
"v"
(
b
[
0
]),
"v"
(
b
[
1
]),
"v"
(
b
[
2
]),
"v"
(
b
[
3
]),
// 2nd Src registers for 2 half2 registers
"0"
(
c
[
0
]),
"1"
(
c
[
1
]));
// 3rd Src Acc registers for 2 half2 registers
}
}
__device__
void
outerProduct8x8
(
const
vector_type
<
float
,
4
>::
MemoryType
*
a
,
__device__
void
outerProduct1x2dot2
(
const
half2
*
a
,
const
half2
*
b
,
float
*
c
)
const
vector_type
<
float
,
4
>::
MemoryType
*
b
,
vector_type
<
float
,
4
>::
MemoryType
*
c
)
{
{
outerProduct4x4
(
a
[
0
],
b
[
0
],
c
[
0
],
c
[
2
],
c
[
4
],
c
[
6
]);
asm
volatile
(
"
\n
\
outerProduct4x4
(
a
[
0
],
b
[
1
],
c
[
1
],
c
[
3
],
c
[
5
],
c
[
7
]);
v_dot2_f32_f16 %0, %2, %3 %0
\n
\
outerProduct4x4
(
a
[
1
],
b
[
0
],
c
[
8
],
c
[
10
],
c
[
12
],
c
[
14
]);
v_dot2_f32_f16 %1, %2, %4 %1
\n
\
outerProduct4x4
(
a
[
1
],
b
[
1
],
c
[
9
],
c
[
11
],
c
[
13
],
c
[
15
]);
"
:
"=v"
(
c
[
0
]),
"=v"
(
c
[
1
])
// Dest registers
:
"v"
(
a
[
0
]),
// 1st Src register for 1 half2 registers
"v"
(
b
[
0
]),
// 2nd Src register
"v"
(
b
[
1
]),
"0"
(
c
[
0
]),
// 3rd Src register
"1"
(
c
[
1
]));
}
}
__device__
void
ds_read_b128
(
vector_type
<
float
,
4
>::
MemoryType
&
r
,
void
*
lds
,
index_t
offset
=
0
)
__device__
void
ds_read_b32
(
float
&
r
,
const
void
*
lds
,
index_t
offset
=
0
)
{
DS_READ_B32
(
0
)
}
__device__
void
ds_read_b128
(
vector_type
<
float
,
4
>::
MemoryType
&
r
,
const
void
*
lds
,
index_t
offset
=
0
)
{
{
if
(
offset
==
0
)
REPEAT_STRIDEx64
(
DS_READ_B128
,
64
,
0
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:0
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
64
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:64
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
128
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:128
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
192
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:192
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
256
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:256
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
320
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:320
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
384
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:384
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
448
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:448
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
512
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:512
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
576
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:576
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
640
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:640
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
704
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:704
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
768
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:768
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
832
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:832
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
896
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:896
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
960
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:960
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
1024
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:1024
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
1088
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:1088
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
1152
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:1152
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
1216
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:1216
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
1280
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:1280
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
1344
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:1344
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
1408
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:1408
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
1472
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:1472
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
1536
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:1536
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
1600
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:1600
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
1664
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:1664
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
1728
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:1728
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
1792
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:1792
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
1856
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:1856
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
1920
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:1920
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
1984
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:1984
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
2048
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:2048
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
2112
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:2112
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
2176
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:2176
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
2240
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:2240
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
2304
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:2304
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
2368
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:2368
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
2432
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:2432
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
2496
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:2496
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
2560
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:2560
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
2624
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:2624
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
2688
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:2688
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
2752
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:2752
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
2816
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:2816
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
2880
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:2880
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
2944
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:2944
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
3008
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:3008
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
3072
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:3072
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
3136
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:3136
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
3200
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:3200
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
3264
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:3264
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
3328
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:3328
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
3392
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:3392
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
3456
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:3456
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
3520
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:3520
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
3584
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:3584
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
3648
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:3648
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
3712
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:3712
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
3776
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:3776
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
3840
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:3840
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
3904
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:3904
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
3968
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:3968
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
4032
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:4032
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
if
(
offset
==
4096
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:4096
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
}
}
__device__
void
__device__
void
ds_write_b128
(
const
vector_type
<
float
,
4
>::
MemoryType
&
r
,
void
*
lds
,
index_t
offset
=
0
)
ds_write_b128
(
const
vector_type
<
float
,
4
>::
MemoryType
&
r
,
const
void
*
lds
,
index_t
offset
=
0
)
{
{
if
(
offset
==
0
)
REPEAT_STRIDEx64
(
DS_WRITE_B128
,
64
,
0
)
{
}
asm
volatile
(
"
\n
\
ds_write_b128 %0, %1
\n
\
template
<
index_t
Size
>
"
__device__
void
gcnasm_accvgpr_read
(
float
*
)
:
{
:
"v"
(
__to_local
(
lds
)),
"v"
(
r
));
}
}
else
template
<
>
{
__device__
void
gcnasm_accvgpr_read
<
16
>
(
float
*
arch_reg
)
assert
(
false
);
{
}
#if CK_USE_INLINE_ASM_XDLOPS
NOP
(
16
)
REPEATx16
(
ACCVGPR_READ
,
0
)
#else
(
void
)
arch_reg
;
#endif
}
template
<
>
__device__
void
gcnasm_accvgpr_read
<
32
>
(
float
*
arch_reg
)
{
#if CK_USE_INLINE_ASM_XDLOPS
NOP
(
16
)
REPEATx16
(
ACCVGPR_READ
,
0
)
REPEATx16
(
ACCVGPR_READ
,
16
)
#else
(
void
)
arch_reg
;
#endif
}
template
<
>
__device__
void
gcnasm_accvgpr_read
<
64
>
(
float
*
arch_reg
)
{
#if CK_USE_INLINE_ASM_XDLOPS
NOP
(
16
)
REPEATx64
(
ACCVGPR_READ
,
0
)
#else
(
void
)
arch_reg
;
#endif
}
template
<
index_t
MPerWave
>
__device__
void
gcnasm_accvgpr_zero
()
{
}
template
<
>
__device__
void
gcnasm_accvgpr_zero
<
32
>
()
{
#if CK_USE_INLINE_ASM_XDLOPS
REPEATx16
(
ACCVGPR_ZERO
,
0
)
REPEATx16
(
ACCVGPR_ZERO
,
16
)
#endif
}
template
<
>
__device__
void
gcnasm_accvgpr_zero
<
64
>
()
{
#if CK_USE_INLINE_ASM_XDLOPS
REPEATx64
(
ACCVGPR_ZERO
,
0
)
#endif
}
template
<
index_t
MPerWave
>
__device__
void
gcnasm_mfma_f32_32x32x1f32
(
float
&
,
float
&
,
float32_t
*
)
{
}
template
<
>
__device__
void
gcnasm_mfma_f32_32x32x1f32
<
64
>
(
float
&
reg_a
,
float
&
reg_b
,
float32_t
*
reg_c
)
{
#if CK_USE_INLINE_ASM_XDLOPS
NOP
(
1
)
(
void
)
reg_c
;
MFMA_F32_32x32x1F32
(
0
,
reg_a
,
reg_b
,
1
,
0
,
0
)
MFMA_F32_32x32x1F32
(
32
,
reg_a
,
reg_b
,
1
,
1
,
0
)
#else
reg_c
[
0
]
=
__llvm_amdgcn_mfma_f32_32x32x1f32
(
reg_a
,
reg_b
,
reg_c
[
0
],
1
,
0
,
0
);
reg_c
[
1
]
=
__llvm_amdgcn_mfma_f32_32x32x1f32
(
reg_a
,
reg_b
,
reg_c
[
1
],
1
,
1
,
0
);
#endif
}
template
<
>
__device__
void
gcnasm_mfma_f32_32x32x1f32
<
32
>
(
float
&
reg_a
,
float
&
reg_b
,
float32_t
*
reg_c
)
{
#if CK_USE_INLINE_ASM_XDLOPS
NOP
(
1
)
(
void
)
reg_c
;
MFMA_F32_32x32x1F32
(
0
,
reg_a
,
reg_b
,
1
,
0
,
0
)
#else
reg_c
[
0
]
=
__llvm_amdgcn_mfma_f32_32x32x1f32
(
reg_a
,
reg_b
,
reg_c
[
0
],
1
,
0
,
0
);
#endif
}
template
<
index_t
MPerWave
>
__device__
void
gcnasm_mfma_f32_32x32x4f16
(
typename
vector_type
<
half
,
4
>::
MemoryType
&
,
typename
vector_type
<
half
,
4
>::
MemoryType
&
,
float32_t
*
)
{
}
template
<
>
__device__
void
gcnasm_mfma_f32_32x32x4f16
<
64
>
(
typename
vector_type
<
half
,
4
>::
MemoryType
&
reg_a
,
typename
vector_type
<
half
,
4
>::
MemoryType
&
reg_b
,
float32_t
*
reg_c
)
{
#if CK_USE_INLINE_ASM_XDLOPS
NOP
(
1
)
(
void
)
reg_c
;
MFMA_F32_32x32x4F16
(
0
,
reg_a
,
reg_b
,
1
,
0
,
0
)
MFMA_F32_32x32x4F16
(
32
,
reg_a
,
reg_b
,
1
,
1
,
0
)
#else
reg_c
[
0
]
=
__llvm_amdgcn_mfma_f32_32x32x4f16
(
reg_a
,
reg_b
,
reg_c
[
0
],
1
,
0
,
0
);
reg_c
[
1
]
=
__llvm_amdgcn_mfma_f32_32x32x4f16
(
reg_a
,
reg_b
,
reg_c
[
1
],
1
,
1
,
0
);
#endif
}
template
<
>
__device__
void
gcnasm_mfma_f32_32x32x4f16
<
32
>
(
typename
vector_type
<
half
,
4
>::
MemoryType
&
reg_a
,
typename
vector_type
<
half
,
4
>::
MemoryType
&
reg_b
,
float32_t
*
reg_c
)
{
#if CK_USE_INLINE_ASM_XDLOPS
NOP
(
1
)
(
void
)
reg_c
;
MFMA_F32_32x32x4F16
(
0
,
reg_a
,
reg_b
,
1
,
0
,
0
)
#else
reg_c
[
0
]
=
__llvm_amdgcn_mfma_f32_32x32x4f16
(
reg_a
,
reg_b
,
reg_c
[
0
],
1
,
0
,
0
);
#endif
}
}
template
<
index_t
MPerWave
>
__device__
void
gcnasm_mfma_f32_32x32x2bf16
(
typename
vector_type
<
ushort
,
2
>::
MemoryType
&
,
typename
vector_type
<
ushort
,
2
>::
MemoryType
&
,
float32_t
*
)
{
}
template
<
>
__device__
void
gcnasm_mfma_f32_32x32x2bf16
<
64
>
(
typename
vector_type
<
ushort
,
2
>::
MemoryType
&
reg_a
,
typename
vector_type
<
ushort
,
2
>::
MemoryType
&
reg_b
,
float32_t
*
reg_c
)
{
#if CK_USE_INLINE_ASM_XDLOPS
NOP
(
1
)
(
void
)
reg_c
;
MFMA_F32_32x32x2BF16
(
0
,
reg_a
,
reg_b
,
1
,
0
,
0
)
MFMA_F32_32x32x2BF16
(
32
,
reg_a
,
reg_b
,
1
,
1
,
0
)
#else
reg_c
[
0
]
=
__llvm_amdgcn_mfma_f32_32x32x2bf16
(
reg_a
,
reg_b
,
reg_c
[
0
],
1
,
0
,
0
);
reg_c
[
1
]
=
__llvm_amdgcn_mfma_f32_32x32x2bf16
(
reg_a
,
reg_b
,
reg_c
[
1
],
1
,
1
,
0
);
#endif
}
template
<
>
__device__
void
gcnasm_mfma_f32_32x32x2bf16
<
32
>
(
typename
vector_type
<
ushort
,
2
>::
MemoryType
&
reg_a
,
typename
vector_type
<
ushort
,
2
>::
MemoryType
&
reg_b
,
float32_t
*
reg_c
)
{
#if CK_USE_INLINE_ASM_XDLOPS
NOP
(
1
)
(
void
)
reg_c
;
MFMA_F32_32x32x2BF16
(
0
,
reg_a
,
reg_b
,
1
,
0
,
0
)
#else
reg_c
[
0
]
=
__llvm_amdgcn_mfma_f32_32x32x2bf16
(
reg_a
,
reg_b
,
reg_c
[
0
],
1
,
0
,
0
);
#endif
}
// clang-format on
}
// namespace ck
}
// namespace ck
#endif
#endif
composable_kernel/include/utility/bfloat16_dev.hpp
View file @
32850b93
...
@@ -26,6 +26,8 @@
...
@@ -26,6 +26,8 @@
#ifndef BFLOAT16_DEVICE_HPP
#ifndef BFLOAT16_DEVICE_HPP
#define BFLOAT16_DEVICE_HPP
#define BFLOAT16_DEVICE_HPP
#define __HIP_PLATFORM_HCC__ 1
#ifdef __cplusplus
#ifdef __cplusplus
extern
"C"
{
extern
"C"
{
#endif
#endif
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment