Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
674c405f
Commit
674c405f
authored
Sep 25, 2020
by
Chao Liu
Browse files
refactor
parent
ffa7e4be
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
574 additions
and
351 deletions
+574
-351
composable_kernel/include/kernel_algorithm/dummy_dynamic_transform_v2.hpp
...l/include/kernel_algorithm/dummy_dynamic_transform_v2.hpp
+24
-19
composable_kernel/include/tensor_description/dynamic_multi_index_transform.hpp
...lude/tensor_description/dynamic_multi_index_transform.hpp
+33
-47
composable_kernel/include/tensor_description/dynamic_tensor_descriptor_v2.hpp
...clude/tensor_description/dynamic_tensor_descriptor_v2.hpp
+3
-2
composable_kernel/include/utility/array.hpp
composable_kernel/include/utility/array.hpp
+18
-71
composable_kernel/include/utility/array_element_picker.hpp
composable_kernel/include/utility/array_element_picker.hpp
+99
-0
composable_kernel/include/utility/array_helper.hpp
composable_kernel/include/utility/array_helper.hpp
+2
-24
composable_kernel/include/utility/common_header.hpp
composable_kernel/include/utility/common_header.hpp
+2
-0
composable_kernel/include/utility/print.hpp
composable_kernel/include/utility/print.hpp
+6
-186
composable_kernel/include/utility/statically_indexed_array.hpp
...sable_kernel/include/utility/statically_indexed_array.hpp
+362
-0
composable_kernel/include/utility/tuple.hpp
composable_kernel/include/utility/tuple.hpp
+24
-0
driver/src/conv_bwd_data_driver.cu
driver/src/conv_bwd_data_driver.cu
+0
-1
driver/src/conv_driver.cpp
driver/src/conv_driver.cpp
+1
-1
No files found.
composable_kernel/include/kernel_algorithm/dummy_dynamic_transform_v2.hpp
View file @
674c405f
...
...
@@ -17,29 +17,34 @@ map_convolution_into_gemm_v2(const WeiDesc& wei_k_c_y_x_global_desc,
const
Array
<
index_t
,
2
>
in_left_pads
,
const
Array
<
index_t
,
2
>
in_right_pads
)
{
const
index_t
N
=
in_n_c_hi_wi_global_desc
.
GetLength
(
0
);
const
index_t
C
=
in_n_c_hi_wi_global_desc
.
GetLength
(
1
);
const
index_t
K
=
out_n_k_ho_wo_global_desc
.
GetLength
(
1
);
constexpr
auto
i0
=
Number
<
0
>
{};
constexpr
auto
i1
=
Number
<
1
>
{};
constexpr
auto
i2
=
Number
<
2
>
{};
constexpr
auto
i3
=
Number
<
3
>
{};
const
index_t
Y
=
wei_k_c_y_x_global_desc
.
GetLength
(
2
);
const
index_t
X
=
wei_k_c_y_x_global_desc
.
GetLength
(
3
);
const
index_t
N
=
in_n_c_hi_wi_global_desc
.
GetLength
(
i0
);
const
index_t
C
=
in_n_c_hi_wi_global_desc
.
GetLength
(
i1
);
const
index_t
K
=
out_n_k_ho_wo_global_desc
.
GetLength
(
i1
);
const
index_t
Hi
=
in_n_c_hi_wi
_global_desc
.
GetLength
(
2
);
const
index_t
Wi
=
in_n_c_hi_wi
_global_desc
.
GetLength
(
3
);
const
index_t
Y
=
wei_k_c_y_x
_global_desc
.
GetLength
(
i
2
);
const
index_t
X
=
wei_k_c_y_x
_global_desc
.
GetLength
(
i
3
);
const
index_t
H
o
=
out
_n_
k
_h
o
_w
o
_global_desc
.
GetLength
(
2
);
const
index_t
W
o
=
out
_n_
k
_h
o
_w
o
_global_desc
.
GetLength
(
3
);
const
index_t
H
i
=
in
_n_
c
_h
i
_w
i
_global_desc
.
GetLength
(
i
2
);
const
index_t
W
i
=
in
_n_
c
_h
i
_w
i
_global_desc
.
GetLength
(
i
3
);
const
index_t
ConvStrideH
=
conv_strides
[
0
]
;
const
index_t
ConvStrideW
=
conv_strides
[
1
]
;
const
index_t
Ho
=
out_n_k_ho_wo_global_desc
.
GetLength
(
i2
)
;
const
index_t
Wo
=
out_n_k_ho_wo_global_desc
.
GetLength
(
i3
)
;
const
index_t
Conv
Dilation
H
=
conv_
dilation
s
[
0
];
const
index_t
Conv
Dilation
W
=
conv_
dilation
s
[
1
];
const
index_t
Conv
Stride
H
=
conv_
stride
s
[
i
0
];
const
index_t
Conv
Stride
W
=
conv_
stride
s
[
i
1
];
const
index_t
InLeftPadH
=
in_left_pads
[
0
];
const
index_t
InLeftPadW
=
in_left_pads
[
1
];
const
index_t
InRightPadH
=
in_right_pads
[
0
];
const
index_t
InRightPadW
=
in_right_pads
[
1
];
const
index_t
ConvDilationH
=
conv_dilations
[
i0
];
const
index_t
ConvDilationW
=
conv_dilations
[
i1
];
const
index_t
InLeftPadH
=
in_left_pads
[
i0
];
const
index_t
InLeftPadW
=
in_left_pads
[
i1
];
const
index_t
InRightPadH
=
in_right_pads
[
i0
];
const
index_t
InRightPadW
=
in_right_pads
[
i1
];
// input tensor
const
auto
in_n_c_hip_wip_global_desc
=
transform_dynamic_tensor_descriptor_v2
(
...
...
@@ -58,8 +63,8 @@ map_convolution_into_gemm_v2(const WeiDesc& wei_k_c_y_x_global_desc,
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
index_t
Hip
=
in_n_c_hip_wip_global_desc
.
GetLength
(
2
);
const
index_t
Wip
=
in_n_c_hip_wip_global_desc
.
GetLength
(
3
);
const
index_t
Hip
=
in_n_c_hip_wip_global_desc
.
GetLength
(
i
2
);
const
index_t
Wip
=
in_n_c_hip_wip_global_desc
.
GetLength
(
i
3
);
const
auto
in_n_c_y_ho_x_wo_global_desc
=
transform_dynamic_tensor_descriptor_v2
(
in_n_c_hip_wip_global_desc
,
...
...
composable_kernel/include/tensor_description/dynamic_multi_index_transform.hpp
View file @
674c405f
...
...
@@ -31,7 +31,7 @@ struct DynamicPassThrough
static_assert
(
LowIdx
::
Size
()
==
1
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
idx_low
(
0
)
=
idx_up
[
0
];
idx_low
(
Number
<
0
>
{})
=
idx_up
[
Number
<
0
>
{}
];
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
>
...
...
@@ -44,7 +44,7 @@ struct DynamicPassThrough
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
idx_diff_low
(
0
)
=
idx_diff_up
[
0
];
idx_diff_low
(
Number
<
0
>
{}
)
=
idx_diff_up
[
Number
<
0
>
{}
];
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
...
...
@@ -92,7 +92,7 @@ struct DynamicLeftPad
static_assert
(
LowIdx
::
Size
()
==
1
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
idx_low
(
0
)
=
idx_up
[
0
]
-
left_pad_
;
idx_low
(
Number
<
0
>
{})
=
idx_up
[
Number
<
0
>
{}
]
-
left_pad_
;
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
>
...
...
@@ -106,7 +106,7 @@ struct DynamicLeftPad
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
idx_diff_low
(
0
)
=
idx_diff_up
[
0
];
idx_diff_low
(
Number
<
0
>
{}
)
=
idx_diff_up
[
Number
<
0
>
{}
];
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
...
...
@@ -120,7 +120,7 @@ struct DynamicLeftPad
__host__
__device__
constexpr
bool
IsValidUpperIndexMappedToValidLowerIndex
(
const
UpIdx
&
idx_up
)
const
{
return
SkipIsValidCheck
||
(
idx_up
[
0
]
>=
left_pad_
);
return
SkipIsValidCheck
||
(
idx_up
[
Number
<
0
>
{}
]
>=
left_pad_
);
}
};
...
...
@@ -158,7 +158,7 @@ struct DynamicRightPad
static_assert
(
LowIdx
::
Size
()
==
1
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
idx_low
(
0
)
=
idx_up
[
0
];
idx_low
(
Number
<
0
>
{})
=
idx_up
[
Number
<
0
>
{}
];
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
>
...
...
@@ -172,7 +172,7 @@ struct DynamicRightPad
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
idx_diff_low
(
0
)
=
idx_diff_up
[
0
];
idx_diff_low
(
Number
<
0
>
{}
)
=
idx_diff_up
[
Number
<
0
>
{}
];
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
...
...
@@ -186,7 +186,7 @@ struct DynamicRightPad
__host__
__device__
constexpr
bool
IsValidUpperIndexMappedToValidLowerIndex
(
const
UpIdx
&
idx_up
)
const
{
return
SkipIsValidCheck
||
(
idx_up
[
0
]
<
low_length_
);
return
SkipIsValidCheck
||
(
idx_up
[
Number
<
0
>
{}
]
<
low_length_
);
}
};
...
...
@@ -228,13 +228,11 @@ struct DynamicEmbed
static_assert
(
LowIdx
::
Size
()
==
1
&&
UpIdx
::
Size
()
==
NDimUp
,
"wrong! inconsistent # of dimension"
);
idx_low
(
0
)
=
coefficients_
[
NDimUp
];
idx_low
(
Number
<
0
>
{}
)
=
coefficients_
[
Number
<
NDimUp
>
{}
];
#pragma unroll
for
(
index_t
i
=
0
;
i
<
NDimUp
;
++
i
)
{
idx_low
(
0
)
+=
idx_up
[
i
]
*
coefficients_
[
i
];
}
static_for
<
0
,
NDimUp
,
1
>
{}([
&
idx_low
,
&
idx_up
,
this
](
auto
i
)
{
idx_low
(
Number
<
0
>
{})
+=
idx_up
[
i
]
*
this
->
coefficients_
[
i
];
});
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
>
...
...
@@ -247,13 +245,10 @@ struct DynamicEmbed
LowIdx
::
Size
()
==
1
&&
UpIdx
::
Size
()
==
NDimUp
,
"wrong! inconsistent # of dimension"
);
idx_diff_low
(
0
)
=
0
;
idx_diff_low
(
Number
<
0
>
{}
)
=
0
;
#pragma unroll
for
(
index_t
i
=
0
;
i
<
NDimUp
;
++
i
)
{
idx_diff_low
(
0
)
+=
idx_diff_up
[
i
]
*
coefficients_
[
i
];
}
static_for
<
0
,
NDimUp
,
1
>
{}(
[
&
](
auto
i
)
{
idx_diff_low
(
Number
<
0
>
{})
+=
idx_diff_up
[
i
]
*
coefficients_
[
i
];
});
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
...
...
@@ -310,16 +305,14 @@ struct DynamicMerge
static_assert
(
LowIdx
::
Size
()
==
NDimLow
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
index_t
tmp
=
idx_up
[
0
];
index_t
tmp
=
idx_up
[
Number
<
0
>
{}
];
#pragma unroll
for
(
index_t
i
=
0
;
i
<
NDimLow
-
1
;
++
i
)
{
idx_low
(
i
)
=
tmp
/
low_lengths_scan_
[
i
];
tmp
-=
idx_low
[
i
]
*
low_lengths_scan_
[
i
];
}
static_for
<
0
,
NDimLow
-
1
,
1
>
{}([
&
idx_low
,
&
tmp
,
this
](
auto
i
)
{
idx_low
(
i
)
=
tmp
/
this
->
low_lengths_scan_
[
i
];
tmp
-=
idx_low
[
i
]
*
this
->
low_lengths_scan_
[
i
];
});
idx_low
(
NDimLow
-
1
)
=
tmp
;
idx_low
(
Number
<
NDimLow
-
1
>
{}
)
=
tmp
;
}
// idx_diff_low depends on idx_low_old, so idx_low need to be up-to-date
...
...
@@ -336,15 +329,13 @@ struct DynamicMerge
LowIdx
::
Size
()
==
NDimLow
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
#if
1
#if
0
// I only want to do this check, if idx_diff_up is know at compile-time
if
(
idx_diff_up
[
0
]
==
0
)
{
#pragma unroll
for
(
index_t
i
=
0
;
i
<
NDimLow
;
++
i
)
if(idx_diff_up[Number<0>{}] == 0)
{
static_for<0, NDimLow, 1>{}([&idx_diff_low](auto i){
idx_diff_low(i) = 0;
}
}
);
return;
}
...
...
@@ -370,9 +361,7 @@ struct DynamicMerge
// do not need to check the first dimension
index_t
carry
=
0
;
#pragma unroll
for
(
index_t
i
=
NDimLow
-
1
;
i
>
0
;
--
i
)
{
static_for
<
NDimLow
-
1
,
0
,
-
1
>
{}([
&
](
auto
i
)
{
// this should be saved in SGPR as well
index_t
idx_low_length_minus_idx_diff_low_const
=
low_lengths_
[
i
]
-
idx_diff_low_const
[
i
];
...
...
@@ -401,9 +390,9 @@ struct DynamicMerge
#if 0
carry = do_borrow ? -1 : carry;
#endif
}
}
);
idx_diff_low
(
0
)
=
idx_diff_low_const
[
0
]
+
carry
;
idx_diff_low
(
Number
<
0
>
{}
)
=
idx_diff_low_const
[
Number
<
0
>
{}
]
+
carry
;
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
false
;
}
...
...
@@ -453,13 +442,10 @@ struct DynamicUnMerge
__host__
__device__
constexpr
void
CalculateLowerIndex
(
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
const
{
idx_low
(
0
)
=
idx_up
[
NDimUp
];
idx_low
(
Number
<
0
>
{}
)
=
idx_up
[
Number
<
NDimUp
>
{}
];
#pragma unroll
for
(
index_t
i
=
0
;
i
<
NDimUp
-
1
;
++
i
)
{
idx_low
(
0
)
+=
idx_up
[
i
]
*
up_lengths_scan_
[
i
];
}
static_for
<
0
,
NDimUp
-
1
,
1
>
{}(
[
&
](
auto
i
)
{
idx_low
(
Number
<
0
>
{})
+=
idx_up
[
i
]
*
up_lengths_scan_
[
i
];
});
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
>
...
...
@@ -512,7 +498,7 @@ struct DynamicFreeze
static_assert
(
LowIdx
::
Size
()
==
1
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
idx_low
(
0
)
=
low_idx_
;
idx_low
(
Number
<
0
>
{}
)
=
low_idx_
;
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
>
...
...
@@ -521,7 +507,7 @@ struct DynamicFreeze
const
LowIdx
&
/* idx_low_old */
,
const
UpIdx
&
/* idx_up_old */
)
{
idx_diff_low
(
0
)
=
index_t
{
0
};
idx_diff_low
(
Number
<
0
>
{}
)
=
index_t
{
Number
<
0
>
{}
};
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
...
...
composable_kernel/include/tensor_description/dynamic_tensor_descriptor_v2.hpp
View file @
674c405f
...
...
@@ -105,9 +105,10 @@ struct DynamicTensorDescriptor_v2
return
GetNumOfVisibleDimension
();
}
__host__
__device__
constexpr
index_t
GetLength
(
index_t
idim
)
const
template
<
index_t
IDim
>
__host__
__device__
constexpr
index_t
GetLength
(
Number
<
IDim
>
)
const
{
return
visible_lengths_
[
idim
];
return
visible_lengths_
[
Number
<
IDim
>
{}
];
}
__host__
__device__
constexpr
const
auto
&
GetLengths
()
const
{
return
visible_lengths_
;
}
...
...
composable_kernel/include/utility/array.hpp
View file @
674c405f
...
...
@@ -161,82 +161,29 @@ struct Array
}
};
// Arr: Array
// Picks: Sequence<...>
template
<
typename
Arr
,
typename
Picks
>
struct
ArrayElementPicker
template
<
typename
X
,
typename
...
Xs
>
__host__
__device__
constexpr
auto
make_array
(
const
X
&
x
,
const
Xs
&
...
xs
)
{
using
type
=
ArrayElementPicker
;
using
data_type
=
typename
Arr
::
data_type
;
return
Array
<
X
,
sizeof
...(
xs
)
+
1
>
{{
x
,
xs
...}}
;
}
__host__
__device__
constexpr
ArrayElementPicker
()
=
delete
;
__host__
__device__
explicit
constexpr
ArrayElementPicker
(
Arr
&
array
)
:
mArray
{
array
}
{
constexpr
index_t
imax
=
reduce_on_sequence
(
Picks
{},
math
::
maxer
<
index_t
>
{},
Number
<
0
>
{});
static_assert
(
imax
<
Arr
::
Size
(),
"wrong! exceeding # array element"
);
}
__host__
__device__
static
constexpr
auto
Size
()
{
return
Picks
::
Size
();
}
template
<
index_t
I
>
__host__
__device__
constexpr
const
data_type
&
At
(
Number
<
I
>
)
const
{
static_assert
(
I
<
Size
(),
"wrong!"
);
constexpr
auto
IP
=
Picks
{}[
I
];
return
mArray
[
IP
];
}
template
<
index_t
I
>
__host__
__device__
constexpr
data_type
&
At
(
Number
<
I
>
)
{
static_assert
(
I
<
Size
(),
"wrong!"
);
constexpr
auto
IP
=
Picks
{}[
I
];
return
mArray
(
IP
);
}
__host__
__device__
constexpr
const
data_type
&
operator
[](
index_t
i
)
const
{
index_t
ip
=
Picks
{}[
i
];
return
mArray
[
ip
];
}
__host__
__device__
constexpr
data_type
&
operator
()(
index_t
i
)
{
index_t
ip
=
Picks
{}[
i
];
return
mArray
(
ip
);
}
template
<
typename
T
>
__host__
__device__
constexpr
auto
operator
=
(
const
T
&
a
)
{
static_for
<
0
,
Size
(),
1
>
{}([
&
](
auto
i
)
{
operator
()(
i
)
=
a
[
i
];
});
return
*
this
;
}
template
<
typename
T
>
__host__
__device__
constexpr
auto
operator
+=
(
const
T
&
a
)
{
static_for
<
0
,
Size
(),
1
>
{}([
&
](
auto
i
)
{
operator
()(
i
)
+=
a
[
i
];
});
return
*
this
;
}
template
<
typename
T
>
__host__
__device__
constexpr
auto
to_array
(
const
T
&
x
)
{
Array
<
typename
T
::
data_type
,
T
::
Size
()
>
y
;
template
<
typename
T
>
__host__
__device__
constexpr
auto
operator
-=
(
const
T
&
a
)
{
static_for
<
0
,
Size
(),
1
>
{}([
&
](
auto
i
)
{
operator
()(
i
)
-=
a
[
i
];
});
static_for
<
0
,
T
::
Size
(),
1
>
{}([
&
](
auto
i
)
{
y
.
At
(
i
)
=
x
.
At
(
i
);
});
return
*
this
;
}
return
y
;
}
private:
Arr
&
mArray
;
};
template
<
typename
TData
,
index_t
NSize
>
__host__
__device__
constexpr
auto
make_zero_array
()
{
constexpr
auto
zero_sequence
=
typename
uniform_sequence_gen
<
NSize
,
0
>::
type
{};
constexpr
auto
zero_array
=
to_array
(
zero_sequence
);
return
zero_array
;
}
}
// namespace ck
#endif
composable_kernel/include/utility/array_element_picker.hpp
0 → 100644
View file @
674c405f
#ifndef CK_ARRAY_ELEMENT_PICKER_HPP
#define CK_ARRAY_ELEMENT_PICKER_HPP
#include "functional2.hpp"
#include "sequence.hpp"
namespace
ck
{
// Arr: Array or StaticallyIndexedArray
// Picks: Sequence<...>
template
<
typename
Arr
,
typename
Picks
>
struct
ArrayElementPicker
{
using
type
=
ArrayElementPicker
;
using
data_type
=
typename
Arr
::
data_type
;
__host__
__device__
constexpr
ArrayElementPicker
()
=
delete
;
__host__
__device__
explicit
constexpr
ArrayElementPicker
(
Arr
&
array
)
:
mArray
{
array
}
{
constexpr
index_t
imax
=
reduce_on_sequence
(
Picks
{},
math
::
maxer
<
index_t
>
{},
Number
<
0
>
{});
static_assert
(
imax
<
Arr
::
Size
(),
"wrong! exceeding # array element"
);
}
__host__
__device__
static
constexpr
auto
Size
()
{
return
Picks
::
Size
();
}
template
<
index_t
I
>
__host__
__device__
constexpr
const
data_type
&
At
(
Number
<
I
>
)
const
{
static_assert
(
I
<
Size
(),
"wrong!"
);
constexpr
auto
IP
=
Picks
{}[
I
];
return
mArray
[
IP
];
}
template
<
index_t
I
>
__host__
__device__
constexpr
data_type
&
At
(
Number
<
I
>
)
{
static_assert
(
I
<
Size
(),
"wrong!"
);
constexpr
auto
IP
=
Picks
{}[
I
];
return
mArray
(
IP
);
}
template
<
index_t
I
>
__host__
__device__
constexpr
const
auto
&
operator
[](
Number
<
I
>
i
)
const
{
return
At
(
i
);
}
template
<
index_t
I
>
__host__
__device__
constexpr
auto
&
operator
()(
Number
<
I
>
i
)
{
return
At
(
i
);
}
template
<
typename
T
>
__host__
__device__
constexpr
auto
operator
=
(
const
T
&
a
)
{
static_assert
(
T
::
Size
()
==
Size
(),
"wrong! size not the same"
);
static_for
<
0
,
Size
(),
1
>
{}([
&
](
auto
i
)
{
operator
()(
i
)
=
a
[
i
];
});
return
*
this
;
}
private:
Arr
&
mArray
;
};
template
<
typename
Arr
,
typename
Picks
,
typename
X
>
__host__
__device__
constexpr
auto
operator
+=
(
ArrayElementPicker
<
Arr
,
Picks
>&
y
,
const
X
&
x
)
{
using
Y
=
ArrayElementPicker
<
Arr
,
Picks
>
;
constexpr
index_t
nsize
=
Y
::
Size
();
static_assert
(
nsize
==
X
::
Size
(),
"wrong! size not the same"
);
static_for
<
0
,
nsize
,
1
>
{}([
&
](
auto
i
)
{
y
(
i
)
+=
x
[
i
];
});
return
y
;
}
template
<
typename
Arr
,
typename
Picks
,
typename
X
>
__host__
__device__
constexpr
auto
operator
-=
(
ArrayElementPicker
<
Arr
,
Picks
>&
y
,
const
X
&
x
)
{
using
Y
=
ArrayElementPicker
<
Arr
,
Picks
>
;
constexpr
index_t
nsize
=
Y
::
Size
();
static_assert
(
nsize
==
X
::
Size
(),
"wrong! size not the same"
);
static_for
<
0
,
nsize
,
1
>
{}([
&
](
auto
i
)
{
y
(
i
)
-=
x
[
i
];
});
return
y
;
}
}
// namespace ck
#endif
composable_kernel/include/utility/array_helper.hpp
View file @
674c405f
...
...
@@ -2,39 +2,17 @@
#define CK_ARRAY_HELPER_HPP
#include "array.hpp"
#include "statically_indexed_array.hpp"
#include "array_element_picker.hpp"
namespace
ck
{
template
<
typename
X
,
typename
...
Xs
>
__host__
__device__
constexpr
auto
make_array
(
const
X
&
x
,
const
Xs
&
...
xs
)
{
return
Array
<
X
,
sizeof
...(
xs
)
+
1
>
{{
x
,
xs
...}};
}
template
<
typename
Arr
,
typename
Picks
>
__host__
__device__
constexpr
auto
pick_array_element
(
Arr
&
a
,
Picks
)
{
return
ArrayElementPicker
<
Arr
,
Picks
>
(
a
);
}
template
<
typename
T
>
__host__
__device__
constexpr
auto
to_array
(
const
T
&
x
)
{
Array
<
typename
T
::
data_type
,
T
::
Size
()
>
y
;
static_for
<
0
,
T
::
Size
(),
1
>
{}([
&
](
auto
i
)
{
y
.
At
(
i
)
=
x
.
At
(
i
);
});
return
y
;
}
template
<
typename
TData
,
index_t
NSize
>
__host__
__device__
constexpr
auto
make_zero_array
()
{
constexpr
auto
zero_sequence
=
typename
uniform_sequence_gen
<
NSize
,
0
>::
type
{};
constexpr
auto
zero_array
=
to_array
(
zero_sequence
);
return
zero_array
;
}
template
<
typename
TData
,
index_t
NSize
,
index_t
...
IRs
>
__host__
__device__
constexpr
auto
reorder_array_given_new2old
(
const
Array
<
TData
,
NSize
>&
old_array
,
Sequence
<
IRs
...
>
/*new2old*/
)
...
...
composable_kernel/include/utility/common_header.hpp
View file @
674c405f
...
...
@@ -3,6 +3,8 @@
#include "array.hpp"
#include "array_helper.hpp"
#include "statically_indexed_array.hpp"
#include "array_element_picker.hpp"
#include "config.hpp"
#include "float_type.hpp"
#include "functional.hpp"
...
...
composable_kernel/include/utility/print.hpp
View file @
674c405f
...
...
@@ -14,197 +14,17 @@ __host__ __device__ void print_array(const char* s, T a)
constexpr
index_t
nsize
=
a
.
Size
();
if
constexpr
(
is_same
<
data_type
,
uint32_t
>
{})
{
if
constexpr
(
nsize
==
0
)
{
printf
(
"%s size %u
\n
"
,
s
,
nsize
);
}
else
if
constexpr
(
nsize
==
1
)
{
printf
(
"%s size %u, {%u}
\n
"
,
s
,
nsize
,
a
[
0
]);
}
else
if
constexpr
(
nsize
==
2
)
{
printf
(
"%s size %u, {%u %u}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
]);
}
else
if
constexpr
(
nsize
==
3
)
{
printf
(
"%s size %u, {%u %u %u}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
],
a
[
2
]);
}
else
if
constexpr
(
nsize
==
4
)
{
printf
(
"%s size %u, {%u %u %u %u}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
],
a
[
2
],
a
[
3
]);
}
else
if
constexpr
(
nsize
==
5
)
{
printf
(
"%s size %u, {%u %u %u %u %u}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
],
a
[
2
],
a
[
3
],
a
[
4
]);
}
else
if
constexpr
(
nsize
==
6
)
{
printf
(
"%s size %u, {%u %u %u %u %u %u}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
],
a
[
2
],
a
[
3
],
a
[
4
],
a
[
5
]);
}
else
if
constexpr
(
nsize
==
7
)
{
printf
(
"%s size %u, {%u %u %u %u %u %u %u}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
],
a
[
2
],
a
[
3
],
a
[
4
],
a
[
5
],
a
[
6
]);
}
else
if
constexpr
(
nsize
==
8
)
{
printf
(
"%s size %u, {%u %u %u %u %u %u %u %u}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
],
a
[
2
],
a
[
3
],
a
[
4
],
a
[
5
],
a
[
6
],
a
[
7
]);
}
else
if
constexpr
(
nsize
==
9
)
{
printf
(
"%s size %u, {%u %u %u %u %u %u %u %u %u}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
],
a
[
2
],
a
[
3
],
a
[
4
],
a
[
5
],
a
[
6
],
a
[
7
],
a
[
8
]);
}
else
if
constexpr
(
nsize
==
10
)
{
printf
(
"%s size %u, {%u %u %u %u %u %u %u %u %u %u}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
],
a
[
2
],
a
[
3
],
a
[
4
],
a
[
5
],
a
[
6
],
a
[
7
],
a
[
8
],
a
[
9
]);
}
else
{
printf
(
"%s size %u, {"
,
s
,
nsize
);
static_for
<
0
,
nsize
,
1
>
{}([
&
a
](
auto
i
)
constexpr
{
printf
(
"%u, "
,
a
[
i
]);
});
printf
(
"}
\n
"
);
}
}
else
if
constexpr
(
is_same
<
data_type
,
int32_t
>
{})
{
if
constexpr
(
nsize
==
0
)
{
printf
(
"%s size %d
\n
"
,
s
,
nsize
);
}
else
if
constexpr
(
nsize
==
1
)
{
printf
(
"%s size %d, {%d}
\n
"
,
s
,
nsize
,
a
[
0
]);
}
else
if
constexpr
(
nsize
==
2
)
{
printf
(
"%s size %d, {%d %d}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
]);
}
else
if
constexpr
(
nsize
==
3
)
{
printf
(
"%s size %d, {%d %d %d}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
],
a
[
2
]);
}
else
if
constexpr
(
nsize
==
4
)
{
printf
(
"%s size %d, {%d %d %d %d}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
],
a
[
2
],
a
[
3
]);
}
else
if
constexpr
(
nsize
==
5
)
{
printf
(
"%s size %d, {%d %d %d %d %d}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
],
a
[
2
],
a
[
3
],
a
[
4
]);
}
else
if
constexpr
(
nsize
==
6
)
{
printf
(
"%s size %d, {%d %d %d %d %d %d}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
],
a
[
2
],
a
[
3
],
a
[
4
],
a
[
5
]);
}
else
if
constexpr
(
nsize
==
7
)
{
printf
(
"%s size %d, {%d %d %d %d %d %d %d}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
],
a
[
2
],
a
[
3
],
a
[
4
],
a
[
5
],
a
[
6
]);
}
else
if
constexpr
(
nsize
==
8
)
{
printf
(
"%s size %d, {%d %d %d %d %d %d %d %d}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
],
a
[
2
],
a
[
3
],
a
[
4
],
a
[
5
],
a
[
6
],
a
[
7
]);
}
else
if
constexpr
(
nsize
==
9
)
{
printf
(
"%s size %d, {%d %d %d %d %d %d %d %d %d}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
],
a
[
2
],
a
[
3
],
a
[
4
],
a
[
5
],
a
[
6
],
a
[
7
],
a
[
8
]);
}
else
if
constexpr
(
nsize
==
10
)
{
printf
(
"%s size %d, {%d %d %d %d %d %d %d %d %d %d}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
],
a
[
2
],
a
[
3
],
a
[
4
],
a
[
5
],
a
[
6
],
a
[
7
],
a
[
8
],
a
[
9
]);
}
else
{
printf
(
"%s size %d, {"
,
s
,
nsize
);
static_for
<
0
,
nsize
,
1
>
{}([
&
a
](
auto
i
)
constexpr
{
printf
(
"%d, "
,
a
[
i
]);
});
printf
(
"}
\n
"
);
}
}
}
template
<
typename
T
>
...
...
composable_kernel/include/utility/statically_indexed_array.hpp
0 → 100644
View file @
674c405f
#ifndef CK_STATICALLY_INDEXED_ARRAY_HPP
#define CK_STATICALLY_INDEXED_ARRAY_HPP
#include "functional2.hpp"
#include "sequence.hpp"
#include "tuple.hpp"
namespace
ck
{
template
<
typename
TData
,
index_t
NSize
>
struct
StaticallyIndexedArray
{
};
template
<
typename
TData
>
struct
StaticallyIndexedArray
<
TData
,
0
>
:
Tuple
<>
{
using
data_type
=
TData
;
};
template
<
typename
TData
>
struct
StaticallyIndexedArray
<
TData
,
1
>
:
Tuple
<
TData
>
{
using
data_type
=
TData
;
};
template
<
typename
TData
>
struct
StaticallyIndexedArray
<
TData
,
2
>
:
Tuple
<
TData
,
TData
>
{
using
data_type
=
TData
;
};
template
<
typename
TData
>
struct
StaticallyIndexedArray
<
TData
,
3
>
:
Tuple
<
TData
,
TData
,
TData
>
{
using
data_type
=
TData
;
};
template
<
typename
TData
>
struct
StaticallyIndexedArray
<
TData
,
4
>
:
Tuple
<
TData
,
TData
,
TData
,
TData
>
{
using
data_type
=
TData
;
};
template
<
typename
TData
>
struct
StaticallyIndexedArray
<
TData
,
5
>
:
Tuple
<
TData
,
TData
,
TData
,
TData
,
TData
>
{
using
data_type
=
TData
;
};
template
<
typename
TData
>
struct
StaticallyIndexedArray
<
TData
,
6
>
:
Tuple
<
TData
,
TData
,
TData
,
TData
,
TData
,
TData
>
{
using
data_type
=
TData
;
};
template
<
typename
TData
>
struct
StaticallyIndexedArray
<
TData
,
7
>
:
Tuple
<
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
>
{
using
data_type
=
TData
;
};
template
<
typename
TData
>
struct
StaticallyIndexedArray
<
TData
,
8
>
:
Tuple
<
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
>
{
using
data_type
=
TData
;
};
template
<
typename
TData
>
struct
StaticallyIndexedArray
<
TData
,
9
>
:
Tuple
<
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
>
{
using
data_type
=
TData
;
};
template
<
typename
TData
>
struct
StaticallyIndexedArray
<
TData
,
10
>
:
Tuple
<
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
>
{
using
data_type
=
TData
;
};
template
<
typename
TData
>
struct
StaticallyIndexedArray
<
TData
,
11
>
:
Tuple
<
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
>
{
using
data_type
=
TData
;
};
template
<
typename
TData
>
struct
StaticallyIndexedArray
<
TData
,
12
>
:
Tuple
<
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
>
{
using
data_type
=
TData
;
};
template
<
typename
TData
>
struct
StaticallyIndexedArray
<
TData
,
13
>
:
Tuple
<
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
>
{
using
data_type
=
TData
;
};
template
<
typename
TData
>
struct
StaticallyIndexedArray
<
TData
,
14
>
:
Tuple
<
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
>
{
using
data_type
=
TData
;
};
template
<
typename
TData
>
struct
StaticallyIndexedArray
<
TData
,
15
>
:
Tuple
<
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
>
{
using
data_type
=
TData
;
};
template
<
typename
TData
>
struct
StaticallyIndexedArray
<
TData
,
16
>
:
Tuple
<
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
>
{
using
data_type
=
TData
;
};
template
<
typename
TData
>
struct
StaticallyIndexedArray
<
TData
,
17
>
:
Tuple
<
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
>
{
using
data_type
=
TData
;
};
template
<
typename
TData
>
struct
StaticallyIndexedArray
<
TData
,
18
>
:
Tuple
<
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
>
{
using
data_type
=
TData
;
};
template
<
typename
TData
>
struct
StaticallyIndexedArray
<
TData
,
19
>
:
Tuple
<
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
>
{
using
data_type
=
TData
;
};
template
<
typename
TData
>
struct
StaticallyIndexedArray
<
TData
,
20
>
:
Tuple
<
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
>
{
using
data_type
=
TData
;
};
template
<
typename
TData
>
struct
StaticallyIndexedArray
<
TData
,
21
>
:
Tuple
<
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
>
{
using
data_type
=
TData
;
};
template
<
typename
TData
>
struct
StaticallyIndexedArray
<
TData
,
22
>
:
Tuple
<
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
,
TData
>
{
using
data_type
=
TData
;
};
template
<
typename
TData
,
index_t
NSize
,
typename
X
>
__host__
__device__
constexpr
auto
operator
+=
(
StaticallyIndexedArray
<
TData
,
NSize
>&
y
,
const
X
&
x
)
{
static_assert
(
X
::
Size
()
==
NSize
,
"wrong! size not the same"
);
static_for
<
0
,
NSize
,
1
>
{}([
&
](
auto
i
)
{
y
(
i
)
+=
x
[
i
];
});
return
y
;
}
template
<
typename
TData
,
index_t
NSize
,
typename
X
>
__host__
__device__
constexpr
auto
operator
-=
(
StaticallyIndexedArray
<
TData
,
NSize
>&
y
,
const
X
&
x
)
{
static_assert
(
X
::
Size
()
==
NSize
,
"wrong! size not the same"
);
static_for
<
0
,
NSize
,
1
>
{}([
&
](
auto
i
)
{
y
(
i
)
-=
x
[
i
];
});
return
y
;
}
template
<
typename
TData
,
index_t
NSize
,
typename
T
>
__host__
__device__
constexpr
auto
operator
+
(
const
StaticallyIndexedArray
<
TData
,
NSize
>&
a
,
const
T
&
b
)
{
using
type
=
StaticallyIndexedArray
<
TData
,
NSize
>
;
static_assert
(
T
::
Size
()
==
NSize
,
"wrong! size not the same"
);
type
r
;
static_for
<
0
,
NSize
,
1
>
{}([
&
](
auto
i
)
{
r
(
i
)
=
a
[
i
]
+
b
[
i
];
});
return
r
;
}
template
<
typename
TData
,
index_t
NSize
,
typename
T
>
__host__
__device__
constexpr
auto
operator
-
(
const
StaticallyIndexedArray
<
TData
,
NSize
>&
a
,
const
T
&
b
)
{
using
type
=
StaticallyIndexedArray
<
TData
,
NSize
>
;
static_assert
(
T
::
Size
()
==
NSize
,
"wrong! size not the same"
);
type
r
;
static_for
<
0
,
NSize
,
1
>
{}([
&
](
auto
i
)
{
r
(
i
)
=
a
[
i
]
-
b
[
i
];
});
return
r
;
}
}
// namespace ck
#endif
composable_kernel/include/utility/tuple.hpp
View file @
674c405f
...
...
@@ -89,6 +89,8 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
{
}
__host__
__device__
static
constexpr
index_t
Size
()
{
return
sizeof
...(
Xs
);
}
template
<
index_t
I
>
__host__
__device__
constexpr
const
auto
&
At
(
Number
<
I
>
)
const
{
...
...
@@ -102,6 +104,28 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
static_assert
(
I
<
base
::
Size
(),
"wrong! out of range"
);
return
base
::
GetElementByKey
(
detail
::
TupleElementKey
<
I
>
{});
}
template
<
index_t
I
>
__host__
__device__
constexpr
const
auto
&
operator
[](
Number
<
I
>
i
)
const
{
return
At
(
i
);
}
template
<
index_t
I
>
__host__
__device__
constexpr
auto
&
operator
()(
Number
<
I
>
i
)
{
return
At
(
i
);
}
template
<
typename
T
>
__host__
__device__
constexpr
auto
operator
=
(
const
T
&
a
)
{
static_assert
(
T
::
Size
()
==
Size
(),
"wrong! size not the same"
);
static_for
<
0
,
Size
(),
1
>
{}([
&
](
auto
i
)
{
operator
()(
i
)
=
a
[
i
];
});
return
*
this
;
}
};
}
// namespace ck
...
...
driver/src/conv_bwd_data_driver.cu
deleted
120000 → 0
View file @
ffa7e4be
conv_bwd_data_driver
.
cpp
\ No newline at end of file
driver/src/conv_driver.cpp
View file @
674c405f
...
...
@@ -561,7 +561,7 @@ int main(int argc, char* argv[])
LeftPads{},
RightPads{},
nrepeat);
#elif
0
#elif
1
device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw
(
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment