Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
5e88414a
Commit
5e88414a
authored
Oct 02, 2020
by
Chao Liu
Browse files
use statically index array for all existing kernels
parent
a578ff93
Changes
17
Show whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
122 additions
and
98 deletions
+122
-98
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw_lds_double_buffer.hpp
...a_implicit_gemm_v1r2_nchw_kcyx_nkhw_lds_double_buffer.hpp
+11
-11
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk.hpp
...ution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk.hpp
+1
-1
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp
...n_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp
+1
-1
composable_kernel/include/tensor_description/dynamic_tensor_descriptor.hpp
.../include/tensor_description/dynamic_tensor_descriptor.hpp
+9
-7
composable_kernel/include/tensor_description/dynamic_tensor_descriptor_v2.hpp
...clude/tensor_description/dynamic_tensor_descriptor_v2.hpp
+20
-20
composable_kernel/include/tensor_description/tensor_descriptor.hpp
...e_kernel/include/tensor_description/tensor_descriptor.hpp
+9
-7
composable_kernel/include/tensor_operation/gridwise_gemm.hpp
composable_kernel/include/tensor_operation/gridwise_gemm.hpp
+28
-26
composable_kernel/include/utility/common_header.hpp
composable_kernel/include/utility/common_header.hpp
+2
-2
composable_kernel/include/utility/container_element_picker.hpp
...sable_kernel/include/utility/container_element_picker.hpp
+12
-12
composable_kernel/include/utility/container_helper.hpp
composable_kernel/include/utility/container_helper.hpp
+15
-4
composable_kernel/include/utility/functional3.hpp
composable_kernel/include/utility/functional3.hpp
+2
-2
composable_kernel/include/utility/print.hpp
composable_kernel/include/utility/print.hpp
+1
-1
composable_kernel/include/utility/tuple.hpp
composable_kernel/include/utility/tuple.hpp
+6
-0
composable_kernel/include/utility/tuple_helper.hpp
composable_kernel/include/utility/tuple_helper.hpp
+2
-1
driver/include/device_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk.hpp
...ution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk.hpp
+1
-1
driver/src/conv_bwd_data_driver.cpp
driver/src/conv_bwd_data_driver.cpp
+1
-1
driver/src/conv_driver.cpp
driver/src/conv_driver.cpp
+1
-1
No files found.
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw_lds_double_buffer.hpp
View file @
5e88414a
...
@@ -107,8 +107,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
...
@@ -107,8 +107,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
const
auto
block_work_id
=
block_work_desc
.
CalculateClusterIndex
(
get_block_1d_id
());
const
auto
block_work_id
=
block_work_desc
.
CalculateClusterIndex
(
get_block_1d_id
());
const
index_t
e_block_data_on_global
=
block_work_id
[
0
]
*
EPerBlock
;
const
index_t
e_block_data_on_global
=
block_work_id
[
Number
<
0
>
{}
]
*
EPerBlock
;
const
index_t
b_block_data_on_global
=
block_work_id
[
1
]
*
BPerBlock
;
const
index_t
b_block_data_on_global
=
block_work_id
[
Number
<
1
>
{}
]
*
BPerBlock
;
// output tensor
// output tensor
// global tensor in global memory, src of blockwise copy
// global tensor in global memory, src of blockwise copy
...
@@ -151,7 +151,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
...
@@ -151,7 +151,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
AddressSpace
::
Vgpr
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Lds
,
AddressSpace
::
Lds
,
InMemoryDataOperation
::
Set
>
(
InMemoryDataOperation
::
Set
>
(
{
0
,
b_block_data_on_global
,
0
}
,
{
0
,
0
,
0
}
);
make_multi_index
(
0
,
b_block_data_on_global
,
0
)
,
make_multi_index
(
0
,
0
,
0
)
);
// weight tensor
// weight tensor
// global tensor in global memory, src of blockwise copy
// global tensor in global memory, src of blockwise copy
...
@@ -191,7 +191,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
...
@@ -191,7 +191,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
AddressSpace
::
Vgpr
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Lds
,
AddressSpace
::
Lds
,
InMemoryDataOperation
::
Set
>
(
InMemoryDataOperation
::
Set
>
(
{
0
,
e_block_data_on_global
,
0
}
,
{
0
,
0
,
0
}
);
make_multi_index
(
0
,
e_block_data_on_global
,
0
)
,
make_multi_index
(
0
,
0
,
0
)
);
// GEMM definition
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// c_mtx += transpose(a_mtx) * b_mtx
...
@@ -434,13 +434,13 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
...
@@ -434,13 +434,13 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
InThreadCopyDstDataPerWrite_B
,
InThreadCopyDstDataPerWrite_B
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Global
,
AddressSpace
::
Global
,
in_memory_op
>
(
{
0
,
0
,
0
,
0
,
0
,
0
}
,
in_memory_op
>
(
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
)
,
{
e_thread_data_on_global
/
E1
,
make_multi_index
(
e_thread_data_on_global
/
E1
,
e_thread_data_on_global
%
E1
,
e_thread_data_on_global
%
E1
,
0
,
0
,
b_thread_data_on_global
/
B1
,
b_thread_data_on_global
/
B1
,
b_thread_data_on_global
%
B1
,
b_thread_data_on_global
%
B1
,
0
}
)
0
)
)
.
Run
(
p_in_thread
,
p_in_global
);
.
Run
(
p_in_thread
,
p_in_global
);
}
}
}
}
...
...
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk.hpp
View file @
5e88414a
...
@@ -125,7 +125,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nhwc_kyxc_nhwk
...
@@ -125,7 +125,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nhwc_kyxc_nhwk
index_t
GemmK1
=
XDotSlice
;
index_t
GemmK1
=
XDotSlice
;
index_t
GemmK2
=
K
;
index_t
GemmK2
=
K
;
return
Array
<
index_t
,
5
>
{
GemmM
,
GemmN
,
GemmK0
,
GemmK1
,
GemmK2
}
;
return
make_multi_index
(
GemmM
,
GemmN
,
GemmK0
,
GemmK1
,
GemmK2
)
;
}
}
__host__
__device__
static
constexpr
auto
GetGemmSize
(
index_t
gemm_id
)
__host__
__device__
static
constexpr
auto
GetGemmSize
(
index_t
gemm_id
)
...
...
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp
View file @
5e88414a
...
@@ -226,7 +226,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
...
@@ -226,7 +226,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
AddressSpace
::
Vgpr
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Lds
,
AddressSpace
::
Lds
,
InMemoryDataOperation
::
Set
>
(
InMemoryDataOperation
::
Set
>
(
{
0
,
k_block_data_on_global
}
,
{
0
,
0
}
);
make_multi_index
(
0
,
k_block_data_on_global
)
,
make_multi_index
(
0
,
0
)
);
// GEMM definition
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// c_mtx += transpose(a_mtx) * b_mtx
...
...
composable_kernel/include/tensor_description/dynamic_tensor_descriptor.hpp
View file @
5e88414a
...
@@ -242,8 +242,8 @@ struct DynamicTransformedTensorDescriptor
...
@@ -242,8 +242,8 @@ struct DynamicTransformedTensorDescriptor
static_for
<
0
,
NTransform
,
1
>
{}([
&
](
auto
itran
)
constexpr
{
static_for
<
0
,
NTransform
,
1
>
{}([
&
](
auto
itran
)
constexpr
{
const
auto
tran
=
transforms_
.
At
(
itran
);
const
auto
tran
=
transforms_
.
At
(
itran
);
const
auto
idx_up_part
=
pick_
array
_element
(
idx_up
,
UpDimensionIds
{}.
At
(
itran
));
const
auto
idx_up_part
=
pick_
container
_element
(
idx_up
,
UpDimensionIds
{}.
At
(
itran
));
auto
idx_low_part
=
pick_
array
_element
(
idx_low
,
LowDimensionIds
{}.
At
(
itran
));
auto
idx_low_part
=
pick_
container
_element
(
idx_low
,
LowDimensionIds
{}.
At
(
itran
));
tran
.
CalculateLowerIndex
(
idx_low_part
,
idx_up_part
);
tran
.
CalculateLowerIndex
(
idx_low_part
,
idx_up_part
);
});
});
...
@@ -259,14 +259,16 @@ struct DynamicTransformedTensorDescriptor
...
@@ -259,14 +259,16 @@ struct DynamicTransformedTensorDescriptor
const
auto
tran
=
transforms_
.
At
(
itran
);
const
auto
tran
=
transforms_
.
At
(
itran
);
const
auto
idx_up_diff_part
=
const
auto
idx_up_diff_part
=
pick_
array
_element
(
idx_up_diff
,
UpDimensionIds
{}.
At
(
itran
));
pick_
container
_element
(
idx_up_diff
,
UpDimensionIds
{}.
At
(
itran
));
const
auto
idx_up_old_part
=
pick_array_element
(
idx_up_old
,
UpDimensionIds
{}.
At
(
itran
));
const
auto
idx_up_old_part
=
pick_container_element
(
idx_up_old
,
UpDimensionIds
{}.
At
(
itran
));
const
auto
idx_low_old_part
=
const
auto
idx_low_old_part
=
pick_
array
_element
(
idx_low_old
,
LowDimensionIds
{}.
At
(
itran
));
pick_
container
_element
(
idx_low_old
,
LowDimensionIds
{}.
At
(
itran
));
auto
idx_low_diff_part
=
pick_array_element
(
idx_low_diff
,
LowDimensionIds
{}.
At
(
itran
));
auto
idx_low_diff_part
=
pick_container_element
(
idx_low_diff
,
LowDimensionIds
{}.
At
(
itran
));
tran
.
CalculateLowerIndexDiff
(
tran
.
CalculateLowerIndexDiff
(
idx_low_diff_part
,
idx_up_diff_part
,
idx_low_old_part
,
idx_up_old_part
);
idx_low_diff_part
,
idx_up_diff_part
,
idx_low_old_part
,
idx_up_old_part
);
...
@@ -325,7 +327,7 @@ struct DynamicTransformedTensorDescriptor
...
@@ -325,7 +327,7 @@ struct DynamicTransformedTensorDescriptor
if
constexpr
(
!
is_valid_up_always_mapped_to_valid_low
)
if
constexpr
(
!
is_valid_up_always_mapped_to_valid_low
)
{
{
const
auto
up_dims_part
=
UpDimensionIds
{}.
At
(
itran
);
const
auto
up_dims_part
=
UpDimensionIds
{}.
At
(
itran
);
const
auto
idx_up_part
=
pick_
array
_element
(
idx_up
,
up_dims_part
);
const
auto
idx_up_part
=
pick_
container
_element
(
idx_up
,
up_dims_part
);
flag
=
flag
&&
tran
.
IsValidUpperIndexMappedToValidLowerIndex
(
idx_up_part
);
flag
=
flag
&&
tran
.
IsValidUpperIndexMappedToValidLowerIndex
(
idx_up_part
);
}
}
...
...
composable_kernel/include/tensor_description/dynamic_tensor_descriptor_v2.hpp
View file @
5e88414a
...
@@ -140,7 +140,7 @@ struct DynamicTensorDescriptor_v2
...
@@ -140,7 +140,7 @@ struct DynamicTensorDescriptor_v2
MultiIndex
<
ndim_hidden
>
idx_hidden
;
MultiIndex
<
ndim_hidden
>
idx_hidden
;
// initialize visible index
// initialize visible index
auto
idx_hidden_pick_visible
=
pick_
array
_element
(
idx_hidden
,
visible_dim_ids
);
auto
idx_hidden_pick_visible
=
pick_
container
_element
(
idx_hidden
,
visible_dim_ids
);
idx_hidden_pick_visible
=
idx
;
idx_hidden_pick_visible
=
idx
;
// calculate hidden index
// calculate hidden index
...
@@ -149,8 +149,8 @@ struct DynamicTensorDescriptor_v2
...
@@ -149,8 +149,8 @@ struct DynamicTensorDescriptor_v2
constexpr
auto
dims_low
=
GetLowerDimensionIdss
().
At
(
itran
);
constexpr
auto
dims_low
=
GetLowerDimensionIdss
().
At
(
itran
);
constexpr
auto
dims_up
=
GetUpperDimensionIdss
().
At
(
itran
);
constexpr
auto
dims_up
=
GetUpperDimensionIdss
().
At
(
itran
);
const
auto
idx_up
=
pick_
array
_element
(
idx_hidden
,
dims_up
);
const
auto
idx_up
=
pick_
container
_element
(
idx_hidden
,
dims_up
);
auto
idx_low
=
pick_
array
_element
(
idx_hidden
,
dims_low
);
auto
idx_low
=
pick_
container
_element
(
idx_hidden
,
dims_low
);
tran
.
CalculateLowerIndex
(
idx_low
,
idx_up
);
tran
.
CalculateLowerIndex
(
idx_low
,
idx_up
);
});
});
...
@@ -193,7 +193,7 @@ struct DynamicTensorDescriptor_v2
...
@@ -193,7 +193,7 @@ struct DynamicTensorDescriptor_v2
constexpr
auto
up_dim_ids
=
UpperDimensionIdss
{}.
At
(
itran
);
constexpr
auto
up_dim_ids
=
UpperDimensionIdss
{}.
At
(
itran
);
// lengths_hidden_pick_up contains a reference to lengths_hidden
// lengths_hidden_pick_up contains a reference to lengths_hidden
auto
hidden_lengths_pick_up
=
pick_
array
_element
(
hidden_lengths
,
up_dim_ids
);
auto
hidden_lengths_pick_up
=
pick_
container
_element
(
hidden_lengths
,
up_dim_ids
);
hidden_lengths_pick_up
=
tran
.
GetUpperLengths
();
hidden_lengths_pick_up
=
tran
.
GetUpperLengths
();
});
});
...
@@ -207,7 +207,7 @@ struct DynamicTensorDescriptor_v2
...
@@ -207,7 +207,7 @@ struct DynamicTensorDescriptor_v2
// variable lengths_) to save space on stack?
// variable lengths_) to save space on stack?
const
HiddenIndex
hidden_lengths_
;
const
HiddenIndex
hidden_lengths_
;
// visible_lenths_ contains a reference to hidden_lengths_
// visible_lenths_ contains a reference to hidden_lengths_
const
Array
ElementPicker
<
const
HiddenIndex
,
VisibleDimensionIds
>
visible_lengths_
;
const
Container
ElementPicker
<
const
HiddenIndex
,
VisibleDimensionIds
>
visible_lengths_
;
#if 0
#if 0
// friend class
// friend class
...
@@ -283,7 +283,7 @@ struct DynamicTensorCoordinate_v2
...
@@ -283,7 +283,7 @@ struct DynamicTensorCoordinate_v2
// private member variables
// private member variables
HiddenIndex
idx_hidden_
;
HiddenIndex
idx_hidden_
;
// idx_visible_ contains a reference to idx_hidden_
// idx_visible_ contains a reference to idx_hidden_
Array
ElementPicker
<
HiddenIndex
,
VisibleDimensionIds
>
idx_visible_
;
Container
ElementPicker
<
HiddenIndex
,
VisibleDimensionIds
>
idx_visible_
;
#if 0
#if 0
// friend functions for making and updating tensor coordinate
// friend functions for making and updating tensor coordinate
...
@@ -441,7 +441,7 @@ make_dynamic_tensor_coordinate_v2(const TensorDesc& tensor_desc, const VisibleIn
...
@@ -441,7 +441,7 @@ make_dynamic_tensor_coordinate_v2(const TensorDesc& tensor_desc, const VisibleIn
MultiIndex
<
ndim_hidden
>
idx_hidden
;
MultiIndex
<
ndim_hidden
>
idx_hidden
;
// initialize visible index
// initialize visible index
auto
idx_hidden_pick_visible
=
pick_
array
_element
(
idx_hidden
,
visible_dim_ids
);
auto
idx_hidden_pick_visible
=
pick_
container
_element
(
idx_hidden
,
visible_dim_ids
);
idx_hidden_pick_visible
=
idx_visible
;
idx_hidden_pick_visible
=
idx_visible
;
// calculate hidden index
// calculate hidden index
...
@@ -451,8 +451,8 @@ make_dynamic_tensor_coordinate_v2(const TensorDesc& tensor_desc, const VisibleIn
...
@@ -451,8 +451,8 @@ make_dynamic_tensor_coordinate_v2(const TensorDesc& tensor_desc, const VisibleIn
constexpr
auto
dims_low
=
TensorDesc
::
GetLowerDimensionIdss
().
At
(
itran
);
constexpr
auto
dims_low
=
TensorDesc
::
GetLowerDimensionIdss
().
At
(
itran
);
constexpr
auto
dims_up
=
TensorDesc
::
GetUpperDimensionIdss
().
At
(
itran
);
constexpr
auto
dims_up
=
TensorDesc
::
GetUpperDimensionIdss
().
At
(
itran
);
const
auto
idx_up
=
pick_
array
_element
(
idx_hidden
,
dims_up
);
const
auto
idx_up
=
pick_
container
_element
(
idx_hidden
,
dims_up
);
auto
idx_low
=
pick_
array
_element
(
idx_hidden
,
dims_low
);
auto
idx_low
=
pick_
container
_element
(
idx_hidden
,
dims_low
);
tran
.
CalculateLowerIndex
(
idx_low
,
idx_up
);
tran
.
CalculateLowerIndex
(
idx_low
,
idx_up
);
});
});
...
@@ -477,7 +477,7 @@ make_dynamic_tensor_coordinate_step_v2(const TensorDesc&, const VisibleIndex& id
...
@@ -477,7 +477,7 @@ make_dynamic_tensor_coordinate_step_v2(const TensorDesc&, const VisibleIndex& id
Array
<
bool
,
ndim_hidden
>
non_zero_diff
{
false
};
Array
<
bool
,
ndim_hidden
>
non_zero_diff
{
false
};
auto
non_zero_diff_pick_visible
=
pick_
array
_element
(
non_zero_diff
,
visible_dim_ids
);
auto
non_zero_diff_pick_visible
=
pick_
container
_element
(
non_zero_diff
,
visible_dim_ids
);
static_for
<
0
,
ndim_visible
,
1
>
{}([
&
non_zero_diff_pick_visible
,
&
idx_diff_visible
](
auto
i
)
{
static_for
<
0
,
ndim_visible
,
1
>
{}([
&
non_zero_diff_pick_visible
,
&
idx_diff_visible
](
auto
i
)
{
non_zero_diff_pick_visible
(
i
)
=
(
idx_diff_visible
[
i
]
!=
0
);
non_zero_diff_pick_visible
(
i
)
=
(
idx_diff_visible
[
i
]
!=
0
);
...
@@ -487,8 +487,8 @@ make_dynamic_tensor_coordinate_step_v2(const TensorDesc&, const VisibleIndex& id
...
@@ -487,8 +487,8 @@ make_dynamic_tensor_coordinate_step_v2(const TensorDesc&, const VisibleIndex& id
constexpr
auto
dims_low
=
TensorDesc
::
GetLowerDimensionIdss
().
At
(
itran
);
constexpr
auto
dims_low
=
TensorDesc
::
GetLowerDimensionIdss
().
At
(
itran
);
constexpr
auto
dims_up
=
TensorDesc
::
GetUpperDimensionIdss
().
At
(
itran
);
constexpr
auto
dims_up
=
TensorDesc
::
GetUpperDimensionIdss
().
At
(
itran
);
const
auto
non_zero_diff_pick_up
=
pick_
array
_element
(
non_zero_diff
,
dims_up
);
const
auto
non_zero_diff_pick_up
=
pick_
container
_element
(
non_zero_diff
,
dims_up
);
auto
non_zero_diff_pick_low
=
pick_
array
_element
(
non_zero_diff
,
dims_low
);
auto
non_zero_diff_pick_low
=
pick_
container
_element
(
non_zero_diff
,
dims_low
);
// if any of upper index diff components is non-zero, then
// if any of upper index diff components is non-zero, then
// 1) Need to do this transform
// 1) Need to do this transform
...
@@ -526,7 +526,7 @@ __host__ __device__ void move_dynamic_tensor_coordinate_v2(const TensorDesc& ten
...
@@ -526,7 +526,7 @@ __host__ __device__ void move_dynamic_tensor_coordinate_v2(const TensorDesc& ten
// initialize visible index diff
// initialize visible index diff
// idx_diff_hidden_pick_visible contains reference to idx_diff_hidden
// idx_diff_hidden_pick_visible contains reference to idx_diff_hidden
auto
idx_diff_hidden_pick_visible
=
auto
idx_diff_hidden_pick_visible
=
pick_
array
_element
(
idx_diff_hidden
,
TensorDesc
::
GetVisibleDimensionIds
());
pick_
container
_element
(
idx_diff_hidden
,
TensorDesc
::
GetVisibleDimensionIds
());
idx_diff_hidden_pick_visible
=
coord_step
.
GetVisibleIndexDiff
();
idx_diff_hidden_pick_visible
=
coord_step
.
GetVisibleIndexDiff
();
...
@@ -535,7 +535,7 @@ __host__ __device__ void move_dynamic_tensor_coordinate_v2(const TensorDesc& ten
...
@@ -535,7 +535,7 @@ __host__ __device__ void move_dynamic_tensor_coordinate_v2(const TensorDesc& ten
// update visible index
// update visible index
auto
idx_hidden_pick_visible
=
auto
idx_hidden_pick_visible
=
pick_
array
_element
(
idx_hidden
,
TensorDesc
::
GetVisibleDimensionIds
());
pick_
container
_element
(
idx_hidden
,
TensorDesc
::
GetVisibleDimensionIds
());
idx_hidden_pick_visible
+=
coord_step
.
GetIndexDiff
();
idx_hidden_pick_visible
+=
coord_step
.
GetIndexDiff
();
// update rest of hidden index
// update rest of hidden index
...
@@ -546,12 +546,12 @@ __host__ __device__ void move_dynamic_tensor_coordinate_v2(const TensorDesc& ten
...
@@ -546,12 +546,12 @@ __host__ __device__ void move_dynamic_tensor_coordinate_v2(const TensorDesc& ten
constexpr
auto
dims_low
=
TensorDesc
::
GetLowerDimensionIdss
().
At
(
itran
);
constexpr
auto
dims_low
=
TensorDesc
::
GetLowerDimensionIdss
().
At
(
itran
);
constexpr
auto
dims_up
=
TensorDesc
::
GetUpperDimensionIdss
().
At
(
itran
);
constexpr
auto
dims_up
=
TensorDesc
::
GetUpperDimensionIdss
().
At
(
itran
);
// this const is for
Array
ElementPicker, Array itself may not be const
// this const is for
Container
ElementPicker, Array itself may not be const
const
auto
idx_up
=
pick_
array
_element
(
idx_hidden
,
dims_up
);
const
auto
idx_up
=
pick_
container
_element
(
idx_hidden
,
dims_up
);
auto
idx_low
=
pick_
array
_element
(
idx_hidden
,
dims_low
);
auto
idx_low
=
pick_
container
_element
(
idx_hidden
,
dims_low
);
const
auto
idx_diff_up
=
pick_
array
_element
(
idx_diff_hidden
,
dims_up
);
const
auto
idx_diff_up
=
pick_
container
_element
(
idx_diff_hidden
,
dims_up
);
auto
idx_diff_low
=
pick_
array
_element
(
idx_diff_hidden
,
dims_low
);
auto
idx_diff_low
=
pick_
container
_element
(
idx_diff_hidden
,
dims_low
);
tran
.
CalculateLowerIndexDiff
(
idx_diff_low
,
idx_diff_up
,
idx_low
,
idx_up
);
tran
.
CalculateLowerIndexDiff
(
idx_diff_low
,
idx_diff_up
,
idx_low
,
idx_up
);
...
@@ -579,7 +579,7 @@ coordinate_has_valid_offset_assuming_visible_index_is_valid(const TensorDesc& te
...
@@ -579,7 +579,7 @@ coordinate_has_valid_offset_assuming_visible_index_is_valid(const TensorDesc& te
if
constexpr
(
!
decltype
(
tran
)
::
IsValidUpperIndexAlwaysMappedToValidLowerIndex
())
if
constexpr
(
!
decltype
(
tran
)
::
IsValidUpperIndexAlwaysMappedToValidLowerIndex
())
{
{
const
auto
idx_up
=
const
auto
idx_up
=
pick_
array
_element
(
idx_hidden
,
TensorDesc
::
GetUpperDimensionIdss
().
At
(
itran
));
pick_
container
_element
(
idx_hidden
,
TensorDesc
::
GetUpperDimensionIdss
().
At
(
itran
));
valid
=
valid
&&
tran
.
IsValidUpperIndexMappedToValidLowerIndex
(
idx_up
);
valid
=
valid
&&
tran
.
IsValidUpperIndexMappedToValidLowerIndex
(
idx_up
);
}
}
...
...
composable_kernel/include/tensor_description/tensor_descriptor.hpp
View file @
5e88414a
...
@@ -311,8 +311,8 @@ struct TransformedTensorDescriptor
...
@@ -311,8 +311,8 @@ struct TransformedTensorDescriptor
static_for
<
0
,
nTransform
,
1
>
{}([
&
](
auto
itran
)
{
static_for
<
0
,
nTransform
,
1
>
{}([
&
](
auto
itran
)
{
constexpr
auto
tran
=
Transforms
{}.
At
(
itran
);
constexpr
auto
tran
=
Transforms
{}.
At
(
itran
);
const
auto
idx_up_part
=
pick_
array
_element
(
idx_up
,
UpDimensionIds
{}.
At
(
itran
));
const
auto
idx_up_part
=
pick_
container
_element
(
idx_up
,
UpDimensionIds
{}.
At
(
itran
));
auto
idx_low_part
=
pick_
array
_element
(
idx_low
,
LowDimensionIds
{}.
At
(
itran
));
auto
idx_low_part
=
pick_
container
_element
(
idx_low
,
LowDimensionIds
{}.
At
(
itran
));
// this assume each lower (single) index is only assocaited with one transformation,
// this assume each lower (single) index is only assocaited with one transformation,
// which is required for index transformation, and has been checked during constructor
// which is required for index transformation, and has been checked during constructor
...
@@ -333,14 +333,16 @@ struct TransformedTensorDescriptor
...
@@ -333,14 +333,16 @@ struct TransformedTensorDescriptor
constexpr
auto
tran
=
Transforms
{}.
At
(
itran
);
constexpr
auto
tran
=
Transforms
{}.
At
(
itran
);
const
auto
idx_up_diff_part
=
const
auto
idx_up_diff_part
=
pick_
array
_element
(
idx_up_diff
,
UpDimensionIds
{}.
At
(
itran
));
pick_
container
_element
(
idx_up_diff
,
UpDimensionIds
{}.
At
(
itran
));
const
auto
idx_up_old_part
=
pick_array_element
(
idx_up_old
,
UpDimensionIds
{}.
At
(
itran
));
const
auto
idx_up_old_part
=
pick_container_element
(
idx_up_old
,
UpDimensionIds
{}.
At
(
itran
));
const
auto
idx_low_old_part
=
const
auto
idx_low_old_part
=
pick_
array
_element
(
idx_low_old
,
LowDimensionIds
{}.
At
(
itran
));
pick_
container
_element
(
idx_low_old
,
LowDimensionIds
{}.
At
(
itran
));
auto
idx_low_diff_part
=
pick_array_element
(
idx_low_diff
,
LowDimensionIds
{}.
At
(
itran
));
auto
idx_low_diff_part
=
pick_container_element
(
idx_low_diff
,
LowDimensionIds
{}.
At
(
itran
));
// this assume each lower (single) index is associated with only one transformation,
// this assume each lower (single) index is associated with only one transformation,
// which is required for index transformation, and has been checked during constructor
// which is required for index transformation, and has been checked during constructor
...
@@ -508,7 +510,7 @@ struct TransformedTensorDescriptor
...
@@ -508,7 +510,7 @@ struct TransformedTensorDescriptor
constexpr
auto
low_lengths_part
=
constexpr
auto
low_lengths_part
=
GetLowerTensorDescriptor
().
GetLengths
(
low_dims_part
);
GetLowerTensorDescriptor
().
GetLengths
(
low_dims_part
);
const
auto
idx_low_part
=
const
auto
idx_low_part
=
to_multi_index
(
pick_
array
_element
(
idx_low
,
low_dims_part
));
to_multi_index
(
pick_
container
_element
(
idx_low
,
low_dims_part
));
static_for
<
0
,
decltype
(
low_dims_part
)
::
Size
(),
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
decltype
(
low_dims_part
)
::
Size
(),
1
>
{}([
&
](
auto
i
)
{
flag
=
flag
&&
idx_low_part
[
i
]
>=
0
&&
idx_low_part
[
i
]
<
low_lengths_part
[
i
];
flag
=
flag
&&
idx_low_part
[
i
]
>=
0
&&
idx_low_part
[
i
]
<
low_lengths_part
[
i
];
...
...
composable_kernel/include/tensor_operation/gridwise_gemm.hpp
View file @
5e88414a
...
@@ -116,8 +116,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
...
@@ -116,8 +116,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
const
auto
block_work_id
=
block_work_desc
.
CalculateClusterIndex
(
get_block_1d_id
());
const
auto
block_work_id
=
block_work_desc
.
CalculateClusterIndex
(
get_block_1d_id
());
const
index_t
m_block_data_on_global
=
block_work_id
[
0
]
*
MPerBlock
;
const
index_t
m_block_data_on_global
=
block_work_id
[
Number
<
0
>
{}
]
*
MPerBlock
;
const
index_t
n_block_data_on_global
=
block_work_id
[
1
]
*
NPerBlock
;
const
index_t
n_block_data_on_global
=
block_work_id
[
Number
<
1
>
{}
]
*
NPerBlock
;
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
...
@@ -143,7 +143,7 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
...
@@ -143,7 +143,7 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
AddressSpace
::
Vgpr
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Lds
,
AddressSpace
::
Lds
,
InMemoryDataOperation
::
Set
>
(
InMemoryDataOperation
::
Set
>
(
{
0
,
m_block_data_on_global
}
,
{
0
,
0
}
);
make_multi_index
(
0
,
m_block_data_on_global
)
,
make_multi_index
(
0
,
0
)
);
// B matrix in LDS memory, dst of blockwise copy
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
...
@@ -169,7 +169,7 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
...
@@ -169,7 +169,7 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
AddressSpace
::
Vgpr
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Lds
,
AddressSpace
::
Lds
,
InMemoryDataOperation
::
Set
>
(
InMemoryDataOperation
::
Set
>
(
{
0
,
n_block_data_on_global
}
,
{
0
,
0
}
);
make_multi_index
(
0
,
n_block_data_on_global
)
,
make_multi_index
(
0
,
0
)
);
// GEMM definition
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// c_mtx += transpose(a_mtx) * b_mtx
...
@@ -355,11 +355,11 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
...
@@ -355,11 +355,11 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
AddressSpace
::
Vgpr
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Global
,
AddressSpace
::
Global
,
CGlobalMemoryDataOperation
>
(
CGlobalMemoryDataOperation
>
(
{
0
,
0
,
0
,
0
}
,
make_multi_index
(
0
,
0
,
0
,
0
)
,
{
m_thread_data_on_global
/
M1
,
make_multi_index
(
m_thread_data_on_global
/
M1
,
m_thread_data_on_global
%
M1
,
m_thread_data_on_global
%
M1
,
n_thread_data_on_global
/
N1
,
n_thread_data_on_global
/
N1
,
n_thread_data_on_global
%
N1
}
)
n_thread_data_on_global
%
N1
)
)
.
Run
(
p_c_thread
,
p_c_global
);
.
Run
(
p_c_thread
,
p_c_global
);
}
}
}
}
...
@@ -447,21 +447,23 @@ struct GridwiseGemmTransposedANormalBNormalC_v2
...
@@ -447,21 +447,23 @@ struct GridwiseGemmTransposedANormalBNormalC_v2
Float
*
__restrict__
p_c_global
,
Float
*
__restrict__
p_c_global
,
Float
*
__restrict__
p_shared_block
)
const
Float
*
__restrict__
p_shared_block
)
const
{
{
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
False
=
integral_constant
<
bool
,
false
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
False
=
integral_constant
<
bool
,
false
>
{};
constexpr
auto
a_k0_k1_k2_m_global_desc
=
AGlobalDesc
{};
constexpr
auto
a_k0_k1_k2_m_global_desc
=
AGlobalDesc
{};
constexpr
auto
b_k0_k1_k2_n_global_desc
=
BGlobalDesc
{};
constexpr
auto
b_k0_k1_k2_n_global_desc
=
BGlobalDesc
{};
constexpr
auto
c_m_n_global_desc
=
CGlobalDesc
{};
constexpr
auto
c_m_n_global_desc
=
CGlobalDesc
{};
constexpr
auto
K0
=
a_k0_k1_k2_m_global_desc
.
GetLengths
()[
0
];
constexpr
auto
K0
=
a_k0_k1_k2_m_global_desc
.
GetLengths
()[
I
0
];
constexpr
auto
K1
=
a_k0_k1_k2_m_global_desc
.
GetLengths
()[
1
];
constexpr
auto
K1
=
a_k0_k1_k2_m_global_desc
.
GetLengths
()[
I
1
];
constexpr
auto
K
=
a_k0_k1_k2_m_global_desc
.
GetLengths
()[
2
];
constexpr
auto
K
=
a_k0_k1_k2_m_global_desc
.
GetLengths
()[
I
2
];
constexpr
auto
M
=
c_m_n_global_desc
.
GetLengths
()[
0
];
constexpr
auto
M
=
c_m_n_global_desc
.
GetLengths
()[
I
0
];
constexpr
auto
N
=
c_m_n_global_desc
.
GetLengths
()[
1
];
constexpr
auto
N
=
c_m_n_global_desc
.
GetLengths
()[
I
1
];
// don't do anything if K == 0
// don't do anything if K == 0
if
(
K
==
0
)
if
(
K
==
0
)
...
@@ -487,8 +489,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v2
...
@@ -487,8 +489,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v2
const
auto
block_work_id
=
block_work_desc
.
CalculateClusterIndex
(
get_block_1d_id
());
const
auto
block_work_id
=
block_work_desc
.
CalculateClusterIndex
(
get_block_1d_id
());
const
index_t
m_block_data_on_global
=
block_work_id
[
0
]
*
MPerBlock
;
const
index_t
m_block_data_on_global
=
block_work_id
[
I
0
]
*
MPerBlock
;
const
index_t
n_block_data_on_global
=
block_work_id
[
1
]
*
NPerBlock
;
const
index_t
n_block_data_on_global
=
block_work_id
[
I
1
]
*
NPerBlock
;
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
...
@@ -514,7 +516,7 @@ struct GridwiseGemmTransposedANormalBNormalC_v2
...
@@ -514,7 +516,7 @@ struct GridwiseGemmTransposedANormalBNormalC_v2
AddressSpace
::
Vgpr
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Lds
,
AddressSpace
::
Lds
,
InMemoryDataOperation
::
Set
>
(
InMemoryDataOperation
::
Set
>
(
{
0
,
0
,
0
,
m_block_data_on_global
}
,
{
0
,
0
,
0
,
0
}
);
make_multi_index
(
0
,
0
,
0
,
m_block_data_on_global
)
,
make_multi_index
(
0
,
0
,
0
,
0
)
);
// B matrix in LDS memory, dst of blockwise copy
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
...
@@ -540,7 +542,7 @@ struct GridwiseGemmTransposedANormalBNormalC_v2
...
@@ -540,7 +542,7 @@ struct GridwiseGemmTransposedANormalBNormalC_v2
AddressSpace
::
Vgpr
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Lds
,
AddressSpace
::
Lds
,
InMemoryDataOperation
::
Set
>
(
InMemoryDataOperation
::
Set
>
(
{
0
,
0
,
0
,
n_block_data_on_global
}
,
{
0
,
0
,
0
,
0
}
);
make_multi_index
(
0
,
0
,
0
,
n_block_data_on_global
)
,
make_multi_index
(
0
,
0
,
0
,
0
)
);
// GEMM definition
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// c_mtx += transpose(a_mtx) * b_mtx
...
@@ -750,11 +752,11 @@ struct GridwiseGemmTransposedANormalBNormalC_v2
...
@@ -750,11 +752,11 @@ struct GridwiseGemmTransposedANormalBNormalC_v2
AddressSpace
::
Vgpr
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Global
,
AddressSpace
::
Global
,
CGlobalMemoryDataOperation
>
(
CGlobalMemoryDataOperation
>
(
{
0
,
0
,
0
,
0
}
,
make_multi_index
(
0
,
0
,
0
,
0
)
,
{
m_thread_data_on_global
/
M1
,
make_multi_index
(
m_thread_data_on_global
/
M1
,
m_thread_data_on_global
%
M1
,
m_thread_data_on_global
%
M1
,
n_thread_data_on_global
/
N1
,
n_thread_data_on_global
/
N1
,
n_thread_data_on_global
%
N1
}
)
n_thread_data_on_global
%
N1
)
)
.
Run
(
p_c_thread
,
p_c_global
);
.
Run
(
p_c_thread
,
p_c_global
);
}
}
}
}
...
...
composable_kernel/include/utility/common_header.hpp
View file @
5e88414a
...
@@ -2,9 +2,9 @@
...
@@ -2,9 +2,9 @@
#define CK_COMMON_HEADER_HPP
#define CK_COMMON_HEADER_HPP
#include "array.hpp"
#include "array.hpp"
#include "
array
_helper.hpp"
#include "
container
_helper.hpp"
#include "statically_indexed_array.hpp"
#include "statically_indexed_array.hpp"
#include "
array
_element_picker.hpp"
#include "
container
_element_picker.hpp"
#include "config.hpp"
#include "config.hpp"
#include "float_type.hpp"
#include "float_type.hpp"
#include "functional.hpp"
#include "functional.hpp"
...
...
composable_kernel/include/utility/
array
_element_picker.hpp
→
composable_kernel/include/utility/
container
_element_picker.hpp
View file @
5e88414a
#ifndef CK_
ARRAY
_ELEMENT_PICKER_HPP
#ifndef CK_
CONTAINER
_ELEMENT_PICKER_HPP
#define CK_
ARRAY
_ELEMENT_PICKER_HPP
#define CK_
CONTAINER
_ELEMENT_PICKER_HPP
#include "functional2.hpp"
#include "functional2.hpp"
#include "sequence.hpp"
#include "sequence.hpp"
...
@@ -9,16 +9,16 @@ namespace ck {
...
@@ -9,16 +9,16 @@ namespace ck {
// Arr: Array or StaticallyIndexedArray
// Arr: Array or StaticallyIndexedArray
// Picks: Sequence<...>
// Picks: Sequence<...>
template
<
typename
Arr
,
typename
Picks
>
template
<
typename
Arr
,
typename
Picks
>
struct
Array
ElementPicker
struct
Container
ElementPicker
{
{
using
type
=
Array
ElementPicker
;
using
type
=
Container
ElementPicker
;
#if 0
#if 0
using data_type = typename Arr::data_type;
using data_type = typename Arr::data_type;
#endif
#endif
__host__
__device__
constexpr
Array
ElementPicker
()
=
delete
;
__host__
__device__
constexpr
Container
ElementPicker
()
=
delete
;
__host__
__device__
explicit
constexpr
Array
ElementPicker
(
Arr
&
array
)
:
mArray
{
array
}
__host__
__device__
explicit
constexpr
Container
ElementPicker
(
Arr
&
array
)
:
mArray
{
array
}
{
{
constexpr
index_t
imax
=
reduce_on_sequence
(
Picks
{},
math
::
maxer
<
index_t
>
{},
Number
<
0
>
{});
constexpr
index_t
imax
=
reduce_on_sequence
(
Picks
{},
math
::
maxer
<
index_t
>
{},
Number
<
0
>
{});
...
@@ -72,9 +72,9 @@ struct ArrayElementPicker
...
@@ -72,9 +72,9 @@ struct ArrayElementPicker
};
};
template
<
typename
Arr
,
typename
Picks
,
typename
X
>
template
<
typename
Arr
,
typename
Picks
,
typename
X
>
__host__
__device__
constexpr
auto
operator
+=
(
Array
ElementPicker
<
Arr
,
Picks
>&
y
,
const
X
&
x
)
__host__
__device__
constexpr
auto
operator
+=
(
Container
ElementPicker
<
Arr
,
Picks
>&
y
,
const
X
&
x
)
{
{
using
Y
=
Array
ElementPicker
<
Arr
,
Picks
>
;
using
Y
=
Container
ElementPicker
<
Arr
,
Picks
>
;
constexpr
index_t
nsize
=
Y
::
Size
();
constexpr
index_t
nsize
=
Y
::
Size
();
static_assert
(
nsize
==
X
::
Size
(),
"wrong! size not the same"
);
static_assert
(
nsize
==
X
::
Size
(),
"wrong! size not the same"
);
...
@@ -85,9 +85,9 @@ __host__ __device__ constexpr auto operator+=(ArrayElementPicker<Arr, Picks>& y,
...
@@ -85,9 +85,9 @@ __host__ __device__ constexpr auto operator+=(ArrayElementPicker<Arr, Picks>& y,
}
}
template
<
typename
Arr
,
typename
Picks
,
typename
X
>
template
<
typename
Arr
,
typename
Picks
,
typename
X
>
__host__
__device__
constexpr
auto
operator
-=
(
Array
ElementPicker
<
Arr
,
Picks
>&
y
,
const
X
&
x
)
__host__
__device__
constexpr
auto
operator
-=
(
Container
ElementPicker
<
Arr
,
Picks
>&
y
,
const
X
&
x
)
{
{
using
Y
=
Array
ElementPicker
<
Arr
,
Picks
>
;
using
Y
=
Container
ElementPicker
<
Arr
,
Picks
>
;
constexpr
index_t
nsize
=
Y
::
Size
();
constexpr
index_t
nsize
=
Y
::
Size
();
static_assert
(
nsize
==
X
::
Size
(),
"wrong! size not the same"
);
static_assert
(
nsize
==
X
::
Size
(),
"wrong! size not the same"
);
...
@@ -98,9 +98,9 @@ __host__ __device__ constexpr auto operator-=(ArrayElementPicker<Arr, Picks>& y,
...
@@ -98,9 +98,9 @@ __host__ __device__ constexpr auto operator-=(ArrayElementPicker<Arr, Picks>& y,
}
}
template
<
typename
Arr
,
typename
Picks
>
template
<
typename
Arr
,
typename
Picks
>
__host__
__device__
constexpr
auto
pick_
array
_element
(
Arr
&
a
,
Picks
)
__host__
__device__
constexpr
auto
pick_
container
_element
(
Arr
&
a
,
Picks
)
{
{
return
Array
ElementPicker
<
Arr
,
Picks
>
(
a
);
return
Container
ElementPicker
<
Arr
,
Picks
>
(
a
);
}
}
}
// namespace ck
}
// namespace ck
...
...
composable_kernel/include/utility/
array
_helper.hpp
→
composable_kernel/include/utility/
container
_helper.hpp
View file @
5e88414a
#ifndef CK_
ARRAY
_HELPER_HPP
#ifndef CK_
CONTAINER
_HELPER_HPP
#define CK_
ARRAY
_HELPER_HPP
#define CK_
CONTAINER
_HELPER_HPP
#include "sequence.hpp"
#include "sequence.hpp"
#include "sequence_helper.hpp"
#include "sequence_helper.hpp"
#include "array.hpp"
#include "array.hpp"
#include "array_helper.hpp"
#include "tuple.hpp"
#include "tuple.hpp"
#include "tuple_helper.hpp"
#include "tuple_helper.hpp"
#include "statically_indexed_array.hpp"
#include "statically_indexed_array.hpp"
#include "
array
_element_picker.hpp"
#include "
container
_element_picker.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -24,6 +23,18 @@ __host__ __device__ constexpr auto container_push_back(const Array<TData, NSize>
...
@@ -24,6 +23,18 @@ __host__ __device__ constexpr auto container_push_back(const Array<TData, NSize>
return
r
;
return
r
;
}
}
template
<
typename
...
Ts
,
typename
T
>
__host__
__device__
constexpr
auto
container_push_back
(
const
Tuple
<
Ts
...
>&
a
,
const
T
&
x
)
{
Tuple
<
Ts
...,
T
>
r
;
static_for
<
0
,
sizeof
...(
Ts
),
1
>
{}([
&
r
,
&
a
](
auto
i
)
constexpr
{
r
(
i
)
=
a
[
i
];
});
r
(
Number
<
sizeof
...(
Ts
)
>
{})
=
x
;
return
r
;
}
template
<
typename
TData
,
index_t
NSize
,
index_t
...
IRs
>
template
<
typename
TData
,
index_t
NSize
,
index_t
...
IRs
>
__host__
__device__
constexpr
auto
__host__
__device__
constexpr
auto
container_reorder_given_new2old
(
const
Array
<
TData
,
NSize
>&
old_array
,
Sequence
<
IRs
...
>
/*new2old*/
)
container_reorder_given_new2old
(
const
Array
<
TData
,
NSize
>&
old_array
,
Sequence
<
IRs
...
>
/*new2old*/
)
...
...
composable_kernel/include/utility/functional3.hpp
View file @
5e88414a
#ifndef CK_FUNCTIONAL3_HPP
#ifndef CK_FUNCTIONAL3_HPP
#define CK_FUNCTIONAL3_HPP
#define CK_FUNCTIONAL3_HPP
#include "array.hpp"
#include "functional.hpp"
#include "functional.hpp"
#include "functional2.hpp"
#include "functional2.hpp"
#include "sequence.hpp"
#include "sequence.hpp"
#include "multi_index.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -133,7 +133,7 @@ struct ford
...
@@ -133,7 +133,7 @@ struct ford
for
(
index_t
i
=
0
;
i
<
ordered_lengths
.
Front
();
++
i
)
for
(
index_t
i
=
0
;
i
<
ordered_lengths
.
Front
();
++
i
)
{
{
detail
::
ford_impl
<
decltype
(
ordered_lengths
.
PopFront
()),
Orders
>
{}(
f
,
detail
::
ford_impl
<
decltype
(
ordered_lengths
.
PopFront
()),
Orders
>
{}(
f
,
Array
<
index_t
,
1
>
{
i
}
);
make_multi_index
(
i
)
);
}
}
}
}
};
};
...
...
composable_kernel/include/utility/print.hpp
View file @
5e88414a
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
#include "array.hpp"
#include "array.hpp"
#include "statically_indexed_array.hpp"
#include "statically_indexed_array.hpp"
#include "
array
_helper.hpp"
#include "
container
_helper.hpp"
#include "sequence.hpp"
#include "sequence.hpp"
namespace
ck
{
namespace
ck
{
...
...
composable_kernel/include/utility/tuple.hpp
View file @
5e88414a
...
@@ -19,9 +19,11 @@ struct TupleElement
...
@@ -19,9 +19,11 @@ struct TupleElement
{
{
__host__
__device__
explicit
constexpr
TupleElement
()
:
mData
()
{}
__host__
__device__
explicit
constexpr
TupleElement
()
:
mData
()
{}
#if 0
__host__ __device__ explicit constexpr TupleElement(const TupleElement&) = default;
__host__ __device__ explicit constexpr TupleElement(const TupleElement&) = default;
__host__ __device__ explicit constexpr TupleElement(TupleElement&&) = default;
__host__ __device__ explicit constexpr TupleElement(TupleElement&&) = default;
#endif
template
<
typename
UData
>
template
<
typename
UData
>
__host__
__device__
explicit
constexpr
TupleElement
(
const
TupleElement
<
Key
,
UData
>&
te
)
__host__
__device__
explicit
constexpr
TupleElement
(
const
TupleElement
<
Key
,
UData
>&
te
)
...
@@ -73,9 +75,11 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs>
...
@@ -73,9 +75,11 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs>
static_assert
(
sizeof
...(
Is
)
==
sizeof
...(
Xs
),
"wrong! inconsistent size"
);
static_assert
(
sizeof
...(
Is
)
==
sizeof
...(
Xs
),
"wrong! inconsistent size"
);
}
}
#if 0
__host__ __device__ explicit constexpr TupleImpl(const TupleImpl&) = default;
__host__ __device__ explicit constexpr TupleImpl(const TupleImpl&) = default;
__host__ __device__ explicit constexpr TupleImpl(TupleImpl&&) = default;
__host__ __device__ explicit constexpr TupleImpl(TupleImpl&&) = default;
#endif
template
<
index_t
...
Js
,
typename
...
Ys
>
template
<
index_t
...
Js
,
typename
...
Ys
>
__host__
__device__
explicit
constexpr
TupleImpl
(
const
TupleImpl
<
Sequence
<
Js
...
>
,
Ys
...
>&
y
)
__host__
__device__
explicit
constexpr
TupleImpl
(
const
TupleImpl
<
Sequence
<
Js
...
>
,
Ys
...
>&
y
)
...
@@ -124,9 +128,11 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
...
@@ -124,9 +128,11 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
__host__
__device__
explicit
constexpr
Tuple
()
:
base
()
{}
__host__
__device__
explicit
constexpr
Tuple
()
:
base
()
{}
#if 0
__host__ __device__ constexpr Tuple(const Tuple&) = default;
__host__ __device__ constexpr Tuple(const Tuple&) = default;
__host__ __device__ constexpr Tuple(Tuple&&) = default;
__host__ __device__ constexpr Tuple(Tuple&&) = default;
#endif
template
<
typename
...
Ys
,
template
<
typename
...
Ys
,
typename
std
::
enable_if
<
sizeof
...(
Ys
)
==
sizeof
...(
Xs
),
bool
>
::
type
=
false
>
typename
std
::
enable_if
<
sizeof
...(
Ys
)
==
sizeof
...(
Xs
),
bool
>
::
type
=
false
>
...
...
composable_kernel/include/utility/tuple_helper.hpp
View file @
5e88414a
#ifndef CK_TUPLE_HELPER_HPP
#ifndef CK_TUPLE_HELPER_HPP
#define CK_TUPLE_HELPER_HPP
#define CK_TUPLE_HELPER_HPP
#include "tuple_helper.hpp"
#include "functional4.hpp"
#include "tuple.hpp"
namespace
ck
{
namespace
ck
{
...
...
driver/include/device_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk.hpp
View file @
5e88414a
...
@@ -222,7 +222,7 @@ void device_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk(InDesc i
...
@@ -222,7 +222,7 @@ void device_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk(InDesc i
static_for
<
0
,
GridwiseConvBwdData
::
GetNumberOfGemm
(),
1
>
{}([
&
](
auto
gemm_id
)
{
static_for
<
0
,
GridwiseConvBwdData
::
GetNumberOfGemm
(),
1
>
{}([
&
](
auto
gemm_id
)
{
constexpr
auto
gemm_sizes
=
GridwiseConvBwdData
::
GetGemmSize
(
gemm_id
);
constexpr
auto
gemm_sizes
=
GridwiseConvBwdData
::
GetGemmSize
(
gemm_id
);
constexpr
index_t
gemm_k2
=
gemm_sizes
.
At
(
4
)
;
constexpr
index_t
gemm_k2
=
gemm_sizes
[
Number
<
4
>
{}]
;
constexpr
bool
is_gemm_not_empty
=
gemm_k2
>
0
;
constexpr
bool
is_gemm_not_empty
=
gemm_k2
>
0
;
// only compile and run if GEMM is no empty
// only compile and run if GEMM is no empty
...
...
driver/src/conv_bwd_data_driver.cpp
View file @
5e88414a
...
@@ -245,7 +245,7 @@ int main(int argc, char* argv[])
...
@@ -245,7 +245,7 @@ int main(int argc, char* argv[])
device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw
device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw
#elif
0
#elif
0
device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw
device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw
#elif
1
#elif
0
device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw
device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw
#elif 1
#elif 1
device_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk
device_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk
...
...
driver/src/conv_driver.cpp
View file @
5e88414a
...
@@ -561,7 +561,7 @@ int main(int argc, char* argv[])
...
@@ -561,7 +561,7 @@ int main(int argc, char* argv[])
LeftPads{},
LeftPads{},
RightPads{},
RightPads{},
nrepeat);
nrepeat);
#elif
0
#elif
1
device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw
(
in_nchw_desc
,
device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw
(
in_nchw_desc
,
in_nchw
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx_desc
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment