Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
545d9305
"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "697da6578fe886dc2478ce71291ccd7d39f5d51d"
Commit
545d9305
authored
Sep 24, 2019
by
Chao Liu
Browse files
refactor
parent
37f4e2b6
Changes
23
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
251 additions
and
183 deletions
+251
-183
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
...ridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
+1
-1
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp
...n_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp
+1
-1
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded.hpp
..._convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded.hpp
+1
-1
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buffer.hpp
...cit_gemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buffer.hpp
+21
-19
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
...ridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
+1
-1
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer.hpp
...n_implicit_gemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer.hpp
+1
-1
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_padded.hpp
..._convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_padded.hpp
+1
-1
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buffer.hpp
...cit_gemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buffer.hpp
+47
-19
composable_kernel/include/tensor_description/ConstantTensorDescriptor.hpp
...l/include/tensor_description/ConstantTensorDescriptor.hpp
+5
-6
composable_kernel/include/tensor_description/multi_index_transform.hpp
...rnel/include/tensor_description/multi_index_transform.hpp
+38
-32
composable_kernel/include/tensor_description/tensor_coordinate_helper.hpp
...l/include/tensor_description/tensor_coordinate_helper.hpp
+16
-0
composable_kernel/include/tensor_description/tensor_coordinate_v2.hpp
...ernel/include/tensor_description/tensor_coordinate_v2.hpp
+6
-13
composable_kernel/include/tensor_description/tensor_descriptor.hpp
...e_kernel/include/tensor_description/tensor_descriptor.hpp
+22
-11
composable_kernel/include/tensor_operation/blockwise_3d_tensor_op.hpp
...ernel/include/tensor_operation/blockwise_3d_tensor_op.hpp
+1
-1
composable_kernel/include/tensor_operation/blockwise_4d_tensor_op.hpp
...ernel/include/tensor_operation/blockwise_4d_tensor_op.hpp
+1
-1
composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp
.../tensor_operation/blockwise_generic_tensor_slice_copy.hpp
+22
-30
composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp
...tensor_operation/threadwise_generic_tensor_slice_copy.hpp
+27
-39
composable_kernel/include/utility/array.hpp
composable_kernel/include/utility/array.hpp
+1
-2
composable_kernel/include/utility/functional.hpp
composable_kernel/include/utility/functional.hpp
+6
-0
composable_kernel/include/utility/sequence.hpp
composable_kernel/include/utility/sequence.hpp
+32
-4
No files found.
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
View file @
545d9305
...
@@ -100,7 +100,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
...
@@ -100,7 +100,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
constexpr
index_t
E
=
C
*
Y
*
X
;
constexpr
index_t
E
=
C
*
Y
*
X
;
// sanity-check for vectorized memory load
// sanity-check for vectorized memory load
static_assert
((
H
o
==
1
||
ConvStrideW
%
InBlockCopySrcDataPerRead_B
==
0
)
&&
static_assert
((
W
o
==
1
||
(
ConvStrideW
==
1
||
InBlockCopySrcDataPerRead_B
==
1
)
)
&&
(
X
==
1
||
ConvDilationW
%
InBlockCopySrcDataPerRead_B
==
0
),
(
X
==
1
||
ConvDilationW
%
InBlockCopySrcDataPerRead_B
==
0
),
"wrong! aligment requirement for vectorized global load of input tensor will "
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated"
);
"be violated"
);
...
...
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp
View file @
545d9305
...
@@ -100,7 +100,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
...
@@ -100,7 +100,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
constexpr
index_t
E
=
C
*
Y
*
X
;
constexpr
index_t
E
=
C
*
Y
*
X
;
// sanity-check for vectorized memory load
// sanity-check for vectorized memory load
static_assert
((
H
o
==
1
||
ConvStrideW
%
InBlockCopySrcDataPerRead_B
==
0
)
&&
static_assert
((
W
o
==
1
||
(
ConvStrideW
==
1
||
InBlockCopySrcDataPerRead_B
==
1
)
)
&&
(
X
==
1
||
ConvDilationW
%
InBlockCopySrcDataPerRead_B
==
0
),
(
X
==
1
||
ConvDilationW
%
InBlockCopySrcDataPerRead_B
==
0
),
"wrong! aligment requirement for vectorized global load of input tensor will "
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated"
);
"be violated"
);
...
...
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded.hpp
View file @
545d9305
...
@@ -107,7 +107,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded
...
@@ -107,7 +107,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded
constexpr
index_t
E
=
C
*
Y
*
X
;
constexpr
index_t
E
=
C
*
Y
*
X
;
// sanity-check for vectorized memory load
// sanity-check for vectorized memory load
static_assert
((
H
o
==
1
||
ConvStrideW
%
InBlockCopySrcDataPerRead_B
==
0
)
&&
static_assert
((
W
o
==
1
||
(
ConvStrideW
==
1
||
InBlockCopySrcDataPerRead_B
==
1
)
)
&&
(
X
==
1
||
ConvDilationW
%
InBlockCopySrcDataPerRead_B
==
0
),
(
X
==
1
||
ConvDilationW
%
InBlockCopySrcDataPerRead_B
==
0
),
"wrong! aligment requirement for vectorized global load of input tensor will "
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated"
);
"be violated"
);
...
...
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buffer.hpp
View file @
545d9305
...
@@ -107,7 +107,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf
...
@@ -107,7 +107,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf
constexpr
index_t
E
=
C
*
Y
*
X
;
constexpr
index_t
E
=
C
*
Y
*
X
;
// sanity-check for vectorized memory load
// sanity-check for vectorized memory load
static_assert
((
H
o
==
1
||
ConvStrideW
%
InBlockCopySrcDataPerRead_B
==
0
)
&&
static_assert
((
W
o
==
1
||
(
ConvStrideW
==
1
||
InBlockCopySrcDataPerRead_B
==
1
)
)
&&
(
X
==
1
||
ConvDilationW
%
InBlockCopySrcDataPerRead_B
==
0
),
(
X
==
1
||
ConvDilationW
%
InBlockCopySrcDataPerRead_B
==
0
),
"wrong! aligment requirement for vectorized global load of input tensor will "
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated"
);
"be violated"
);
...
@@ -174,9 +174,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf
...
@@ -174,9 +174,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf
decltype
(
in_e_n1_b_n2_global_desc
),
decltype
(
in_e_n1_b_n2_global_desc
),
decltype
(
in_e_n1_b_n2_block_desc
),
decltype
(
in_e_n1_b_n2_block_desc
),
Sequence
<
0
,
1
,
0
,
1
>
,
Sequence
<
0
,
1
,
0
,
1
>
,
Sequence
<
1
,
0
,
1
,
0
>
,
Sequence
<
1
,
1
,
1
,
1
>
,
Sequence
<
1
,
1
,
1
,
1
>
,
Sequence
<
0
,
0
,
0
,
0
>
,
decltype
(
in_e_n1_b_n2_block_desc
.
GetLengths
()),
decltype
(
in_e_n1_b_n2_block_desc
.
GetLengths
()),
InBlockCopySubLengths_E_N1_B_N2
,
InBlockCopySubLengths_E_N1_B_N2
,
InBlockCopyClusterLengths_E_N1_B_N2
,
InBlockCopyClusterLengths_E_N1_B_N2
,
...
@@ -219,9 +217,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf
...
@@ -219,9 +217,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf
decltype
(
wei_e_k_global_desc
),
decltype
(
wei_e_k_global_desc
),
decltype
(
wei_e_k_block_desc
),
decltype
(
wei_e_k_block_desc
),
Sequence
<
1
,
1
>
,
Sequence
<
1
,
1
>
,
Sequence
<
0
,
0
>
,
Sequence
<
1
,
1
>
,
Sequence
<
1
,
1
>
,
Sequence
<
0
,
0
>
,
decltype
(
wei_e_k_block_desc
.
GetLengths
()),
decltype
(
wei_e_k_block_desc
.
GetLengths
()),
WeiBlockCopySubLengths_E_K
,
WeiBlockCopySubLengths_E_K
,
WeiBlockCopyClusterLengths_E_K
,
WeiBlockCopyClusterLengths_E_K
,
...
@@ -299,8 +295,10 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf
...
@@ -299,8 +295,10 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf
// LDS double buffer: preload data into LDS
// LDS double buffer: preload data into LDS
{
{
blockwise_in_copy
.
Run
(
p_in_global
,
p_in_block_double
);
blockwise_in_copy
.
template
Run
<
Float
,
address_space_t
::
global
,
address_space_t
::
lds
>(
blockwise_wei_copy
.
Run
(
p_wei_global
,
p_wei_block_double
);
p_in_global
,
p_in_block_double
);
blockwise_wei_copy
.
template
Run
<
Float
,
address_space_t
::
global
,
address_space_t
::
lds
>(
p_wei_global
,
p_wei_block_double
);
}
}
// LDS double buffer: main body
// LDS double buffer: main body
...
@@ -331,15 +329,19 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf
...
@@ -331,15 +329,19 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf
__syncthreads
();
__syncthreads
();
// LDS doubel buffer: load next data from device mem
// LDS doubel buffer: load next data from device mem
blockwise_in_copy
.
RunLoadRegisterBuffer
(
p_in_global
,
p_in_register_buffer
);
blockwise_in_copy
.
template
RunLoadRegisterBuffer
<
Float
,
address_space_t
::
global
>(
blockwise_wei_copy
.
RunLoadRegisterBuffer
(
p_wei_global
,
p_wei_register_buffer
);
p_in_global
,
p_in_register_buffer
);
blockwise_wei_copy
.
template
RunLoadRegisterBuffer
<
Float
,
address_space_t
::
global
>(
p_wei_global
,
p_wei_register_buffer
);
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_wei_block_now
,
p_in_block_now
,
p_out_thread
);
blockwise_gemm
.
Run
(
p_wei_block_now
,
p_in_block_now
,
p_out_thread
);
// LDS double buffer: store next data to LDS
// LDS double buffer: store next data to LDS
blockwise_in_copy
.
RunStoreRegisterBuffer
(
p_in_register_buffer
,
p_in_block_next
);
blockwise_in_copy
.
template
RunStoreRegisterBuffer
<
Float
,
address_space_t
::
lds
>(
blockwise_wei_copy
.
RunStoreRegisterBuffer
(
p_wei_register_buffer
,
p_wei_block_next
);
p_in_register_buffer
,
p_in_block_next
);
blockwise_wei_copy
.
template
RunStoreRegisterBuffer
<
Float
,
address_space_t
::
lds
>(
p_wei_register_buffer
,
p_wei_block_next
);
}
}
}
}
...
@@ -355,17 +357,19 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf
...
@@ -355,17 +357,19 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf
__syncthreads
();
__syncthreads
();
// LDS doubel buffer: load next data from device mem
// LDS doubel buffer: load next data from device mem
blockwise_in_copy
.
RunLoadRegisterBuffer
(
p_in_global
,
p_in_register_buffer
);
blockwise_in_copy
.
template
RunLoadRegisterBuffer
<
Float
,
address_space_t
::
global
>(
blockwise_wei_copy
.
RunLoadRegisterBuffer
(
p_wei_global
,
p_wei_register_buffer
);
p_in_global
,
p_in_register_buffer
);
blockwise_wei_copy
.
template
RunLoadRegisterBuffer
<
Float
,
address_space_t
::
global
>(
p_wei_global
,
p_wei_register_buffer
);
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_wei_block_double
,
p_in_block_double
,
p_out_thread
);
blockwise_gemm
.
Run
(
p_wei_block_double
,
p_in_block_double
,
p_out_thread
);
// LDS double buffer: store next data to LDS
// LDS double buffer: store next data to LDS
blockwise_in_copy
.
RunStoreRegisterBuffer
(
p_in_register_buffer
,
blockwise_in_copy
.
template
RunStoreRegisterBuffer
<
Float
,
address_space_t
::
lds
>(
p_in_block_double
+
in_block_space
);
p_in_register_buffer
,
p_in_block_double
+
in_block_space
);
blockwise_wei_copy
.
RunStoreRegisterBuffer
(
p_wei_register_buffer
,
blockwise_wei_copy
.
template
RunStoreRegisterBuffer
<
Float
,
address_space_t
::
lds
>(
p_wei_block_double
+
wei_block_space
);
p_wei_register_buffer
,
p_wei_block_double
+
wei_block_space
);
// odd iteration
// odd iteration
__syncthreads
();
__syncthreads
();
...
@@ -424,9 +428,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf
...
@@ -424,9 +428,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf
ThreadwiseGenericTensorSliceCopy_v4r2
<
decltype
(
out_k0_k1_n1_b_n2_thread_desc
),
ThreadwiseGenericTensorSliceCopy_v4r2
<
decltype
(
out_k0_k1_n1_b_n2_thread_desc
),
decltype
(
out_k0_k1_n1_b_n2_global_desc
),
decltype
(
out_k0_k1_n1_b_n2_global_desc
),
Sequence
<
1
,
1
,
1
,
1
,
1
>
,
Sequence
<
1
,
1
,
1
,
1
,
1
>
,
Sequence
<
0
,
0
,
0
,
0
,
0
>
,
Sequence
<
1
,
1
,
1
,
0
,
1
>
,
Sequence
<
1
,
1
,
1
,
0
,
1
>
,
Sequence
<
0
,
0
,
0
,
1
,
0
>
,
decltype
(
decltype
(
out_k0_k1_n1_b_n2_thread_desc
.
GetLengths
()),
out_k0_k1_n1_b_n2_thread_desc
.
GetLengths
()),
arithmetic_sequence_gen
<
0
,
5
,
1
>::
type
,
arithmetic_sequence_gen
<
0
,
5
,
1
>::
type
,
...
...
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
View file @
545d9305
...
@@ -84,7 +84,7 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
...
@@ -84,7 +84,7 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
constexpr
index_t
B
=
N
*
Ho
*
Wo
;
constexpr
index_t
B
=
N
*
Ho
*
Wo
;
// sanity-check for vectorized memory load
// sanity-check for vectorized memory load
static_assert
((
H
o
==
1
||
ConvStrideW
%
InBlockCopyDataPerAccess_B
==
0
)
&&
static_assert
((
W
o
==
1
||
(
ConvStrideW
==
1
||
InBlockCopyDataPerAccess_B
==
1
)
)
&&
(
X
==
1
||
ConvDilationW
%
InBlockCopyDataPerAccess_B
==
0
),
(
X
==
1
||
ConvDilationW
%
InBlockCopyDataPerAccess_B
==
0
),
"wrong! aligment requirement for vectorized global load of input tensor will "
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated"
);
"be violated"
);
...
...
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer.hpp
View file @
545d9305
...
@@ -84,7 +84,7 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
...
@@ -84,7 +84,7 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
constexpr
index_t
B
=
N
*
Ho
*
Wo
;
constexpr
index_t
B
=
N
*
Ho
*
Wo
;
// sanity-check for vectorized memory load
// sanity-check for vectorized memory load
static_assert
((
H
o
==
1
||
ConvStrideW
%
InBlockCopyDataPerAccess_B
==
0
)
&&
static_assert
((
W
o
==
1
||
(
ConvStrideW
==
1
||
InBlockCopyDataPerAccess_B
==
1
)
)
&&
(
X
==
1
||
ConvDilationW
%
InBlockCopyDataPerAccess_B
==
0
),
(
X
==
1
||
ConvDilationW
%
InBlockCopyDataPerAccess_B
==
0
),
"wrong! aligment requirement for vectorized global load of input tensor will "
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated"
);
"be violated"
);
...
...
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_padded.hpp
View file @
545d9305
...
@@ -91,7 +91,7 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded
...
@@ -91,7 +91,7 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded
constexpr
index_t
B
=
N
*
Ho
*
Wo
;
constexpr
index_t
B
=
N
*
Ho
*
Wo
;
// sanity-check for vectorized memory load
// sanity-check for vectorized memory load
static_assert
((
H
o
==
1
||
ConvStrideW
%
InBlockCopyDataPerAccess_B
==
0
)
&&
static_assert
((
W
o
==
1
||
(
ConvStrideW
==
1
||
InBlockCopyDataPerAccess_B
==
1
)
)
&&
(
X
==
1
||
ConvDilationW
%
InBlockCopyDataPerAccess_B
==
0
),
(
X
==
1
||
ConvDilationW
%
InBlockCopyDataPerAccess_B
==
0
),
"wrong! aligment requirement for vectorized global load of input tensor will "
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated"
);
"be violated"
);
...
...
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buffer.hpp
View file @
545d9305
...
@@ -90,7 +90,7 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
...
@@ -90,7 +90,7 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
constexpr
index_t
B
=
N
*
Ho
*
Wo
;
constexpr
index_t
B
=
N
*
Ho
*
Wo
;
// sanity-check for vectorized memory load
// sanity-check for vectorized memory load
static_assert
((
H
o
==
1
||
ConvStrideW
%
InBlockCopyDataPerAccess_B
==
0
)
&&
static_assert
((
W
o
==
1
||
(
ConvStrideW
==
1
||
InBlockCopyDataPerAccess_B
==
1
)
)
&&
(
X
==
1
||
ConvDilationW
%
InBlockCopyDataPerAccess_B
==
0
),
(
X
==
1
||
ConvDilationW
%
InBlockCopyDataPerAccess_B
==
0
),
"wrong! aligment requirement for vectorized global load of input tensor will "
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated"
);
"be violated"
);
...
@@ -145,6 +145,8 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
...
@@ -145,6 +145,8 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
BlockwiseGenericTensorSliceCopy_v4
<
BlockSize
,
BlockwiseGenericTensorSliceCopy_v4
<
BlockSize
,
decltype
(
in_e_b_global_desc
),
decltype
(
in_e_b_global_desc
),
decltype
(
in_e_b_block_desc
),
decltype
(
in_e_b_block_desc
),
Sequence
<
0
,
0
>
,
Sequence
<
1
,
1
>
,
decltype
(
in_e_b_block_desc
.
GetLengths
()),
decltype
(
in_e_b_block_desc
.
GetLengths
()),
InBlockCopySubLengths_E_B
,
InBlockCopySubLengths_E_B
,
InBlockCopyClusterLengths_E_B
,
InBlockCopyClusterLengths_E_B
,
...
@@ -157,13 +159,21 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
...
@@ -157,13 +159,21 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
InBlockCopyDataPerAccess_B
>
(
InBlockCopyDataPerAccess_B
>
(
{
0
,
b_block_data_on_global
},
{
0
,
0
});
{
0
,
b_block_data_on_global
},
{
0
,
0
});
// weight tensor
// weight tensor
// global mem
// global mem
#if 0
constexpr auto wei_e_k_global_desc =
constexpr auto wei_e_k_global_desc =
transform_tensor_descriptor(wei_k_c_y_x_global_desc,
transform_tensor_descriptor(wei_k_c_y_x_global_desc,
make_tuple(Merge<Sequence<C, Y, X>>{}, PassThrough<K>{}),
make_tuple(Merge<Sequence<C, Y, X>>{}, PassThrough<K>{}),
make_tuple(Sequence<1, 2, 3>{}, Sequence<0>{}),
make_tuple(Sequence<1, 2, 3>{}, Sequence<0>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
make_tuple(Sequence<0>{}, Sequence<1>{}));
#else
// hack
constexpr
auto
wei_e_k_global_desc_old
=
WeiGlobalDesc
::
Unfold
(
I1
,
I3
).
ReorderGivenNew2Old
(
Sequence
<
1
,
0
>
{});
constexpr
auto
wei_e_k_global_desc
=
make_native_tensor_descriptor
(
wei_e_k_global_desc_old
.
GetLengths
(),
wei_e_k_global_desc_old
.
GetStrides
());
#endif
// LDS
// LDS
// be careful of LDS alignment
// be careful of LDS alignment
...
@@ -176,6 +186,8 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
...
@@ -176,6 +186,8 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
BlockwiseGenericTensorSliceCopy_v4
<
BlockSize
,
BlockwiseGenericTensorSliceCopy_v4
<
BlockSize
,
decltype
(
wei_e_k_global_desc
),
decltype
(
wei_e_k_global_desc
),
decltype
(
wei_e_k_block_desc
),
decltype
(
wei_e_k_block_desc
),
Sequence
<
1
,
1
>
,
Sequence
<
1
,
1
>
,
decltype
(
wei_e_k_block_desc
.
GetLengths
()),
decltype
(
wei_e_k_block_desc
.
GetLengths
()),
WeiBlockCopySubLengths_E_K
,
WeiBlockCopySubLengths_E_K
,
WeiBlockCopyClusterLengths_E_K
,
WeiBlockCopyClusterLengths_E_K
,
...
@@ -253,8 +265,10 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
...
@@ -253,8 +265,10 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
// LDS double buffer: preload data into LDS
// LDS double buffer: preload data into LDS
{
{
blockwise_in_copy
.
Run
(
p_in_global
,
p_in_block_double
);
blockwise_in_copy
.
template
Run
<
Float
,
address_space_t
::
global
,
address_space_t
::
lds
>(
blockwise_wei_copy
.
Run
(
p_wei_global
,
p_wei_block_double
);
p_in_global
,
p_in_block_double
);
blockwise_wei_copy
.
template
Run
<
Float
,
address_space_t
::
global
,
address_space_t
::
lds
>(
p_wei_global
,
p_wei_block_double
);
}
}
// LDS double buffer: main body
// LDS double buffer: main body
...
@@ -285,15 +299,19 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
...
@@ -285,15 +299,19 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
__syncthreads
();
__syncthreads
();
// LDS doubel buffer: load next data from device mem
// LDS doubel buffer: load next data from device mem
blockwise_in_copy
.
RunLoadRegisterBuffer
(
p_in_global
,
p_in_register_buffer
);
blockwise_in_copy
.
template
RunLoadRegisterBuffer
<
Float
,
address_space_t
::
global
>(
blockwise_wei_copy
.
RunLoadRegisterBuffer
(
p_wei_global
,
p_wei_register_buffer
);
p_in_global
,
p_in_register_buffer
);
blockwise_wei_copy
.
template
RunLoadRegisterBuffer
<
Float
,
address_space_t
::
global
>(
p_wei_global
,
p_wei_register_buffer
);
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_wei_block_now
,
p_in_block_now
,
p_out_thread
);
blockwise_gemm
.
Run
(
p_wei_block_now
,
p_in_block_now
,
p_out_thread
);
// LDS double buffer: store next data to LDS
// LDS double buffer: store next data to LDS
blockwise_in_copy
.
RunStoreRegisterBuffer
(
p_in_register_buffer
,
p_in_block_next
);
blockwise_in_copy
.
template
RunStoreRegisterBuffer
<
Float
,
address_space_t
::
lds
>(
blockwise_wei_copy
.
RunStoreRegisterBuffer
(
p_wei_register_buffer
,
p_wei_block_next
);
p_in_register_buffer
,
p_in_block_next
);
blockwise_wei_copy
.
template
RunStoreRegisterBuffer
<
Float
,
address_space_t
::
lds
>(
p_wei_register_buffer
,
p_wei_block_next
);
}
}
}
}
...
@@ -309,17 +327,19 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
...
@@ -309,17 +327,19 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
__syncthreads
();
__syncthreads
();
// LDS doubel buffer: load next data from device mem
// LDS doubel buffer: load next data from device mem
blockwise_in_copy
.
RunLoadRegisterBuffer
(
p_in_global
,
p_in_register_buffer
);
blockwise_in_copy
.
template
RunLoadRegisterBuffer
<
Float
,
address_space_t
::
global
>(
blockwise_wei_copy
.
RunLoadRegisterBuffer
(
p_wei_global
,
p_wei_register_buffer
);
p_in_global
,
p_in_register_buffer
);
blockwise_wei_copy
.
template
RunLoadRegisterBuffer
<
Float
,
address_space_t
::
global
>(
p_wei_global
,
p_wei_register_buffer
);
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_wei_block_double
,
p_in_block_double
,
p_out_thread
);
blockwise_gemm
.
Run
(
p_wei_block_double
,
p_in_block_double
,
p_out_thread
);
// LDS double buffer: store next data to LDS
// LDS double buffer: store next data to LDS
blockwise_in_copy
.
RunStoreRegisterBuffer
(
p_in_register_buffer
,
blockwise_in_copy
.
template
RunStoreRegisterBuffer
<
Float
,
address_space_t
::
lds
>(
p_in_block_double
+
in_block_space
);
p_in_register_buffer
,
p_in_block_double
+
in_block_space
);
blockwise_wei_copy
.
RunStoreRegisterBuffer
(
p_wei_register_buffer
,
blockwise_wei_copy
.
template
RunStoreRegisterBuffer
<
Float
,
address_space_t
::
lds
>(
p_wei_block_double
+
wei_block_space
);
p_wei_register_buffer
,
p_wei_block_double
+
wei_block_space
);
// odd iteration
// odd iteration
__syncthreads
();
__syncthreads
();
...
@@ -367,9 +387,11 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
...
@@ -367,9 +387,11 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
// output threadwise copy
// output threadwise copy
auto
threadwise_out_copy
=
ThreadwiseGenericTensorSliceCopy_v4r2
<
ThreadwiseGenericTensorSliceCopy_v4r2
<
decltype
(
out_k0_k1_b0_b1_thread_desc
),
decltype
(
out_k0_k1_b0_b1_thread_desc
),
decltype
(
out_k0_k1_b0_b1_global_desc
),
decltype
(
out_k0_k1_b0_b1_global_desc
),
Sequence
<
1
,
1
,
1
,
1
>
,
Sequence
<
1
,
1
,
0
,
0
>
,
decltype
(
out_k0_k1_b0_b1_thread_desc
.
GetLengths
()),
decltype
(
out_k0_k1_b0_b1_thread_desc
.
GetLengths
()),
arithmetic_sequence_gen
<
0
,
4
,
1
>::
type
,
arithmetic_sequence_gen
<
0
,
4
,
1
>::
type
,
3
,
3
,
...
@@ -378,9 +400,15 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
...
@@ -378,9 +400,15 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
{
k_thread_data_on_global
/
K1
,
{
k_thread_data_on_global
/
K1
,
k_thread_data_on_global
%
K1
,
k_thread_data_on_global
%
K1
,
b_thread_data_on_global
/
B1
,
b_thread_data_on_global
/
B1
,
b_thread_data_on_global
%
B1
});
b_thread_data_on_global
%
B1
})
#if 1
threadwise_out_copy
.
Run
(
p_out_thread
,
p_out_global
);
.
template
Run_generic
<
Float
,
address_space_t
::
generic
,
address_space_t
::
global
>
#elif 1
.
template
Run_optimized_dst_address_calculation
<
Float
,
address_space_t
::
vgpr
,
address_space_t
::
global
>
#endif
(
p_out_thread
,
p_out_global
);
}
}
}
}
};
};
...
...
composable_kernel/include/tensor_description/ConstantTensorDescriptor.hpp
View file @
545d9305
...
@@ -96,13 +96,12 @@ struct ConstantTensorDescriptor
...
@@ -96,13 +96,12 @@ struct ConstantTensorDescriptor
__host__
__device__
static
constexpr
auto
GetElementSize
()
__host__
__device__
static
constexpr
auto
GetElementSize
()
{
{
return
Number
<
accumulate_on_sequence
(
return
Number
<
reduce_on_sequence
(
Lengths
{},
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{})
>
{};
Lengths
{},
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{})
>
{};
}
}
__host__
__device__
static
constexpr
auto
GetElementSpace
()
__host__
__device__
static
constexpr
auto
GetElementSpace
()
{
{
constexpr
index_t
element_space_unaligned
=
accumulat
e_on_sequence
(
constexpr
index_t
element_space_unaligned
=
reduc
e_on_sequence
(
(
GetLengths
()
-
Number
<
1
>
{})
*
GetStrides
(),
math
::
plus
<
index_t
>
{},
Number
<
1
>
{});
(
GetLengths
()
-
Number
<
1
>
{})
*
GetStrides
(),
math
::
plus
<
index_t
>
{},
Number
<
1
>
{});
return
Number
<
element_space_unaligned
>
{};
return
Number
<
element_space_unaligned
>
{};
...
@@ -155,7 +154,7 @@ struct ConstantTensorDescriptor
...
@@ -155,7 +154,7 @@ struct ConstantTensorDescriptor
constexpr
auto
multi_id
=
Sequence
<
Is
...
>
{};
constexpr
auto
multi_id
=
Sequence
<
Is
...
>
{};
return
Number
<
accumulat
e_on_sequence
(
return
Number
<
reduc
e_on_sequence
(
multi_id
*
GetStrides
(),
math
::
plus
<
index_t
>
{},
Number
<
0
>
{})
>
{};
multi_id
*
GetStrides
(),
math
::
plus
<
index_t
>
{},
Number
<
0
>
{})
>
{};
}
}
...
@@ -389,7 +388,7 @@ struct ConstantTensorDescriptor
...
@@ -389,7 +388,7 @@ struct ConstantTensorDescriptor
constexpr
auto
fold_intervals
=
Sequence
<
FoldIntervals
...
>
{};
constexpr
auto
fold_intervals
=
Sequence
<
FoldIntervals
...
>
{};
constexpr
index_t
fold_intervals_product
=
constexpr
index_t
fold_intervals_product
=
accumulat
e_on_sequence
(
fold_intervals
,
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{});
reduc
e_on_sequence
(
fold_intervals
,
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{});
constexpr
auto
unfold_length
=
GetLength
(
Number
<
IDim
>
{});
constexpr
auto
unfold_length
=
GetLength
(
Number
<
IDim
>
{});
constexpr
auto
unfold_stride
=
GetStride
(
Number
<
IDim
>
{});
constexpr
auto
unfold_stride
=
GetStride
(
Number
<
IDim
>
{});
...
@@ -447,7 +446,7 @@ struct ConstantTensorDescriptor
...
@@ -447,7 +446,7 @@ struct ConstantTensorDescriptor
static_assert
(
Type
::
Extract
(
middle
).
AreDimensionsContinuous
(),
"wrong! not unfoldable"
);
static_assert
(
Type
::
Extract
(
middle
).
AreDimensionsContinuous
(),
"wrong! not unfoldable"
);
// unfolded length, stride
// unfolded length, stride
constexpr
index_t
unfold_length
=
accumulat
e_on_sequence
(
constexpr
index_t
unfold_length
=
reduc
e_on_sequence
(
GetLengths
().
Extract
(
middle
),
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{});
GetLengths
().
Extract
(
middle
),
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{});
constexpr
index_t
unfold_stride
=
GetStride
(
Number
<
LastUnfoldDim
>
{});
constexpr
index_t
unfold_stride
=
GetStride
(
Number
<
LastUnfoldDim
>
{});
...
...
composable_kernel/include/tensor_description/multi_index_transform.hpp
View file @
545d9305
...
@@ -41,11 +41,10 @@ struct PassThrough
...
@@ -41,11 +41,10 @@ struct PassThrough
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
// TODO: should this function be here? should it be specific for padding check?
__host__
__device__
static
constexpr
bool
__host__
__device__
static
constexpr
bool
IsUpperIndex
InPaddingArea
(
const
UpperIndex
&
/* idx_up */
)
IsUpperIndex
MappedToValidLowerIndex
(
const
UpperIndex
&
/* idx_up */
)
{
{
return
fals
e
;
return
tru
e
;
}
}
};
};
...
@@ -82,23 +81,38 @@ struct Pad
...
@@ -82,23 +81,38 @@ struct Pad
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
// TODO: should this function be here? should it be specific for padding check?
__host__
__device__
constexpr
bool
__host__
__device__
constexpr
bool
IsUpperIndexInPaddingArea
(
const
UpperIndex
&
idx_up
)
const
IsUpperIndexMappedToValidLowerIndex
(
const
UpperIndex
&
idx_up
)
const
{
{
bool
flag
=
false
;
#if 0
struct lambda_no_pad
{
__host__ __device__ constexpr bool operator()(index_t x) const { return x == 0; }
};
if(sequence_all_of(LeftPads{}, lambda_no_pad{}) &&
sequence_all_of(RightPads{}, lambda_no_pad{}))
{
return true;
}
else
#endif
{
bool
flag
=
true
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
idim
)
{
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
idim
)
{
// only check if there is left-padding
// only check if there is left-padding
static_if
<
(
LeftPads
::
At
(
idim
)
!=
0
)
>
{}(
static_if
<
(
LeftPads
::
At
(
idim
)
!=
0
)
>
{}(
[
&
](
auto
)
{
flag
=
flag
||
idx_up
[
idim
]
<
LeftPads
::
At
(
idim
);
});
[
&
](
auto
)
{
flag
=
flag
&&
idx_up
[
idim
]
>=
LeftPads
::
At
(
idim
);
});
// only check if there is right-padding
// only check if there is right-padding
static_if
<
(
RightPads
::
At
(
idim
)
!=
0
)
>
{}([
&
](
auto
)
{
static_if
<
(
RightPads
::
At
(
idim
)
!=
0
)
>
{}([
&
](
auto
)
{
flag
=
flag
||
idx_up
[
idim
]
>=
LeftPads
::
At
(
idim
)
+
LowerLengths
::
At
(
idim
);
flag
=
flag
&&
(
idx_up
[
idim
]
<
LeftPads
::
At
(
idim
)
+
LowerLengths
::
At
(
idim
));
});
});
});
});
return
flag
;
return
flag
;
}
}
}
};
};
...
@@ -155,16 +169,10 @@ struct Merge
...
@@ -155,16 +169,10 @@ struct Merge
LowerLengths
::
PopFront
(),
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{})
LowerLengths
::
PopFront
(),
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{})
.
PushBack
(
Number
<
1
>
{});
.
PushBack
(
Number
<
1
>
{});
#if 1 // would these 2 versions be compiled to same ISA?
// calculate index in each of the dimensions in the order of their dimension
static_for
<
0
,
nDimLow
-
1
,
1
>
{}(
static_for
<
0
,
nDimLow
-
1
,
1
>
{}(
lambda_CalculateLowerIndex
<
decltype
(
pseudo_low_strides
)
>
(
itmp
,
idx_low
));
lambda_CalculateLowerIndex
<
decltype
(
pseudo_low_strides
)
>
(
itmp
,
idx_low
));
idx_low
(
nDimLow
-
1
)
=
itmp
/
pseudo_low_strides
[
nDimLow
-
1
];
idx_low
(
nDimLow
-
1
)
=
itmp
/
pseudo_low_strides
[
nDimLow
-
1
];
#else
static_for
<
0
,
nDimLow
,
1
>
{}(
lambda_CalculateLowerIndex
<
decltype
(
pseudo_low_strides
)
>
(
itmp
,
idx_low
));
#endif
return
idx_low
;
return
idx_low
;
}
}
...
@@ -244,6 +252,7 @@ struct Merge
...
@@ -244,6 +252,7 @@ struct Merge
});
});
// highest dimension, no out-of-bound check
// highest dimension, no out-of-bound check
if
(
borrow
)
if
(
borrow
)
{
{
--
idx_low_new
(
0
);
--
idx_low_new
(
0
);
...
@@ -255,11 +264,10 @@ struct Merge
...
@@ -255,11 +264,10 @@ struct Merge
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
false
;
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
false
;
}
// TODO: should this function be here? should it be specific for padding check?
__host__
__device__
static
constexpr
bool
__host__
__device__
static
constexpr
bool
IsUpperIndex
InPaddingArea
(
const
UpperIndex
&
/* idx_up */
)
IsUpperIndex
MappedToValidLowerIndex
(
const
UpperIndex
&
/* idx_up */
)
{
{
return
fals
e
;
return
tru
e
;
}
}
};
};
...
@@ -304,11 +312,10 @@ struct Unmerge
...
@@ -304,11 +312,10 @@ struct Unmerge
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
// TODO: should this function be here? should it be specific for padding check?
__host__
__device__
static
constexpr
bool
__host__
__device__
static
constexpr
bool
IsUpperIndex
InPaddingArea
(
const
UpperIndex
&
/* idx_up */
)
IsUpperIndex
MappedToValidLowerIndex
(
const
UpperIndex
&
/* idx_up */
)
{
{
return
fals
e
;
return
tru
e
;
}
}
};
};
...
@@ -362,9 +369,9 @@ struct Embed
...
@@ -362,9 +369,9 @@ struct Embed
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
__host__
__device__
static
constexpr
bool
IsUpperIndex
InPaddingArea
(
const
UpperIndex
&
/* idx_up */
)
IsUpperIndex
MappedToValidLowerIndex
(
const
UpperIndex
&
/* idx_up */
)
{
{
return
fals
e
;
return
tru
e
;
}
}
};
};
...
@@ -404,11 +411,10 @@ struct Vectorize
...
@@ -404,11 +411,10 @@ struct Vectorize
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
// TODO: should this function be here? should it be specific for padding check?
__host__
__device__
static
constexpr
bool
__host__
__device__
static
constexpr
bool
IsUpperIndex
InPaddingArea
(
const
UpperIndex
&
/* idx_up */
)
IsUpperIndex
MappedToValidLowerIndex
(
const
UpperIndex
&
/* idx_up */
)
{
{
return
fals
e
;
return
tru
e
;
}
}
};
};
...
...
composable_kernel/include/tensor_description/tensor_coordinate_helper.hpp
0 → 100644
View file @
545d9305
#ifndef CK_TENSOR_COORDINATE_HELPER_HPP
#define CK_TENSOR_COORDINATE_HELPER_HPP
#include "tensor_coordiante_v2.hpp"
namespace
ck
{
template
<
typename
TensorDesc
>
__host__
__device__
constexpr
auto
make_tensor_coordinate_v2
(
TensorDesc
,
MultiIndex
<
TensorDesc
::
GetNumOfDimension
()
>
idx
)
{
return
typename
TensorCoordinate_v2
<
TensorDesc
>::
type
(
idx
);
}
}
// namespace ck
#endif
composable_kernel/include/tensor_description/tensor_coordinate_v2.hpp
View file @
545d9305
...
@@ -76,8 +76,7 @@ struct NativeTensorCoordinate
...
@@ -76,8 +76,7 @@ struct NativeTensorCoordinate
return
coord
;
return
coord
;
}
}
// TODO: should this function be here? should it be specific for padding check?
__host__
__device__
static
constexpr
bool
IsUpperIndexMappedToValidOffset
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsAnyLevelIndexInPaddingArea
()
{
return
false
;
}
private:
private:
// mIndex may be saved and update, however, the value of some (or all) of its entries may
// mIndex may be saved and update, however, the value of some (or all) of its entries may
...
@@ -166,11 +165,11 @@ struct TransformedTensorCoordinate
...
@@ -166,11 +165,11 @@ struct TransformedTensorCoordinate
return
coord_up
;
return
coord_up
;
}
}
//
TODO: should
this function
be here? should it be specific for padding check?
// this function
should be inexpensive, because there is no upper-to-lower index transformation
__host__
__device__
constexpr
bool
Is
AnyLevelIndexInPaddingArea
()
const
__host__
__device__
constexpr
bool
Is
UpperIndexMappedToValidOffset
()
const
{
{
return
tensor_desc_type
::
IsUpperIndex
InPaddingArea
(
GetIndex
())
||
return
tensor_desc_type
::
IsUpperIndex
MappedToValidLowerIndex
(
GetIndex
())
&&
mCoordLow
.
Is
AnyLevelIndexInPaddingArea
();
mCoordLow
.
Is
UpperIndexMappedToValidOffset
();
}
}
private:
private:
...
@@ -206,11 +205,5 @@ struct TensorCoordinate_v2
...
@@ -206,11 +205,5 @@ struct TensorCoordinate_v2
using
type
=
decltype
(
MakeDummyTensorCoordinate
(
TensorDesc
{}));
using
type
=
decltype
(
MakeDummyTensorCoordinate
(
TensorDesc
{}));
};
};
template
<
typename
TensorDesc
>
}
// namespace ck
__host__
__device__
constexpr
auto
make_tensor_coordinate_v2
(
TensorDesc
,
MultiIndex
<
TensorDesc
::
GetNumOfDimension
()
>
idx
)
{
return
typename
TensorCoordinate_v2
<
TensorDesc
>::
type
(
idx
);
}
}
#endif
#endif
composable_kernel/include/tensor_description/tensor_descriptor.hpp
View file @
545d9305
...
@@ -66,12 +66,12 @@ struct NativeTensorDescriptor
...
@@ -66,12 +66,12 @@ struct NativeTensorDescriptor
__host__
__device__
static
constexpr
index_t
GetElementSize
()
__host__
__device__
static
constexpr
index_t
GetElementSize
()
{
{
return
accumulat
e_on_sequence
(
GetLengths
(),
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{});
return
reduc
e_on_sequence
(
GetLengths
(),
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{});
}
}
__host__
__device__
static
constexpr
index_t
GetElementSpace
()
__host__
__device__
static
constexpr
index_t
GetElementSpace
()
{
{
return
accumulat
e_on_sequence
(
return
reduc
e_on_sequence
(
(
GetLengths
()
-
Number
<
1
>
{})
*
GetStrides
(),
math
::
plus
<
index_t
>
{},
Number
<
1
>
{});
(
GetLengths
()
-
Number
<
1
>
{})
*
GetStrides
(),
math
::
plus
<
index_t
>
{},
Number
<
1
>
{});
}
}
...
@@ -120,10 +120,10 @@ struct NativeTensorDescriptor
...
@@ -120,10 +120,10 @@ struct NativeTensorDescriptor
}
}
#endif
#endif
// TODO: should this function be here? should it be specific for padding check?
__host__
__device__
static
constexpr
bool
__host__
__device__
static
constexpr
bool
IsUpperIndexInPaddingArea
(
const
Index
&
/* idx */
)
IsUpperIndexMappedToValidOffset
(
const
Index
&
/* idx */
)
{
{
return
fals
e
;
return
tru
e
;
}
}
};
};
...
@@ -290,7 +290,7 @@ struct TransformedTensorDescriptor
...
@@ -290,7 +290,7 @@ struct TransformedTensorDescriptor
__host__
__device__
static
constexpr
index_t
GetElementSize
()
__host__
__device__
static
constexpr
index_t
GetElementSize
()
{
{
return
accumulat
e_on_sequence
(
GetLengths
(),
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{});
return
reduc
e_on_sequence
(
GetLengths
(),
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{});
}
}
__host__
__device__
static
constexpr
index_t
GetElementSpace
()
__host__
__device__
static
constexpr
index_t
GetElementSpace
()
...
@@ -375,7 +375,7 @@ struct TransformedTensorDescriptor
...
@@ -375,7 +375,7 @@ struct TransformedTensorDescriptor
constexpr bool is_linear_transform = tran.IsLinearTransform();
constexpr bool is_linear_transform = tran.IsLinearTransform();
// judge if all lower dimension are linear
// judge if all lower dimension are linear
constexpr bool is_all_low_dim_linear = math::
accumulat
e_on_sequence(
constexpr bool is_all_low_dim_linear = math::
reduc
e_on_sequence(
pick_sequence_elements_by_mask(
pick_sequence_elements_by_mask(
GetLowerTensorDescriptor().GetMaskOfLinearDimensions(), LowDimensionId{}),
GetLowerTensorDescriptor().GetMaskOfLinearDimensions(), LowDimensionId{}),
math::logic_and<bool>{},
math::logic_and<bool>{},
...
@@ -441,21 +441,32 @@ struct TransformedTensorDescriptor
...
@@ -441,21 +441,32 @@ struct TransformedTensorDescriptor
}
}
#endif
#endif
// TODO: should this function be here? should it be specific for padding check?
__host__
__device__
static
constexpr
bool
__host__
__device__
static
constexpr
bool
IsUpperIndexInPaddingArea
(
const
UpperIndex
&
idx_up
)
IsUpperIndexMappedToValidLowerIndex
(
const
UpperIndex
&
idx_up
)
{
{
bool
flag
=
fals
e
;
bool
flag
=
tru
e
;
static_for
<
0
,
nTransform
,
1
>
{}([
&
](
auto
itran
)
{
static_for
<
0
,
nTransform
,
1
>
{}([
&
](
auto
itran
)
{
constexpr
auto
tran
=
Transforms
{}.
At
(
itran
);
constexpr
auto
tran
=
Transforms
{}.
At
(
itran
);
const
auto
idx_up_part
=
pick_array_element
(
idx_up
,
UpDimensionIds
{}.
At
(
itran
));
const
auto
idx_up_part
=
pick_array_element
(
idx_up
,
UpDimensionIds
{}.
At
(
itran
));
flag
=
flag
||
tran
.
IsUpperIndex
InPaddingArea
(
to_array
(
idx_up_part
));
flag
=
flag
&&
tran
.
IsUpperIndex
MappedToValidLowerIndex
(
to_array
(
idx_up_part
));
});
});
return
flag
;
return
flag
;
}
}
// Whenever this function is called, it will call CalculateLowerIndex() recursively
// If you have created a tensor coordinate already, instead of calling this function,
// you should call TransformedTensorCoordinate::IsUpperIndexMappedToValidOffset()
__host__
__device__
static
constexpr
bool
IsUpperIndexMappedToValidOffset
(
const
UpperIndex
&
idx_up
)
{
return
IsUpperIndexMappedToValidLowerIndex
(
idx_up
)
&&
GetLowerTensorDescriptor
().
IsUpperIndexMappedToValidOffset
(
CalculateLowerIndex
(
idx_up
));
}
};
};
}
// namespace ck
}
// namespace ck
...
...
composable_kernel/include/tensor_operation/blockwise_3d_tensor_op.hpp
View file @
545d9305
...
@@ -162,7 +162,7 @@ struct Blockwise3dTensorCopy3
...
@@ -162,7 +162,7 @@ struct Blockwise3dTensorCopy3
"wrrong! BlockSize is not big enough for ThreadPerDims!"
);
"wrrong! BlockSize is not big enough for ThreadPerDims!"
);
constexpr
index_t
num_active_thread
=
constexpr
index_t
num_active_thread
=
accumulat
e_on_sequence
(
ThreadPerDims
{},
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{});
reduc
e_on_sequence
(
ThreadPerDims
{},
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{});
if
(
BlockSize
>
num_active_thread
)
if
(
BlockSize
>
num_active_thread
)
{
{
...
...
composable_kernel/include/tensor_operation/blockwise_4d_tensor_op.hpp
View file @
545d9305
...
@@ -505,7 +505,7 @@ struct Blockwise4dTensorCopy3
...
@@ -505,7 +505,7 @@ struct Blockwise4dTensorCopy3
"wrrong! BlockSize is not big enough for ThreadPerDims!"
);
"wrrong! BlockSize is not big enough for ThreadPerDims!"
);
constexpr
index_t
num_active_thread
=
constexpr
index_t
num_active_thread
=
accumulat
e_on_sequence
(
ThreadPerDims
{},
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{});
reduc
e_on_sequence
(
ThreadPerDims
{},
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{});
if
(
BlockSize
>
num_active_thread
)
if
(
BlockSize
>
num_active_thread
)
{
{
...
...
composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp
View file @
545d9305
...
@@ -681,9 +681,7 @@ template <index_t BlockSize,
...
@@ -681,9 +681,7 @@ template <index_t BlockSize,
typename
SrcDesc
,
typename
SrcDesc
,
typename
DstDesc
,
typename
DstDesc
,
typename
SrcLinearDimensionMask
,
typename
SrcLinearDimensionMask
,
typename
SrcNonLinearDimensionMask
,
typename
DstLinearDimensionMask
,
typename
DstLinearDimensionMask
,
typename
DstNonLinearDimensionMask
,
typename
SliceLengths
,
typename
SliceLengths
,
typename
SubLengths
,
typename
SubLengths
,
typename
ThreadClusterLengths
,
typename
ThreadClusterLengths
,
...
@@ -738,45 +736,43 @@ struct BlockwiseGenericTensorSliceCopy_v4
...
@@ -738,45 +736,43 @@ struct BlockwiseGenericTensorSliceCopy_v4
return
RegisterBufferDesc
::
GetElementSpace
();
return
RegisterBufferDesc
::
GetElementSpace
();
}
}
template
<
typename
TData
>
template
<
typename
TData
,
address_space_t
SrcAddressSpace
=
address_space_t
::
generic
>
__device__
void
RunLoadRegisterBuffer
(
const
TData
*
p_src
,
TData
*
p_buffer
)
const
__device__
void
RunLoadRegisterBuffer
(
const
TData
*
p_src
,
TData
*
p_buffer
)
const
{
{
#if 0
#if 1
mThreadwiseLoad.Run_generic(p_src, p_buffer);
mThreadwiseLoad
.
template
Run_generic
<
TData
,
SrcAddressSpace
,
address_space_t
::
vgpr
>(
#elif
1
p_src
,
p_buffer
);
// hardcoded: src is global memory
#else
mThreadwiseLoad
.
template
Run_generic
<
TData
,
address_space_t
::
global
>(
p_src
,
p_buffer
);
mThreadwiseLoad
.
template
Run_optimized_src_address_calculation
<
TData
,
#elif 1
SrcAddressSpace
,
// hardcoded: src is global memory
address_space_t
::
vgpr
>(
mThreadwiseLoad
p_src
,
p_buffer
);
.
template
Run_optimized_src_address_calculation
<
TData
,
address_space_t
::
global
>(
p_src
,
p_buffer
);
#endif
#endif
}
}
template
<
typename
TData
>
template
<
typename
TData
,
address_space_t
DstAddressSpace
=
address_space_t
::
generic
>
__device__
void
RunStoreRegisterBuffer
(
const
TData
*
p_buffer
,
TData
*
p_dst
)
const
__device__
void
RunStoreRegisterBuffer
(
const
TData
*
p_buffer
,
TData
*
p_dst
)
const
{
{
#if 0
#if 1
mThreadwiseStore.Run_generic(p_buffer, p_dst);
mThreadwiseStore
.
template
Run_generic
<
TData
,
address_space_t
::
vgpr
,
DstAddressSpace
>(
#elif
1
p_buffer
,
p_dst
);
// hardcoded: dst is lds
#else
mThreadwiseStore
.
template
Run_generic
<
TData
,
address_space_t
::
lds
>(
p_buffer
,
p_dst
);
mThreadwiseStore
.
template
Run_optimized_dst_address_calculation
<
TData
,
#elif 1
address_space_t
::
vgpr
,
// hardcoded: dst is lds
DstAddressSpace
>(
p_buffer
,
mThreadwiseStore
.
template
Run_optimized_dst_address_calculation
<
TData
,
address_space_t
::
lds
>(
p_buffer
,
p_dst
);
p_dst
);
#endif
#endif
}
}
template
<
typename
TData
>
template
<
typename
TData
,
address_space_t
SrcAddressSpace
=
address_space_t
::
generic
,
address_space_t
DstAddressSpace
=
address_space_t
::
generic
>
__device__
void
Run
(
const
TData
*
p_src
,
TData
*
p_dst
)
const
__device__
void
Run
(
const
TData
*
p_src
,
TData
*
p_dst
)
const
{
{
TData
p_buffer
[
GetRegisterBufferSize
()];
TData
p_buffer
[
GetRegisterBufferSize
()];
RunLoadRegisterBuffer
(
p_src
,
p_buffer
);
RunLoadRegisterBuffer
<
TData
,
SrcAddressSpace
>
(
p_src
,
p_buffer
);
RunStoreRegisterBuffer
(
p_buffer
,
p_dst
);
RunStoreRegisterBuffer
<
TData
,
DstAddressSpace
>
(
p_buffer
,
p_dst
);
}
}
template
<
typename
T
,
bool
PositiveDirection
>
template
<
typename
T
,
bool
PositiveDirection
>
...
@@ -802,9 +798,7 @@ struct BlockwiseGenericTensorSliceCopy_v4
...
@@ -802,9 +798,7 @@ struct BlockwiseGenericTensorSliceCopy_v4
ThreadwiseGenericTensorSliceCopy_v4r2
<
SrcDesc
,
ThreadwiseGenericTensorSliceCopy_v4r2
<
SrcDesc
,
RegisterBufferDesc
,
RegisterBufferDesc
,
SrcLinearDimensionMask
,
SrcLinearDimensionMask
,
SrcNonLinearDimensionMask
,
typename
uniform_sequence_gen
<
nDim
,
1
>::
type
,
typename
uniform_sequence_gen
<
nDim
,
1
>::
type
,
typename
uniform_sequence_gen
<
nDim
,
0
>::
type
,
SubLengths
,
SubLengths
,
SrcDimAccessOrder
,
SrcDimAccessOrder
,
SrcVectorAccessDim
,
SrcVectorAccessDim
,
...
@@ -815,9 +809,7 @@ struct BlockwiseGenericTensorSliceCopy_v4
...
@@ -815,9 +809,7 @@ struct BlockwiseGenericTensorSliceCopy_v4
ThreadwiseGenericTensorSliceCopy_v4r2
<
RegisterBufferDesc
,
ThreadwiseGenericTensorSliceCopy_v4r2
<
RegisterBufferDesc
,
DstDesc
,
DstDesc
,
typename
uniform_sequence_gen
<
nDim
,
1
>::
type
,
typename
uniform_sequence_gen
<
nDim
,
1
>::
type
,
typename
uniform_sequence_gen
<
nDim
,
0
>::
type
,
DstLinearDimensionMask
,
DstLinearDimensionMask
,
DstNonLinearDimensionMask
,
SubLengths
,
SubLengths
,
DstDimAccessOrder
,
DstDimAccessOrder
,
DstVectorAccessDim
,
DstVectorAccessDim
,
...
...
composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp
View file @
545d9305
...
@@ -1131,9 +1131,7 @@ struct ThreadwiseGenericTensorSliceCopy_v3r1
...
@@ -1131,9 +1131,7 @@ struct ThreadwiseGenericTensorSliceCopy_v3r1
template
<
typename
SrcDesc
,
template
<
typename
SrcDesc
,
typename
DstDesc
,
typename
DstDesc
,
typename
SrcLinearDimensionMask
,
typename
SrcLinearDimensionMask
,
typename
SrcNonLinearDimensionMask
,
typename
DstLinearDimensionMask
,
typename
DstLinearDimensionMask
,
typename
DstNonLinearDimensionMask
,
typename
SliceLengths
,
typename
SliceLengths
,
typename
DimAccessOrder
,
typename
DimAccessOrder
,
index_t
VectorAccessDim
,
index_t
VectorAccessDim
,
...
@@ -1231,8 +1229,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
...
@@ -1231,8 +1229,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// Check src vector's padding situation, only check the first data in this src
// Check src vector's padding situation, only check the first data in this src
// vector. It's user's responsiblity to make sure all data in the src vector has
// vector. It's user's responsiblity to make sure all data in the src vector has
// the same padding situation
// the same padding situation
// TODO: not sure a dedicated IsAnyLevelIndexInPaddingArea() function is neccessary
if
(
src_coord
.
IsUpperIndexMappedToValidOffset
())
if
(
!
src_coord
.
IsAnyLevelIndexInPaddingArea
())
{
{
static_if
<
SrcAddressSpace
==
address_space_t
::
global
>
{}([
&
](
auto
)
{
static_if
<
SrcAddressSpace
==
address_space_t
::
global
>
{}([
&
](
auto
)
{
#if CK_USE_AMD_INTRINSIC && CK_USE_AMD_INTRINSIC_BUFFER_LOAD_STORE
#if CK_USE_AMD_INTRINSIC && CK_USE_AMD_INTRINSIC_BUFFER_LOAD_STORE
...
@@ -1260,13 +1257,10 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
...
@@ -1260,13 +1257,10 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
const
auto
dst_coord
=
mDstSliceOrigin
+
(
long_vector_data_begin_id
+
scalar_id
);
const
auto
dst_coord
=
mDstSliceOrigin
+
(
long_vector_data_begin_id
+
scalar_id
);
// Check dst vector's padding situation, only check the first data in this dst
// Check dst vector's padding situation, only check the first data in this dst
// vector. It's user's responsiblity to make sure all data in the dst vector has
// vector. It's user's responsiblity to make sure all data in the dst vector has
// the same padding situation
// the same padding situation
// TODO: not sure a dedicated IsAnyLevelIndexInPaddingArea() function is neccessary
if
(
dst_coord
.
IsUpperIndexMappedToValidOffset
())
#if 0 // tuning
if(!dst_coord.IsAnyLevelIndexInPaddingArea())
#endif
{
{
static_if
<
DstAddressSpace
==
address_space_t
::
global
>
{}([
&
](
auto
)
{
static_if
<
DstAddressSpace
==
address_space_t
::
global
>
{}([
&
](
auto
)
{
#if CK_USE_AMD_INTRINSIC && CK_USE_AMD_INTRINSIC_BUFFER_LOAD_STORE
#if CK_USE_AMD_INTRINSIC && CK_USE_AMD_INTRINSIC_BUFFER_LOAD_STORE
...
@@ -1303,7 +1297,9 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
...
@@ -1303,7 +1297,9 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// Will do padding check on src data: Read 0 if src data is in padding area.
// Will do padding check on src data: Read 0 if src data is in padding area.
// Will do padding check on dst data: No write if dst data is in paddin area.
// Will do padding check on dst data: No write if dst data is in paddin area.
// This version is optimized for address calculation of src tensor
// This version is optimized for address calculation of src tensor
template
<
typename
TData
,
address_space_t
SrcAddressSpace
=
address_space_t
::
generic
>
template
<
typename
TData
,
address_space_t
SrcAddressSpace
=
address_space_t
::
generic
,
address_space_t
DstAddressSpace
=
address_space_t
::
generic
>
__device__
void
Run_optimized_src_address_calculation
(
const
TData
*
p_src
,
TData
*
p_dst
)
const
__device__
void
Run_optimized_src_address_calculation
(
const
TData
*
p_src
,
TData
*
p_dst
)
const
{
{
using
src_vector_t
=
typename
vector_type
<
TData
,
SrcDataPerAccess
>::
MemoryType
;
using
src_vector_t
=
typename
vector_type
<
TData
,
SrcDataPerAccess
>::
MemoryType
;
...
@@ -1321,8 +1317,9 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
...
@@ -1321,8 +1317,9 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// TODO:: stop using this hack, once TransformedTensorDescriptor::GetLinearDimensionMask()
// TODO:: stop using this hack, once TransformedTensorDescriptor::GetLinearDimensionMask()
// is implemented
// is implemented
constexpr
auto
src_linear_dim_mask
=
SrcLinearDimensionMask
{};
constexpr
auto
src_linear_dim_mask
=
SrcLinearDimensionMask
{};
constexpr
auto
src_nonlinear_dim_mask
=
SrcNonLinearDimensionMask
{};
constexpr
auto
src_nonlinear_dim_mask
=
SrcLinearDimensionMask
::
Transform
(
logical_not
<
index_t
>
{});
static_assert
(
static_assert
(
src_linear_dim_mask
.
At
(
VectorAccessDim
)
||
long_vector_size
==
SrcDataPerAccess
,
src_linear_dim_mask
.
At
(
VectorAccessDim
)
||
long_vector_size
==
SrcDataPerAccess
,
...
@@ -1392,9 +1389,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
...
@@ -1392,9 +1389,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// Check src vector's padding situation, only check the first data in
// Check src vector's padding situation, only check the first data in
// this src vector. It's user's responsiblity to make sure all data in
// this src vector. It's user's responsiblity to make sure all data in
// the src vector has the same padding situation
// the src vector has the same padding situation
// TODO: not sure a dedicated IsAnyLevelIndexInPaddingArea() function is
if
(
src_coord
.
IsUpperIndexMappedToValidOffset
())
// neccessary
if
(
!
src_coord
.
IsAnyLevelIndexInPaddingArea
())
{
{
static_if
<
SrcAddressSpace
==
address_space_t
::
global
>
{}([
&
](
auto
)
{
static_if
<
SrcAddressSpace
==
address_space_t
::
global
>
{}([
&
](
auto
)
{
#if CK_USE_AMD_INTRINSIC && CK_USE_AMD_INTRINSIC_BUFFER_LOAD_STORE
#if CK_USE_AMD_INTRINSIC && CK_USE_AMD_INTRINSIC_BUFFER_LOAD_STORE
...
@@ -1427,14 +1422,10 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
...
@@ -1427,14 +1422,10 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
const
auto
dst_coord
=
mDstSliceOrigin
+
(
nonlinear_dim_data_steps
+
const
auto
dst_coord
=
mDstSliceOrigin
+
(
nonlinear_dim_data_steps
+
linear_dim_data_steps
+
scalar_id
);
linear_dim_data_steps
+
scalar_id
);
// Check dst vector's padding situation, only check the first data in
// Check dst vector's padding situation, only check the first data in
// this dst vector. It's user's responsiblity to make sure all data in
// this dst vector. It's user's responsiblity to make sure all data in
// the dst vector has the same padding situation
// the dst vector has the same padding situation
// TODO: not sure a dedicated IsAnyLevelIndexInPaddingArea() function is
if
(
dst_coord
.
IsUpperIndexMappedToValidOffset
())
// neccessary
#if 0 // tuning
if(!dst_coord.IsAnyLevelIndexInPaddingArea())
#endif
{
{
*
reinterpret_cast
<
dst_vector_t
*>
(
&
p_dst
[
dst_coord
.
GetOffset
()])
=
*
reinterpret_cast
<
dst_vector_t
*>
(
&
p_dst
[
dst_coord
.
GetOffset
()])
=
*
reinterpret_cast
<
dst_vector_t
*>
(
&
p_long_vector
[
buffer_offset
]);
*
reinterpret_cast
<
dst_vector_t
*>
(
&
p_long_vector
[
buffer_offset
]);
...
@@ -1450,7 +1441,9 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
...
@@ -1450,7 +1441,9 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// Will do padding check on src data: Read 0 if src data is in padding area.
// Will do padding check on src data: Read 0 if src data is in padding area.
// Will do padding check on dst data: No write if dst data is in paddin area.
// Will do padding check on dst data: No write if dst data is in paddin area.
// This version is optimized for address calculation of dst tensor
// This version is optimized for address calculation of dst tensor
template
<
typename
TData
,
address_space_t
DstAddressSpace
=
address_space_t
::
generic
>
template
<
typename
TData
,
address_space_t
SrcAddressSpace
=
address_space_t
::
generic
,
address_space_t
DstAddressSpace
=
address_space_t
::
generic
>
__device__
void
Run_optimized_dst_address_calculation
(
const
TData
*
p_src
,
TData
*
p_dst
)
const
__device__
void
Run_optimized_dst_address_calculation
(
const
TData
*
p_src
,
TData
*
p_dst
)
const
{
{
using
src_vector_t
=
typename
vector_type
<
TData
,
SrcDataPerAccess
>::
MemoryType
;
using
src_vector_t
=
typename
vector_type
<
TData
,
SrcDataPerAccess
>::
MemoryType
;
...
@@ -1468,8 +1461,9 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
...
@@ -1468,8 +1461,9 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// TODO:: stop using this hack, once TransformedTensorDescriptor::GetLinearDimensionMask()
// TODO:: stop using this hack, once TransformedTensorDescriptor::GetLinearDimensionMask()
// is implemented
// is implemented
constexpr
auto
dst_linear_dim_mask
=
DstLinearDimensionMask
{};
constexpr
auto
dst_linear_dim_mask
=
DstLinearDimensionMask
{};
constexpr
auto
dst_nonlinear_dim_mask
=
DstNonLinearDimensionMask
{};
constexpr
auto
dst_nonlinear_dim_mask
=
DstLinearDimensionMask
::
Transform
(
logical_not
<
index_t
>
{});
static_assert
(
static_assert
(
dst_linear_dim_mask
.
At
(
VectorAccessDim
)
||
long_vector_size
==
DstDataPerAccess
,
dst_linear_dim_mask
.
At
(
VectorAccessDim
)
||
long_vector_size
==
DstDataPerAccess
,
...
@@ -1535,9 +1529,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
...
@@ -1535,9 +1529,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// Check src vector's padding situation, only check the first data in
// Check src vector's padding situation, only check the first data in
// this src vector. It's user's responsiblity to make sure all data in
// this src vector. It's user's responsiblity to make sure all data in
// the src vector has the same padding situation
// the src vector has the same padding situation
// TODO: not sure a dedicated IsAnyLevelIndexInPaddingArea() function is
if
(
src_coord
.
IsUpperIndexMappedToValidOffset
())
// neccessary
if
(
!
src_coord
.
IsAnyLevelIndexInPaddingArea
())
{
{
*
reinterpret_cast
<
src_vector_t
*>
(
&
p_long_vector
[
buffer_offset
])
=
*
reinterpret_cast
<
src_vector_t
*>
(
&
p_long_vector
[
buffer_offset
])
=
*
reinterpret_cast
<
const
src_vector_t
*>
(
&
p_src
[
src_coord
.
GetOffset
()]);
*
reinterpret_cast
<
const
src_vector_t
*>
(
&
p_src
[
src_coord
.
GetOffset
()]);
...
@@ -1561,14 +1553,10 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
...
@@ -1561,14 +1553,10 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
const
index_t
dst_linear_offset
=
const
index_t
dst_linear_offset
=
dst_coord
.
GetOffset
()
-
dst_nonlinear_coord
.
GetOffset
();
dst_coord
.
GetOffset
()
-
dst_nonlinear_coord
.
GetOffset
();
// Check dst vector's padding situation, only check the first data in
// Check dst vector's padding situation, only check the first data in
// this dst vector. It's user's responsiblity to make sure all data in
// this dst vector. It's user's responsiblity to make sure all data in
// the dst vector has the same padding situation
// the dst vector has the same padding situation
// TODO: not sure a dedicated IsAnyLevelIndexInPaddingArea() function is
if
(
dst_coord
.
IsUpperIndexMappedToValidOffset
())
// neccessary
#if 0 // tuning
if(!dst_coord.IsAnyLevelIndexInPaddingArea())
#endif
{
{
static_if
<
DstAddressSpace
==
address_space_t
::
global
>
{}([
&
](
auto
)
{
static_if
<
DstAddressSpace
==
address_space_t
::
global
>
{}([
&
](
auto
)
{
#if CK_USE_AMD_INTRINSIC && CK_USE_AMD_INTRINSIC_BUFFER_LOAD_STORE
#if CK_USE_AMD_INTRINSIC && CK_USE_AMD_INTRINSIC_BUFFER_LOAD_STORE
...
...
composable_kernel/include/utility/array.hpp
View file @
545d9305
...
@@ -110,8 +110,7 @@ struct ArrayElementPicker
...
@@ -110,8 +110,7 @@ struct ArrayElementPicker
__host__
__device__
explicit
constexpr
ArrayElementPicker
(
Arr
&
array
)
:
mArray
{
array
}
__host__
__device__
explicit
constexpr
ArrayElementPicker
(
Arr
&
array
)
:
mArray
{
array
}
{
{
constexpr
index_t
imax
=
constexpr
index_t
imax
=
reduce_on_sequence
(
Picks
{},
math
::
maxer
<
index_t
>
{},
Number
<
0
>
{});
accumulate_on_sequence
(
Picks
{},
math
::
maxer
<
index_t
>
{},
Number
<
0
>
{});
static_assert
(
imax
<
Arr
::
Size
(),
"wrong! exceeding # array element"
);
static_assert
(
imax
<
Arr
::
Size
(),
"wrong! exceeding # array element"
);
}
}
...
...
composable_kernel/include/utility/functional.hpp
View file @
545d9305
...
@@ -25,6 +25,12 @@ struct swallow
...
@@ -25,6 +25,12 @@ struct swallow
}
}
};
};
template
<
typename
T
>
struct
logical_not
{
constexpr
bool
operator
()(
const
T
&
x
)
const
{
return
!
x
;
}
};
// Emulate if constexpr
// Emulate if constexpr
template
<
bool
>
template
<
bool
>
struct
static_if
;
struct
static_if
;
...
...
composable_kernel/include/utility/sequence.hpp
View file @
545d9305
...
@@ -764,12 +764,12 @@ __host__ __device__ constexpr auto pick_sequence_elements_by_mask(Seq, Mask)
...
@@ -764,12 +764,12 @@ __host__ __device__ constexpr auto pick_sequence_elements_by_mask(Seq, Mask)
#endif
#endif
template
<
typename
Seq
,
typename
Reduce
>
template
<
typename
Seq
,
typename
Reduce
>
struct
lambda_
accumulat
e_on_sequence
struct
lambda_
reduc
e_on_sequence
{
{
const
Reduce
&
f
;
const
Reduce
&
f
;
index_t
&
result
;
index_t
&
result
;
__host__
__device__
constexpr
lambda_
accumulat
e_on_sequence
(
const
Reduce
&
f_
,
index_t
&
result_
)
__host__
__device__
constexpr
lambda_
reduc
e_on_sequence
(
const
Reduce
&
f_
,
index_t
&
result_
)
:
f
(
f_
),
result
(
result_
)
:
f
(
f_
),
result
(
result_
)
{
{
}
}
...
@@ -783,14 +783,42 @@ struct lambda_accumulate_on_sequence
...
@@ -783,14 +783,42 @@ struct lambda_accumulate_on_sequence
template
<
typename
Seq
,
typename
Reduce
,
index_t
Init
>
template
<
typename
Seq
,
typename
Reduce
,
index_t
Init
>
__host__
__device__
constexpr
index_t
__host__
__device__
constexpr
index_t
accumulat
e_on_sequence
(
Seq
,
Reduce
f
,
Number
<
Init
>
/*initial_value*/
)
reduc
e_on_sequence
(
Seq
,
Reduce
f
,
Number
<
Init
>
/*initial_value*/
)
{
{
index_t
result
=
Init
;
index_t
result
=
Init
;
static_for
<
0
,
Seq
::
m
Size
,
1
>
{}(
lambda_
accumulat
e_on_sequence
<
Seq
,
Reduce
>
(
f
,
result
));
static_for
<
0
,
Seq
::
Size
()
,
1
>
{}(
lambda_
reduc
e_on_sequence
<
Seq
,
Reduce
>
(
f
,
result
));
return
result
;
return
result
;
}
}
// TODO: a generic any_of for any container
template
<
typename
Seq
,
typename
F
>
__host__
__device__
constexpr
bool
sequence_any_of
(
Seq
,
F
f
/*initial_value*/
)
{
bool
flag
=
false
;
for
(
index_t
i
=
0
;
i
<
Seq
::
Size
();
++
i
)
{
flag
=
flag
||
f
(
Seq
::
At
(
i
));
}
return
flag
;
}
// TODO: a generic all_of for any container
template
<
typename
Seq
,
typename
F
>
__host__
__device__
constexpr
bool
sequence_all_of
(
Seq
,
F
f
/*initial_value*/
)
{
bool
flag
=
true
;
for
(
index_t
i
=
0
;
i
<
Seq
::
Size
();
++
i
)
{
flag
=
flag
&&
f
(
Seq
::
At
(
i
));
}
return
flag
;
}
}
// namespace ck
}
// namespace ck
#endif
#endif
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment