Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
625838de
"include/composable_kernel/utility/functional3.hpp" did not exist on "88b77181aab1198b41b612f6d03b6dfb2d32bd40"
Commit
625838de
authored
Sep 06, 2019
by
Chao Liu
Browse files
added tuple
parent
12da8154
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
368 additions
and
105 deletions
+368
-105
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn_padded.hpp
..._convolution_implicit_gemm_v1r3_chwn_cyxk_khwn_padded.hpp
+64
-0
composable_kernel/include/tensor_description/dimension.hpp
composable_kernel/include/tensor_description/dimension.hpp
+6
-4
composable_kernel/include/tensor_description/multi_index_transform.hpp
...rnel/include/tensor_description/multi_index_transform.hpp
+32
-35
composable_kernel/include/tensor_description/tensor_descriptor.hpp
...e_kernel/include/tensor_description/tensor_descriptor.hpp
+66
-13
composable_kernel/include/tensor_description/tensor_descriptor_helper.hpp
...l/include/tensor_description/tensor_descriptor_helper.hpp
+117
-0
composable_kernel/include/tensor_description/tensor_visit.hpp
...osable_kernel/include/tensor_description/tensor_visit.hpp
+1
-0
composable_kernel/include/utility/Array.hpp
composable_kernel/include/utility/Array.hpp
+8
-6
composable_kernel/include/utility/tuple.hpp
composable_kernel/include/utility/tuple.hpp
+73
-40
driver/include/host_conv.hpp
driver/include/host_conv.hpp
+0
-6
driver/src/driver.cpp
driver/src/driver.cpp
+1
-1
No files found.
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn_padded.hpp
View file @
625838de
...
@@ -4,6 +4,8 @@
...
@@ -4,6 +4,8 @@
#include "common_header.hpp"
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "blockwise_batched_gemm.hpp"
#include "blockwise_batched_gemm.hpp"
...
@@ -45,6 +47,7 @@ template <index_t GridSize,
...
@@ -45,6 +47,7 @@ template <index_t GridSize,
index_t
OutThreadCopyDataPerAccess_N
>
index_t
OutThreadCopyDataPerAccess_N
>
struct
GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_padded
struct
GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_padded
{
{
#if 0
__device__ void Run(const Float* const __restrict__ p_in_global,
__device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const
Float* const __restrict__ p_out_global) const
...
@@ -478,6 +481,67 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_padded
...
@@ -478,6 +481,67 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_padded
#endif
#endif
});
});
}
}
#else
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
const
{
#if 0
constexpr auto tmp = std::tuple<bool>{};
constexpr auto flag = std::get<0>(tmp);
#else
constexpr
auto
a
=
Tuple
<
bool
,
Sequence
<
1
>
,
index_t
>
(
true
,
Sequence
<
1
>
{},
99
);
if
(
get_thread_local_1d_id
()
==
0
&&
get_block_1d_id
()
==
0
)
{
printf
(
"adsas %d
\n
"
,
a
.
At
(
Number
<
0
>
{}));
print_Sequence
(
"seq"
,
a
.
At
(
Number
<
1
>
{}));
printf
(
"adsas %lu
\n
"
,
a
.
At
(
Number
<
2
>
{}));
}
auto
b
=
Tuple
<
bool
,
Sequence
<
1
>
,
index_t
>
(
true
,
Sequence
<
1
>
{},
99
);
b
.
At
(
Number
<
0
>
{})
=
false
;
if
(
get_thread_local_1d_id
()
==
0
&&
get_block_1d_id
()
==
0
)
{
printf
(
"adsas %d
\n
"
,
b
.
At
(
Number
<
0
>
{}));
print_Sequence
(
"seq"
,
b
.
At
(
Number
<
1
>
{}));
printf
(
"adsas %lu
\n
"
,
b
.
At
(
Number
<
2
>
{}));
}
if
(
get_thread_local_1d_id
()
==
0
&&
get_block_1d_id
()
==
0
)
{
printf
(
"adsas %d
\n
"
,
Tuple
<
bool
,
Sequence
<
1
>
,
index_t
>
(
true
,
Sequence
<
1
>
(),
99
).
At
(
Number
<
0
>
{}));
print_Sequence
(
"seq"
,
Tuple
<
bool
,
Sequence
<
1
>
,
index_t
>
(
true
,
Sequence
<
1
>
(),
99
).
At
(
Number
<
1
>
{}));
printf
(
"adsas %d
\n
"
,
Tuple
<
bool
,
Sequence
<
1
>
,
index_t
>
(
true
,
Sequence
<
1
>
(),
99
).
At
(
Number
<
2
>
{}));
}
#endif
#if 0
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
// create a native tensor descriptor
constexpr auto in_n_c_h_w_global_desc =
make_NativeTensorDescriptor(InGlobalDesc::GetLengths(), InGlobalDesc::GetStrides());
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_tensor_descriptor("in_n_c_h_w_global_desc", in_n_c_h_w_global_desc);
}
// transform the tensor descriptor once
//
// calculate the offset of some entry
#endif
}
#endif
};
};
}
// namespace ck
}
// namespace ck
...
...
composable_kernel/include/tensor_description/dimension.hpp
View file @
625838de
...
@@ -12,15 +12,17 @@ struct Dimension
...
@@ -12,15 +12,17 @@ struct Dimension
};
};
template
<
index_t
Length
,
index_t
Stride
>
template
<
index_t
Length
,
index_t
Stride
>
struct
NativeDimension
:
Dimension
<
Length
>
struct
NativeDimension
{
{
__host__
__device__
static
constexpr
auto
GetLength
()
{
return
Number
<
Length
>
{};
}
__host__
__device__
static
constexpr
auto
GetStride
()
{
return
Number
<
Stride
>
{};
}
__host__
__device__
static
constexpr
auto
GetStride
()
{
return
Number
<
Stride
>
{};
}
__host__
__device__
static
constexpr
index_t
GetOffset
(
index_t
i
d
)
{
return
i
d
*
Stride
;
}
__host__
__device__
static
constexpr
index_t
GetOffset
(
index_t
i
)
{
return
i
*
Stride
;
}
__host__
__device__
static
constexpr
index_t
GetOffsetDiff
(
index_t
i
d
_diff
)
__host__
__device__
static
constexpr
index_t
GetOffsetDiff
(
index_t
i_diff
)
{
{
return
i
d
_diff
*
Stride
;
return
i_diff
*
Stride
;
}
}
};
};
...
...
composable_kernel/include/tensor_description/multi_index_transform.hpp
View file @
625838de
...
@@ -8,25 +8,19 @@ namespace ck {
...
@@ -8,25 +8,19 @@ namespace ck {
template
<
index_t
N
>
template
<
index_t
N
>
using
MultiIndex
=
Array
<
index_t
,
N
>
;
using
MultiIndex
=
Array
<
index_t
,
N
>
;
// LowLengths: Sequence<...>
template
<
index_t
Length
>
template
<
class
LowLengths
>
struct
PassThrough
struct
PassThrough
{
{
static
constexpr
index_t
nDim
=
LowLengths
::
GetSize
();
using
LowerIndex
=
MultiIndex
<
1
>
;
using
UpperIndex
=
MultiIndex
<
1
>
;
using
LowerIndex
=
MultiIndex
<
nDim
>
;
using
UpperIndex
=
LowerIndex
;
__host__
__device__
static
constexpr
auto
GetNumOfLowerDimension
()
{
return
Number
<
nDim
>
{};
}
__host__
__device__
static
constexpr
auto
GetNumOfLowerDimension
()
{
return
Number
<
1
>
{};
}
__host__
__device__
static
constexpr
auto
GetNumOfUpperDimension
()
__host__
__device__
static
constexpr
auto
GetNumOfUpperDimension
()
{
return
Number
<
1
>
{};
}
{
return
GetNumOfLowerDimension
();
}
__host__
__device__
static
constexpr
auto
GetLowerLengths
()
{
return
Low
Length
s
{};
}
__host__
__device__
static
constexpr
auto
GetLowerLengths
()
{
return
Sequence
<
Length
>
{};
}
__host__
__device__
static
constexpr
auto
GetUpperLengths
()
{
return
GetLower
Length
s
()
;
}
__host__
__device__
static
constexpr
auto
GetUpperLengths
()
{
return
Sequence
<
Length
>
{}
;
}
__host__
__device__
static
constexpr
auto
GetLowerIndex
(
UpperIndex
idx_up
)
{
return
idx_up
;
}
__host__
__device__
static
constexpr
auto
GetLowerIndex
(
UpperIndex
idx_up
)
{
return
idx_up
;
}
...
@@ -35,7 +29,7 @@ struct PassThrough
...
@@ -35,7 +29,7 @@ struct PassThrough
return
idx_up_diff
;
return
idx_up_diff
;
}
}
__host__
__device__
static
constexpr
bool
Is
Index
Transform
Linear
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
Is
Linear
Transform
()
{
return
true
;
}
};
};
// LowLengths: Sequence<...>
// LowLengths: Sequence<...>
...
@@ -45,25 +39,22 @@ struct Pad
...
@@ -45,25 +39,22 @@ struct Pad
static
constexpr
index_t
nDim
=
LowLengths
::
GetSize
();
static
constexpr
index_t
nDim
=
LowLengths
::
GetSize
();
using
LowerIndex
=
MultiIndex
<
nDim
>
;
using
LowerIndex
=
MultiIndex
<
nDim
>
;
using
UpperIndex
=
LowerIndex
;
using
UpperIndex
=
MultiIndex
<
nDim
>
;
__host__
__device__
static
constexpr
auto
GetNumOfLowerDimension
()
{
return
Number
<
nDim
>
{};
}
__host__
__device__
static
constexpr
auto
GetNumOfLowerDimension
()
{
return
Number
<
nDim
>
{};
}
__host__
__device__
static
constexpr
auto
GetNumOfUpperDimension
()
__host__
__device__
static
constexpr
auto
GetNumOfUpperDimension
()
{
return
Number
<
nDim
>
{};
}
{
return
GetNumOfLowerDimension
();
}
__host__
__device__
static
constexpr
auto
GetLowerLengths
()
{
return
LowLengths
{};
}
__host__
__device__
static
constexpr
auto
GetLowerLengths
()
{
return
LowLengths
{};
}
__host__
__device__
static
constexpr
auto
GetUpperLengths
()
__host__
__device__
static
constexpr
auto
GetUpperLengths
()
{
{
return
GetLowerLengths
()
+
LeftPads
+
RightPads
;
return
GetLowerLengths
()
+
LeftPads
{}
+
RightPads
{}
;
}
}
__host__
__device__
static
constexpr
auto
GetLowerIndex
(
UpperIndex
idx_up
)
__host__
__device__
static
constexpr
auto
GetLowerIndex
(
UpperIndex
idx_up
)
{
{
return
idx_up
-
LeftPads
;
return
idx_up
-
LeftPads
{}
;
}
}
__host__
__device__
static
constexpr
auto
GetLowerIndexDiff
(
UpperIndex
idx_up_diff
)
__host__
__device__
static
constexpr
auto
GetLowerIndexDiff
(
UpperIndex
idx_up_diff
)
...
@@ -71,9 +62,10 @@ struct Pad
...
@@ -71,9 +62,10 @@ struct Pad
return
idx_up_diff
;
return
idx_up_diff
;
}
}
__host__
__device__
static
constexpr
bool
Is
Index
Transform
Linear
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
Is
Linear
Transform
()
{
return
true
;
}
};
};
#if 0
// LowLengths: Sequence<...>
// LowLengths: Sequence<...>
template <class LowLengths>
template <class LowLengths>
struct Merge
struct Merge
...
@@ -116,8 +108,9 @@ struct Merge
...
@@ -116,8 +108,9 @@ struct Merge
return idx_low_diff;
return idx_low_diff;
}
}
__host__
__device__
static
constexpr
bool
Is
Index
Transform
Linear
()
{
return
false
;
}
__host__ __device__ static constexpr bool Is
Linear
Transform() { return false; }
};
};
#endif
// UpLengths: Sequence<...>
// UpLengths: Sequence<...>
template
<
index_t
LowLength
,
class
UpLengths
>
template
<
index_t
LowLength
,
class
UpLengths
>
...
@@ -126,6 +119,9 @@ struct Unmerge
...
@@ -126,6 +119,9 @@ struct Unmerge
static
constexpr
index_t
nDimLow
=
1
;
static
constexpr
index_t
nDimLow
=
1
;
static
constexpr
index_t
nDimUp
=
UpLengths
::
GetSize
();
static
constexpr
index_t
nDimUp
=
UpLengths
::
GetSize
();
using
UpperIndex
=
MultiIndex
<
nDimUp
>
;
using
LowerIndex
=
MultiIndex
<
nDimLow
>
;
__host__
__device__
constexpr
Unmerge
()
__host__
__device__
constexpr
Unmerge
()
{
{
static_assert
(
LowLength
==
accumulate_on_sequence
(
static_assert
(
LowLength
==
accumulate_on_sequence
(
...
@@ -133,7 +129,7 @@ struct Unmerge
...
@@ -133,7 +129,7 @@ struct Unmerge
"wrong! UpLengths need to be "
);
"wrong! UpLengths need to be "
);
}
}
__host__
__device__
static
constexpr
auto
GetNumOfUpperDimension
()
{
return
Number
<
nDimUp
>
{}
}
;
__host__
__device__
static
constexpr
auto
GetNumOfUpperDimension
()
{
return
Number
<
nDimUp
>
{};
}
__host__
__device__
static
constexpr
auto
GetNumOfLowerDimension
()
{
return
Number
<
nDimLow
>
{};
}
__host__
__device__
static
constexpr
auto
GetNumOfLowerDimension
()
{
return
Number
<
nDimLow
>
{};
}
...
@@ -149,7 +145,7 @@ struct Unmerge
...
@@ -149,7 +145,7 @@ struct Unmerge
LowerIndex
idx_low
{
0
};
LowerIndex
idx_low
{
0
};
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
idim
)
{
idx_low
[
0
]
+=
idx_up
[
idim
]
*
scans
[
idim
];
});
static_for
<
0
,
nDim
Up
,
1
>
{}([
&
](
auto
idim
)
{
idx_low
(
0
)
+=
idx_up
[
idim
]
*
scans
[
idim
];
});
return
idx_low
;
return
idx_low
;
}
}
...
@@ -159,7 +155,7 @@ struct Unmerge
...
@@ -159,7 +155,7 @@ struct Unmerge
return
GetLowerIndex
(
idx_up_diff
);
return
GetLowerIndex
(
idx_up_diff
);
}
}
__host__
__device__
static
constexpr
bool
Is
Index
Transform
Linear
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
Is
Linear
Transform
()
{
return
true
;
}
};
};
// UpLengths: Sequence<...>
// UpLengths: Sequence<...>
...
@@ -171,7 +167,8 @@ struct Embed
...
@@ -171,7 +167,8 @@ struct Embed
static
constexpr
index_t
nDimLow
=
1
;
static
constexpr
index_t
nDimLow
=
1
;
static
constexpr
index_t
nDimUp
=
UpLengths
::
GetSize
();
static
constexpr
index_t
nDimUp
=
UpLengths
::
GetSize
();
static
constexpr
auto
mCoefficients
=
Coefficients
{};
using
LowerIndex
=
MultiIndex
<
nDimLow
>
;
using
UpperIndex
=
MultiIndex
<
nDimUp
>
;
__host__
__device__
constexpr
Embed
()
__host__
__device__
constexpr
Embed
()
{
{
...
@@ -179,14 +176,14 @@ struct Embed
...
@@ -179,14 +176,14 @@ struct Embed
"wrong! # of dimensions not consistent"
);
"wrong! # of dimensions not consistent"
);
constexpr
index_t
low_id_max
=
constexpr
index_t
low_id_max
=
Coefficents
.
Back
()
+
accumulate_on_sequence
(
UpLengths
{}
*
Coefficients
::
PopBack
(),
Coeffic
i
ents
::
Back
()
+
accumulate_on_sequence
(
UpLengths
{}
*
Coefficients
::
PopBack
(),
math
::
plus
<
index_t
>
{},
math
::
plus
<
index_t
>
{},
Number
<
0
>
{});
Number
<
0
>
{});
static_assert
(
low_id_max
<
LowLength
,
"wrong! lower-id will go out of range"
);
static_assert
(
low_id_max
<
LowLength
,
"wrong! lower-id will go out of range"
);
}
}
__host__
__device__
static
constexpr
auto
GetNumOfUpperDimension
()
{
return
Number
<
nDimUp
>
{}
}
;
__host__
__device__
static
constexpr
auto
GetNumOfUpperDimension
()
{
return
Number
<
nDimUp
>
{};
}
__host__
__device__
static
constexpr
auto
GetNumOfLowerDimension
()
{
return
Number
<
nDimLow
>
{};
}
__host__
__device__
static
constexpr
auto
GetNumOfLowerDimension
()
{
return
Number
<
nDimLow
>
{};
}
...
@@ -196,10 +193,10 @@ struct Embed
...
@@ -196,10 +193,10 @@ struct Embed
__host__
__device__
static
constexpr
auto
GetLowerIndex
(
UpperIndex
idx_up
)
__host__
__device__
static
constexpr
auto
GetLowerIndex
(
UpperIndex
idx_up
)
{
{
LowerIndex
idx_low
{
m
Coefficients
[
nDimUp
]
}
;
LowerIndex
idx_low
(
Coefficients
{}
[
nDimUp
]
)
;
static_for
<
0
,
nDimUp
,
1
>
{}(
static_for
<
0
,
nDimUp
,
1
>
{}(
[
&
](
auto
idim
)
{
idx_low
[
0
]
+=
idx_up
[
idim
]
*
m
Coefficients
[
idim
];
});
[
&
](
auto
idim
)
{
idx_low
[
0
]
+=
idx_up
[
idim
]
*
Coefficients
{}
[
idim
];
});
return
idx_low
;
return
idx_low
;
}
}
...
@@ -209,12 +206,12 @@ struct Embed
...
@@ -209,12 +206,12 @@ struct Embed
LowerIndex
idx_low_diff
{
0
};
LowerIndex
idx_low_diff
{
0
};
static_for
<
0
,
nDimUp
,
1
>
{}(
static_for
<
0
,
nDimUp
,
1
>
{}(
[
&
](
auto
idim
)
{
idx_low_diff
[
0
]
+=
idx_up_diff
[
idim
]
*
m
Coefficients
[
idim
];
});
[
&
](
auto
idim
)
{
idx_low_diff
[
0
]
+=
idx_up_diff
[
idim
]
*
Coefficients
{}
[
idim
];
});
return
idx_low_diff
;
return
idx_low_diff
;
}
}
__host__
__device__
static
constexpr
bool
Is
Index
Transform
Linear
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
Is
Linear
Transform
()
{
return
true
;
}
};
};
}
// namespace ck
}
// namespace ck
...
...
composable_kernel/include/tensor_description/tensor_descriptor.hpp
View file @
625838de
...
@@ -11,21 +11,39 @@ template <class... NativeDimensions>
...
@@ -11,21 +11,39 @@ template <class... NativeDimensions>
struct
NativeTensorDescriptor
struct
NativeTensorDescriptor
{
{
using
type
=
NativeTensorDescriptor
;
using
type
=
NativeTensorDescriptor
;
static
constexpr
auto
mDimensions
=
Tuple
<
NativeDimensions
...
>
;
static
constexpr
auto
mDimensions
=
Tuple
<
NativeDimensions
...
>
{}
;
static
constexpr
index_t
nDim
=
mDimensions
::
GetSize
();
static
constexpr
index_t
nDim
=
mDimensions
.
GetSize
();
using
Index
=
MultiIndex
<
nDim
>
;
using
Index
=
MultiIndex
<
nDim
>
;
__host__
__device__
static
constexpr
auto
GetNumOfDimension
()
{
return
Number
<
nDim
>
{};
}
__host__
__device__
static
constexpr
auto
GetNumOfDimension
()
{
return
Number
<
nDim
>
{};
}
struct
lambda_GetLength
{
template
<
class
IDim
>
__host__
__device__
constexpr
auto
operator
()(
IDim
)
const
{
return
GetLength
(
IDim
{});
}
};
__host__
__device__
static
constexpr
auto
GetLengths
()
__host__
__device__
static
constexpr
auto
GetLengths
()
{
{
// not implemented
return
typename
sequence_gen
<
nDim
,
lambda_GetLength
>::
type
{};
}
}
struct
lambda_GetStride
{
template
<
class
IDim
>
__host__
__device__
constexpr
auto
operator
()(
IDim
)
const
{
return
GetStride
(
IDim
{});
}
};
__host__
__device__
static
constexpr
auto
GetStrides
()
__host__
__device__
static
constexpr
auto
GetStrides
()
{
{
// not implemented
return
typename
sequence_gen
<
nDim
,
lambda_GetStride
>::
type
{};
}
}
template
<
index_t
IDim
>
template
<
index_t
IDim
>
...
@@ -59,20 +77,26 @@ struct NativeTensorDescriptor
...
@@ -59,20 +77,26 @@ struct NativeTensorDescriptor
return
offset_diff
;
return
offset_diff
;
}
}
__host__
__device__
static
constexpr
auto
AreUpperIndex2OffsetTransformLinear
();
template
<
index_t
IDim
>
__host__
__device__
static
constexpr
bool
IsLinearDimension
(
Number
<
IDim
>
)
{
{
// TODO: re-implement "Sequence", so that it can take other data-type (including bool) as
return
true
;
// element
return
uniform_sequence_gen
<
nDim
,
1
>
{};
}
}
__host__
__device__
static
constexpr
auto
Get
Independent
Dimension
Group
s
()
__host__
__device__
static
constexpr
auto
Get
Linear
Dimensions
()
{
{
// not implemented, should return Tuple<Sequence<0>, Sequence<1>, ...>
return
typename
arithmetic_sequence_gen
<
0
,
nDim
,
1
>::
type
{};
return
xxx
;
}
__host__
__device__
static
constexpr
auto
GetNonLinearDimensions
()
{
return
Sequence
<>
{};
}
__host__
__device__
static
constexpr
auto
GetNonLinearIndependentDimensionGroups
()
{
return
Tuple
<>
{};
}
}
};
};
#if 0
// LowerTensorDescriptor
// LowerTensorDescriptor
// Transforms: std::tuple<DimensionTransforms...>
// Transforms: std::tuple<DimensionTransforms...>
// LowerDimensionIds: std::tuple<Sequence<...>>
// LowerDimensionIds: std::tuple<Sequence<...>>
...
@@ -213,16 +237,45 @@ struct TransformedTensorDescriptor
...
@@ -213,16 +237,45 @@ struct TransformedTensorDescriptor
return GetLowerTensorDescriptor().GetOffset(GetLowerIndex(idx_up));
return GetLowerTensorDescriptor().GetOffset(GetLowerIndex(idx_up));
}
}
__host__
__device__
static
constexpr
auto
AreUpperIndex2OffsetTransformLinear
();
template <index_t IDim>
__host__ __device__ static constexpr bool IsLinearDimension(Number<IDim>);
{
// not implemented
}
__host__ __device__ static constexpr auto GetLinearDimensions()
{
// not implemented
}
__host__ __device__ static constexpr auto GetNonLinearDimensions()
{
{
// not implemented
// not implemented
}
}
__host__
__device__
static
constexpr
auto
GetIndependentDimensionGroups
()
__host__ __device__ static constexpr auto Get
NonLinear
IndependentDimensionGroups()
{
{
// not implemented
// not implemented
}
}
};
};
#endif
template
<
index_t
...
Lengths
,
index_t
...
Strides
>
__host__
__device__
constexpr
auto
make_NativeTensorDescriptor
(
Sequence
<
Lengths
...
>
,
Sequence
<
Strides
...
>
)
{
return
NativeTensorDescriptor
<
NativeDimension
<
Lengths
,
Strides
>
...
>
{};
}
template
<
class
Lengths
>
__host__
__device__
constexpr
auto
make_NativeTensorDescriptor_packed
(
Lengths
)
{
constexpr
index_t
strides
=
reverse_inclusive_scan_sequence
(
Lengths
::
PopFront
(),
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{})
.
PushBack
(
Number
<
1
>
{});
return
make_NativeTensorDescriptor
(
Lengths
{},
strides
);
}
}
// namespace ck
}
// namespace ck
#endif
#endif
composable_kernel/include/tensor_description/tensor_descriptor_helper.hpp
0 → 100644
View file @
625838de
#ifndef CK_TENSOR_DESCRIPTOR_HELPER_HPP
#define CK_TENSOR_DESCRIPTOR_HELPER_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
namespace
ck
{
template
<
class
...
NativeDimensions
>
__host__
__device__
void
print_tensor_descriptor
(
const
char
*
s
,
NativeTensorDescriptor
<
NativeDimensions
...
>
desc
)
{
print_tensor_descriptor_impl
(
s
,
desc
.
GetLengths
(),
desc
.
GetStrides
());
}
template
<
index_t
...
Lengths
,
index_t
...
Strides
>
__host__
__device__
void
print_tensor_descriptor_impl
(
const
char
*
s
,
Sequence
<
Lengths
...
>
,
Sequence
<
Strides
...
>
)
{
constexpr
index_t
nDim
=
sizeof
...(
Lengths
);
static_assert
(
nDim
>
0
&&
nDim
<=
12
,
"wrong!"
);
static_if
<
nDim
==
1
>
{}([
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u}, strides {%u}
\n
"
,
s
,
nDim
,
Lengths
...,
Strides
...);
});
static_if
<
nDim
==
2
>
{}([
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u}, strides {%u %u}
\n
"
,
s
,
nDim
,
Lengths
...,
Strides
...);
});
static_if
<
nDim
==
3
>
{}([
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u}, strides {%u %u %u}
\n
"
,
s
,
nDim
,
Lengths
...,
Strides
...);
});
static_if
<
nDim
==
4
>
{}([
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u %u}, strides {%u %u %u %u}
\n
"
,
s
,
nDim
,
Lengths
...,
Strides
...);
});
static_if
<
nDim
==
5
>
{}([
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u %u %u}, strides {%u %u %u %u %u}
\n
"
,
s
,
nDim
,
Lengths
...,
Strides
...);
});
static_if
<
nDim
==
6
>
{}([
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u %u %u %u}, strides {%u %u %u %u %u %u}
\n
"
,
s
,
nDim
,
Lengths
...,
Strides
...);
});
static_if
<
nDim
==
7
>
{}([
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u}
\n
"
,
s
,
nDim
,
Lengths
...,
Strides
...);
});
static_if
<
nDim
==
8
>
{}([
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u}
\n
"
,
s
,
nDim
,
Lengths
...,
Strides
...);
});
static_if
<
nDim
==
9
>
{}([
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u "
"%u}
\n
"
,
s
,
nDim
,
Lengths
...,
Strides
...);
});
static_if
<
nDim
==
10
>
{}([
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u "
"%u %u %u}
\n
"
,
s
,
nDim
,
Lengths
...,
Strides
...);
});
static_if
<
nDim
==
11
>
{}([
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u "
"%u %u "
"%u %u %u}
\n
"
,
s
,
nDim
,
Lengths
...,
Strides
...);
});
static_if
<
nDim
==
12
>
{}([
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u "
"%u %u %u %u "
"%u %u %u}
\n
"
,
s
,
nDim
,
Lengths
...,
Strides
...);
});
}
}
// namespace ck
#endif
composable_kernel/include/tensor_description/tensor_visit.hpp
View file @
625838de
...
@@ -85,6 +85,7 @@ struct TensorVisit
...
@@ -85,6 +85,7 @@ struct TensorVisit
{
{
constexpr
auto
nonlinear_independent_dimensions_igroup
=
constexpr
auto
nonlinear_independent_dimensions_igroup
=
nonlinear_independent_dimension_groups
.
Get
(
igroup
);
nonlinear_independent_dimension_groups
.
Get
(
igroup
);
constexpr
auto
nonlinear_independent_lengths_igroup
=
constexpr
auto
nonlinear_independent_lengths_igroup
=
lambda_HackLengths
{}(
lengths
,
nonlinear_independent_dimensions_igroup
);
lambda_HackLengths
{}(
lengths
,
nonlinear_independent_dimensions_igroup
);
...
...
composable_kernel/include/utility/Array.hpp
View file @
625838de
...
@@ -82,9 +82,11 @@ struct Array
...
@@ -82,9 +82,11 @@ struct Array
// A: Array
// A: Array
// Picks: Sequence<...>
// Picks: Sequence<...>
template
<
class
Arr
,
class
Picks
>
template
<
class
Arr
,
class
Picks
>
ArrayElementPicker
struct
ArrayElementPicker
{
{
__host__
__device__
constexpr
ArrayElementPicker
(
Arr
&
array
)
:
mData
{
array
}
using
data_type
=
typename
Arr
::
data_type
;
__host__
__device__
constexpr
ArrayElementPicker
(
Arr
&
array
)
:
mData
{
array
}
{
{
constexpr
index_t
imax
=
constexpr
index_t
imax
=
accumulate_on_sequence
(
Picks
{},
math
::
maxer
<
index_t
>
{},
Number
<
0
>
{});
accumulate_on_sequence
(
Picks
{},
math
::
maxer
<
index_t
>
{},
Number
<
0
>
{});
...
@@ -95,26 +97,26 @@ ArrayElementPicker
...
@@ -95,26 +97,26 @@ ArrayElementPicker
__host__
__device__
static
constexpr
index_t
GetSize
()
{
return
Picks
::
GetSize
();
}
__host__
__device__
static
constexpr
index_t
GetSize
()
{
return
Picks
::
GetSize
();
}
template
<
index_t
I
>
template
<
index_t
I
>
__host__
__device__
constexpr
TD
ata
operator
[](
Number
<
I
>
)
const
__host__
__device__
constexpr
d
ata
_type
operator
[](
Number
<
I
>
)
const
{
{
constexpr
auto
IP
=
Picks
::
Get
(
Number
<
I
>
{});
constexpr
auto
IP
=
Picks
::
Get
(
Number
<
I
>
{});
return
mData
[
IP
];
return
mData
[
IP
];
}
}
__host__
__device__
constexpr
TD
ata
operator
[](
index_t
i
)
const
__host__
__device__
constexpr
d
ata
_type
operator
[](
index_t
i
)
const
{
{
constexpr
index_t
ip
=
Picks
{}[
i
];
constexpr
index_t
ip
=
Picks
{}[
i
];
return
mData
[
ip
];
return
mData
[
ip
];
}
}
template
<
index_t
I
>
template
<
index_t
I
>
__host__
__device__
TD
ata
&
operator
()(
Number
<
I
>
)
__host__
__device__
d
ata
_type
&
operator
()(
Number
<
I
>
)
{
{
constexpr
auto
IP
=
Picks
::
Get
(
Number
<
I
>
{});
constexpr
auto
IP
=
Picks
::
Get
(
Number
<
I
>
{});
return
mData
[
IP
];
return
mData
[
IP
];
}
}
__host__
__device__
TD
ata
&
operator
()(
index_t
i
)
__host__
__device__
d
ata
_type
&
operator
()(
index_t
i
)
{
{
constexpr
index_t
ip
=
Picks
{}[
i
];
constexpr
index_t
ip
=
Picks
{}[
i
];
return
mData
[
ip
];
return
mData
[
ip
];
...
...
composable_kernel/include/utility/tuple.hpp
View file @
625838de
...
@@ -2,66 +2,99 @@
...
@@ -2,66 +2,99 @@
#define CK_TUPLE_HPP
#define CK_TUPLE_HPP
#include "integral_constant.hpp"
#include "integral_constant.hpp"
#include "Sequence.hpp"
namespace
ck
{
namespace
ck
{
template
<
class
...
Ts
>
namespace
detail
{
struct
tuple
:
public
std
::
tuple
<
Ts
...
>
{
using
type
=
tuple
;
__host__
__device__
static
constexpr
index_t
GetSize
()
{
return
std
::
tuple_size
(
tuple
{});
}
template
<
index_t
>
struct
TupleElementKey
{
};
template
<
index_t
I
>
template
<
typename
Key
,
typename
Data
>
__host__
__device__
constexpr
auto
Get
(
Number
<
I
>
)
const
struct
TupleElement
{
template
<
typename
T
>
__host__
__device__
explicit
constexpr
TupleElement
(
T
&&
v
)
:
mData
(
static_cast
<
T
&&>
(
v
))
{
{
return
std
::
get
<
I
>
(
*
this
);
}
}
template
<
index_t
I
>
Data
mData
;
__host__
__device__
constexpr
auto
operator
[](
Number
<
I
>
)
const
{
return
Get
(
Number
<
I
>
{})
:
}
};
};
// merge tuple
template
<
typename
Key
,
typename
Data
>
template
<
class
...
Tuples
>
__host__
__device__
constexpr
const
Data
&
get_tuple_element
(
const
TupleElement
<
Key
,
Data
>&
x
)
__host__
__device__
constexpr
auto
merge_tuple
(
Tuples
&&
...
xs
)
{
{
return
std
::
tuple_cat
(
xs
...)
;
return
x
.
mData
;
}
;
}
// generate sequence
template
<
typename
Key
,
typename
Data
>
template
<
index_t
IBegin
,
index_t
NRemain
,
class
F
>
__host__
__device__
constexpr
Data
&
get_tuple_element
(
TupleElement
<
Key
,
Data
>&
x
)
struct
tuple_gen_impl
{
{
static
constexpr
index_t
NRemainLeft
=
NRemain
/
2
;
return
x
.
mData
;
static
constexpr
index_t
NRemainRight
=
NRemain
-
NRemainLeft
;
}
static
constexpr
index_t
IMiddle
=
IBegin
+
NRemainLeft
;
using
type
=
template
<
typename
Key
,
typename
Data
>
typename
tuple_merge
<
typename
tuple_gen_impl
<
IBegin
,
NRemainLeft
,
F
>::
type
,
__host__
__device__
constexpr
Data
&&
get_tuple_element
(
TupleElement
<
Key
,
Data
>&&
x
)
typename
tuple_gen_impl
<
IMiddle
,
NRemainRight
,
F
>::
type
>::
type
;
};
template
<
index_t
I
,
class
F
>
struct
tuple_gen_impl
<
I
,
1
,
F
>
{
{
static
constexpr
auto
x
=
F
{}(
Number
<
I
>
{});
return
static_cast
<
Data
&&>
(
x
.
mData
);
using
type
=
tuple
<
Is
>
;
}
};
template
<
index_t
I
,
class
F
>
template
<
typename
Indices
,
typename
...
Xs
>
struct
sequence_gen_impl
<
I
,
0
,
F
>
struct
TupleImpl
;
template
<
index_t
...
Is
,
typename
...
Xs
>
struct
TupleImpl
<
Sequence
<
Is
...
>
,
Xs
...
>
:
TupleElement
<
TupleElementKey
<
Is
>
,
Xs
>
...
{
{
using
type
=
Sequence
<>
;
template
<
typename
...
Ys
>
__host__
__device__
explicit
constexpr
TupleImpl
(
Ys
&&
...
ys
)
:
TupleElement
<
TupleElementKey
<
Is
>
,
Xs
>
(
static_cast
<
Ys
&&>
(
ys
))...
{
}
__host__
__device__
static
constexpr
index_t
Size
()
{
return
sizeof
...(
Xs
);
}
template
<
index_t
I
>
__host__
__device__
constexpr
const
auto
&
GetElementByKey
(
TupleElementKey
<
I
>
)
const
{
return
get_tuple_element
<
TupleElementKey
<
I
>>
(
*
this
);
}
template
<
index_t
I
>
__host__
__device__
constexpr
auto
&
GetElementByKey
(
TupleElementKey
<
I
>
)
{
return
get_tuple_element
<
TupleElementKey
<
I
>>
(
*
this
);
}
};
};
template
<
index_t
NSize
,
class
F
>
}
// namespace detail
struct
sequence_gen
template
<
typename
...
Xs
>
struct
Tuple
:
detail
::
TupleImpl
<
typename
arithmetic_sequence_gen
<
0
,
sizeof
...(
Xs
),
1
>::
type
,
Xs
...
>
{
{
using
type
=
typename
sequence_gen_impl
<
0
,
NSize
,
F
>::
type
;
using
base
=
detail
::
TupleImpl
<
typename
arithmetic_sequence_gen
<
0
,
sizeof
...(
Xs
),
1
>::
type
,
Xs
...
>
;
template
<
typename
...
Ys
>
__host__
__device__
explicit
constexpr
Tuple
(
Ys
&&
...
ys
)
:
base
(
static_cast
<
Ys
&&>
(
ys
)...)
{
}
template
<
index_t
I
>
__host__
__device__
constexpr
const
auto
&
At
(
Number
<
I
>
)
const
{
static_assert
(
I
<
base
::
Size
(),
"wrong! out of range"
);
return
GetElementByKey
(
detail
::
TupleElementKey
<
I
>
{});
}
template
<
index_t
I
>
__host__
__device__
constexpr
auto
&
At
(
Number
<
I
>
)
{
static_assert
(
I
<
base
::
Size
(),
"wrong! out of range"
);
return
GetElementByKey
(
detail
::
TupleElementKey
<
I
>
{});
}
};
};
}
// namespace ck
}
// namespace ck
...
...
driver/include/host_conv.hpp
View file @
625838de
...
@@ -65,9 +65,6 @@ void host_direct_convolution(const Tensor<TIn>& in_nchw,
...
@@ -65,9 +65,6 @@ void host_direct_convolution(const Tensor<TIn>& in_nchw,
index_t
h_pad_low
=
LowerPads
{}.
Get
(
Number
<
0
>
{});
index_t
h_pad_low
=
LowerPads
{}.
Get
(
Number
<
0
>
{});
index_t
w_pad_low
=
LowerPads
{}.
Get
(
Number
<
1
>
{});
index_t
w_pad_low
=
LowerPads
{}.
Get
(
Number
<
1
>
{});
index_t
h_pad_up
=
UpperPads
{}.
Get
(
Number
<
0
>
{});
index_t
w_pad_up
=
UpperPads
{}.
Get
(
Number
<
1
>
{});
auto
f
=
[
&
](
auto
n
,
auto
k
,
auto
ho
,
auto
wo
)
{
auto
f
=
[
&
](
auto
n
,
auto
k
,
auto
ho
,
auto
wo
)
{
double
v
=
0
;
double
v
=
0
;
for
(
int
c
=
0
;
c
<
wei_kcyx
.
mDesc
.
GetLengths
()[
1
];
++
c
)
for
(
int
c
=
0
;
c
<
wei_kcyx
.
mDesc
.
GetLengths
()[
1
];
++
c
)
...
@@ -125,9 +122,6 @@ void host_winograd_3x3_convolution(const Tensor<TIn>& in_nchw,
...
@@ -125,9 +122,6 @@ void host_winograd_3x3_convolution(const Tensor<TIn>& in_nchw,
index_t
h_pad_low
=
LowerPads
{}.
Get
(
Number
<
0
>
{});
index_t
h_pad_low
=
LowerPads
{}.
Get
(
Number
<
0
>
{});
index_t
w_pad_low
=
LowerPads
{}.
Get
(
Number
<
1
>
{});
index_t
w_pad_low
=
LowerPads
{}.
Get
(
Number
<
1
>
{});
index_t
h_pad_up
=
UpperPads
{}.
Get
(
Number
<
0
>
{});
index_t
w_pad_up
=
UpperPads
{}.
Get
(
Number
<
1
>
{});
std
::
size_t
HiPerTile
=
HoPerTile
+
Y
-
1
;
std
::
size_t
HiPerTile
=
HoPerTile
+
Y
-
1
;
std
::
size_t
WiPerTile
=
WoPerTile
+
X
-
1
;
std
::
size_t
WiPerTile
=
WoPerTile
+
X
-
1
;
...
...
driver/src/driver.cpp
View file @
625838de
...
@@ -368,7 +368,7 @@ int main(int argc, char* argv[])
...
@@ -368,7 +368,7 @@ int main(int argc, char* argv[])
#if 0
#if 0
device_convolution_direct_v2_nchw_kcyx_nkhw
device_convolution_direct_v2_nchw_kcyx_nkhw
(in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat);
(in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat);
#elif
1
#elif
0
device_convolution_implicit_gemm_v1_chwn_cyxk_khwn
(
device_convolution_implicit_gemm_v1_chwn_cyxk_khwn
(
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw_device
,
nrepeat
);
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw_device
,
nrepeat
);
#elif 1
#elif 1
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment