Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
db32635c
Commit
db32635c
authored
Sep 06, 2022
by
Po-Yen, Chen
Browse files
Generalize transpose utility functions
parent
98498486
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
154 additions
and
45 deletions
+154
-45
example/36_elementwise_permute/common.hpp
example/36_elementwise_permute/common.hpp
+151
-23
example/36_elementwise_permute/run_elementwise_permute_example.inc
...6_elementwise_permute/run_elementwise_permute_example.inc
+3
-22
No files found.
example/36_elementwise_permute/common.hpp
View file @
db32635c
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
#pragma once
#pragma once
#include <cassert>
#include <cstddef>
#include <cstddef>
#include <cstdlib>
#include <cstdlib>
#include <iostream>
#include <iostream>
...
@@ -64,7 +65,7 @@ struct Placeholder final
...
@@ -64,7 +65,7 @@ struct Placeholder final
constexpr
inline
operator
T
()
const
noexcept
;
constexpr
inline
operator
T
()
const
noexcept
;
};
};
template
<
typename
T
,
typename
=
void
>
template
<
typename
Iterator
,
typename
=
void
>
struct
is_output_iterator
:
std
::
false_type
struct
is_output_iterator
:
std
::
false_type
{
{
};
};
...
@@ -80,6 +81,23 @@ struct is_output_iterator<
...
@@ -80,6 +81,23 @@ struct is_output_iterator<
template
<
typename
T
>
template
<
typename
T
>
inline
constexpr
bool
is_output_iterator_v
=
is_output_iterator
<
T
>::
value
;
inline
constexpr
bool
is_output_iterator_v
=
is_output_iterator
<
T
>::
value
;
template
<
typename
Iterator
,
typename
=
void
>
struct
is_bidirectional_iterator
:
std
::
false_type
{
};
template
<
typename
Iterator
>
struct
is_bidirectional_iterator
<
Iterator
,
std
::
void_t
<
decltype
(
--
std
::
declval
<
std
::
add_lvalue_reference_t
<
Iterator
>>
()),
decltype
(
std
::
declval
<
std
::
add_lvalue_reference_t
<
Iterator
>>
()
--
)
>>
:
std
::
bool_constant
<
is_iterator_v
<
Iterator
>>
{
};
template
<
typename
Iterator
>
inline
constexpr
bool
is_bidirectional_iterator_v
=
is_bidirectional_iterator
<
Iterator
>::
value
;
template
<
typename
Iterator
,
typename
=
void
>
template
<
typename
Iterator
,
typename
=
void
>
struct
is_random_access_iterator
:
std
::
false_type
struct
is_random_access_iterator
:
std
::
false_type
{
{
...
@@ -126,6 +144,22 @@ struct is_sized_range<Range, std::void_t<decltype(size(std::declval<Range>()))>>
...
@@ -126,6 +144,22 @@ struct is_sized_range<Range, std::void_t<decltype(size(std::declval<Range>()))>>
template
<
typename
Range
>
template
<
typename
Range
>
inline
constexpr
bool
is_sized_range_v
=
is_sized_range
<
Range
>::
value
;
inline
constexpr
bool
is_sized_range_v
=
is_sized_range
<
Range
>::
value
;
template
<
typename
Range
,
typename
=
void
>
struct
is_bidirectional_range
:
std
::
false_type
{
};
template
<
typename
Range
>
struct
is_bidirectional_range
<
Range
,
std
::
void_t
<>>
:
std
::
bool_constant
<
is_range_v
<
Range
>
&&
is_bidirectional_iterator_v
<
ck
::
remove_cvref_t
<
decltype
(
begin
(
std
::
declval
<
Range
>
()))
>>>
{
};
template
<
typename
Range
>
inline
constexpr
bool
is_bidirectional_range_v
=
is_bidirectional_range
<
Range
>::
value
;
template
<
typename
Range
,
typename
=
void
>
template
<
typename
Range
,
typename
=
void
>
struct
is_random_access_range
:
std
::
false_type
struct
is_random_access_range
:
std
::
false_type
{
{
...
@@ -155,30 +189,13 @@ is_valid_axes(const Axes& axes)
...
@@ -155,30 +189,13 @@ is_valid_axes(const Axes& axes)
}
}
using
std
::
begin
,
std
::
end
;
using
std
::
begin
,
std
::
end
;
std
::
vector
<
std
::
size_t
>
copy
(
begin
(
axes
),
end
(
axes
));
std
::
vector
<
std
::
size_t
>
sorted_axes
(
begin
(
axes
),
end
(
axes
));
std
::
sort
(
begin
(
copy
),
end
(
copy
));
std
::
sort
(
begin
(
sorted_axes
),
end
(
sorted_axes
));
const
auto
last
=
std
::
unique
(
begin
(
copy
),
end
(
copy
));
const
auto
last
=
std
::
unique
(
begin
(
sorted_axes
),
end
(
sorted_axes
));
return
(
last
==
end
(
copy
))
&&
(
*
begin
(
copy
)
==
0
)
&&
(
*
std
::
prev
(
last
)
==
size
(
axes
)
-
1
);
return
(
last
==
end
(
sorted_axes
))
&&
(
*
begin
(
sorted_axes
)
==
0
)
&&
}
(
*
std
::
prev
(
last
)
==
size
(
axes
)
-
1
);
template
<
typename
Shape
,
typename
Axes
,
typename
OutputIterator
>
inline
std
::
enable_if_t
<
detail
::
is_random_access_range_v
<
Shape
>
&&
detail
::
is_sized_range_v
<
Shape
>
&&
detail
::
is_sized_range_v
<
Axes
>
&&
detail
::
is_output_iterator_v
<
OutputIterator
>
,
OutputIterator
>
transpose_shape
(
const
Shape
&
shape
,
const
Axes
&
axes
,
OutputIterator
iter
)
{
using
std
::
size
;
assert
(
size
(
shape
)
==
size
(
axes
)
&&
is_valid_axes
(
axes
));
for
(
const
auto
axis
:
axes
)
{
*
iter
++
=
shape
[
axis
];
}
return
iter
;
}
}
inline
bool
parse_cmd_args
(
int
argc
,
char
*
argv
[],
ExecutionConfig
&
config
,
Problem
&
problem
)
inline
bool
parse_cmd_args
(
int
argc
,
char
*
argv
[],
ExecutionConfig
&
config
,
Problem
&
problem
)
...
@@ -235,3 +252,114 @@ inline bool parse_cmd_args(int argc, char* argv[], ExecutionConfig& config, Prob
...
@@ -235,3 +252,114 @@ inline bool parse_cmd_args(int argc, char* argv[], ExecutionConfig& config, Prob
return
true
;
return
true
;
}
}
template
<
typename
Shape
>
inline
std
::
enable_if_t
<
detail
::
is_range_v
<
Shape
>
,
bool
>
is_valid_shape
(
const
Shape
&
shape
)
{
using
std
::
begin
,
std
::
end
;
using
std
::
empty
;
return
!
empty
(
shape
)
&&
std
::
all_of
(
begin
(
shape
),
end
(
shape
),
[](
auto
dim
)
{
return
0
<
dim
;
});
}
template
<
typename
Shape
,
typename
Indices
>
inline
std
::
enable_if_t
<
detail
::
is_sized_range_v
<
Shape
>
&&
detail
::
is_sized_range_v
<
Indices
>
,
bool
>
is_valid_indices
(
const
Shape
&
shape
,
const
Indices
&
indices
)
{
assert
(
is_valid_shape
(
shape
));
using
std
::
empty
;
if
(
empty
(
indices
))
{
return
false
;
}
using
std
::
size
;
if
(
size
(
shape
)
!=
size
(
indices
))
{
return
false
;
}
using
std
::
begin
,
std
::
end
;
auto
dim
=
begin
(
shape
);
auto
idx
=
begin
(
indices
);
for
(;
dim
!=
end
(
shape
)
&&
idx
!=
end
(
indices
);
++
dim
,
++
idx
)
{
if
(
*
dim
<=
*
idx
)
{
return
false
;
}
}
return
true
;
}
template
<
typename
Shape
,
typename
Axes
,
typename
OutputIterator
>
inline
std
::
enable_if_t
<
detail
::
is_random_access_range_v
<
Shape
>
&&
detail
::
is_sized_range_v
<
Shape
>
&&
detail
::
is_sized_range_v
<
Axes
>
&&
detail
::
is_output_iterator_v
<
OutputIterator
>
,
OutputIterator
>
transpose_shape
(
const
Shape
&
shape
,
const
Axes
&
axes
,
OutputIterator
iter
)
{
using
std
::
size
;
assert
(
size
(
shape
)
==
size
(
axes
)
&&
);
assert
(
is_valid_shape
(
shape
)
&&
is_valid_axes
(
axes
));
for
(
const
auto
axis
:
axes
)
{
*
iter
++
=
shape
[
axis
];
}
return
iter
;
}
template
<
typename
Shape
,
typename
Indices
>
std
::
enable_if_t
<
detail
::
is_bidirectional_range_v
<
Shape
>
&&
detail
::
is_sized_range_v
<
Shape
>
&&
detail
::
is_bidirectional_range_v
<
Indices
>
&&
detail
::
is_sized_range_v
<
Indices
>
,
bool
>
advance_indices
(
const
Shape
&
shape
,
Indices
&
indices
)
{
assert
(
is_valid_shape
(
shape
));
assert
(
is_valid_indices
(
indices
));
assert
(
size
(
shape
)
==
size
(
indices
));
bool
carry
=
true
;
using
std
::
rbegin
,
std
::
rend
;
auto
dim
=
rbegin
(
shape
);
auto
idx
=
rbegin
(
indices
);
for
(;
carry
&&
dim
!=
rend
(
shape
)
&&
idx
!=
rend
(
indices
);
++
dim
,
++
idx
)
{
assert
(
*
idx
<
*
dim
);
*
idx
=
(
*
idx
+
carry
);
carry
=
((
*
idx
==
*
dim
)
?
(
*
idx
=
0
,
true
)
:
false
);
}
return
!
carry
;
}
template
<
typename
Src
,
typename
Functor
,
typename
Dest
>
std
::
enable_if_t
<
std
::
is_invocable_v
<
Functor
,
std
::
add_lvalue_reference_t
<
Dest
>
,
std
::
add_lvalue_reference_t
<
Src
>>>
host_elementwise_permute
(
const
Tensor
<
Src
>&
src
,
Functor
functor
,
Tensor
<
Dest
>&
dest
)
{
const
auto
&
shape
=
src
.
mDesc
.
GetLengths
();
const
auto
&
transposed_shape
=
dest
.
mDesc
.
GetLengths
();
assert
(
is_valid_shape
(
shape
)
&&
is_valid_shape
(
transposed_shape
));
static_assert
(
detail
::
is_sized_range_v
<
ck
::
remove_cvref_t
<
decltype
(
shape
)
>>
&&
detail
::
is_sized_range_v
<
ck
::
remove_cvref_t
<
decltype
(
transposed_shape
)
>>
);
using
std
::
size
;
assert
(
size
(
shape
)
==
4
&&
size
(
transposed_shape
)
==
4
);
std
::
array
<
std
::
size_t
,
4
>
dims
{};
do
{
Dest
b_val
=
0
;
functor
(
b_val
,
src
(
dims
[
0
],
dims
[
1
],
dims
[
2
],
dims
[
3
]));
dest
(
dims
[
0
],
dims
[
2
],
dims
[
3
],
dims
[
1
])
=
b_val
;
}
while
(
advance_indices
(
shape
,
dims
));
}
example/36_elementwise_permute/run_elementwise_permute_example.inc
View file @
db32635c
...
@@ -3,25 +3,6 @@
...
@@ -3,25 +3,6 @@
#pragma once
#pragma once
template
<
typename
HostTensorA
,
typename
HostTensorB
,
typename
Functor
>
void
host_elementwise4D
(
HostTensorB
&
B
,
const
HostTensorA
&
A
,
const
std
::
vector
<
std
::
size_t
>&
shape
,
Functor
functor
)
{
using
btype
=
ck
::
remove_reference_t
<
decltype
(
B
(
0
,
0
,
0
,
0
))
>
;
for
(
std
::
size_t
n
=
0
;
n
<
shape
[
0
];
++
n
)
for
(
std
::
size_t
c
=
0
;
c
<
shape
[
1
];
++
c
)
for
(
std
::
size_t
h
=
0
;
h
<
shape
[
2
];
++
h
)
for
(
std
::
size_t
w
=
0
;
w
<
shape
[
3
];
++
w
)
{
auto
a_val
=
A
(
n
,
c
,
h
,
w
);
btype
b_val
=
0
;
functor
(
b_val
,
a_val
);
B
(
n
,
h
,
w
,
c
)
=
b_val
;
}
}
bool
run_elementwise_permute
(
const
ExecutionConfig
&
config
,
const
Problem
&
problem
)
bool
run_elementwise_permute
(
const
ExecutionConfig
&
config
,
const
Problem
&
problem
)
{
{
const
auto
&
nchw
=
problem
.
shape
;
const
auto
&
nchw
=
problem
.
shape
;
...
@@ -67,10 +48,10 @@ bool run_elementwise_permute(const ExecutionConfig& config, const Problem& probl
...
@@ -67,10 +48,10 @@ bool run_elementwise_permute(const ExecutionConfig& config, const Problem& probl
if
(
config
.
do_verification
)
if
(
config
.
do_verification
)
{
{
b_device_buf
.
FromDevice
(
b
.
mData
.
data
());
Tensor
<
BDataType
>
host_b
(
nhwc
);
Tensor
<
BDataType
>
host_b
(
nhwc
);
host_elementwise4D
<
Tensor
<
ADataType
>
,
Tensor
<
BDataType
>
,
PassThrough
>
(
host_elementwise_permute
(
a
,
PassThrough
{},
host_b
);
host_b
,
a
,
nhwc
,
PassThrough
{});
b_device_buf
.
FromDevice
(
b
.
mData
.
data
());
return
ck
::
utils
::
check_err
(
return
ck
::
utils
::
check_err
(
b
.
mData
,
host_b
.
mData
,
"Error: incorrect results in tensor B"
,
1
e
-
10
,
1
e
-
10
);
b
.
mData
,
host_b
.
mData
,
"Error: incorrect results in tensor B"
,
1
e
-
10
,
1
e
-
10
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment