Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
25bad0f3
Commit
25bad0f3
authored
Jun 20, 2019
by
Paul
Browse files
Improve visit_all to handle shapes as well
parent
ddb6356b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
39 additions
and
13 deletions
+39
-13
src/targets/gpu/device/include/migraphx/gpu/device/shape.hpp
src/targets/gpu/device/include/migraphx/gpu/device/shape.hpp
+1
-1
src/targets/gpu/device/include/migraphx/gpu/device/tensor_view.hpp
...ts/gpu/device/include/migraphx/gpu/device/tensor_view.hpp
+2
-2
src/targets/gpu/device/include/migraphx/gpu/device/visit.hpp
src/targets/gpu/device/include/migraphx/gpu/device/visit.hpp
+36
-10
No files found.
src/targets/gpu/device/include/migraphx/gpu/device/shape.hpp
View file @
25bad0f3
...
...
@@ -76,7 +76,7 @@ struct hip_shape
};
template
<
std
::
size_t
N
>
hip_shape
<
N
>
make_hip
(
const
shape
&
x
)
hip_shape
<
N
>
make_hip
_shape
(
const
shape
&
x
)
{
return
x
;
}
...
...
src/targets/gpu/device/include/migraphx/gpu/device/tensor_view.hpp
View file @
25bad0f3
...
...
@@ -39,9 +39,9 @@ struct hip_tensor_view
};
template
<
std
::
size_t
N
,
class
T
>
hip_tensor_view
<
T
,
N
>
make_hip
(
tensor_view
<
T
>
x
)
hip_tensor_view
<
T
,
N
>
make_hip
_view
(
const
shape
&
s
,
T
*
x
)
{
return
x
;
return
{
x
,
s
}
;
}
}
// namespace device
...
...
src/targets/gpu/device/include/migraphx/gpu/device/visit.hpp
View file @
25bad0f3
...
...
@@ -43,21 +43,50 @@ void visit_tensor_size(std::size_t n, F f)
}
}
inline
s
td
::
size_t
tensor_siz
e
(
const
shape
&
x
)
{
return
x
.
lens
().
size
()
;
}
inline
s
hape
get_shap
e
(
const
shape
&
x
)
{
return
x
;
}
template
<
class
T
>
auto
tensor_siz
e
(
const
T
&
x
)
->
decltype
(
x
.
get_shape
()
.
lens
().
size
()
)
auto
get_shap
e
(
const
T
&
x
)
->
decltype
(
x
.
get_shape
())
{
return
x
.
get_shape
().
lens
().
size
();
return
x
.
get_shape
();
}
template
<
class
V
,
class
F
,
class
...
Ts
>
void
hip_visit_all_impl
(
const
shape
&
s
,
F
f
,
V
&&
v
,
Ts
&&
...
xs
)
{
visit_tensor_size
(
s
.
lens
().
size
(),
[
&
](
auto
ndim
)
{
s
.
visit_type
([
&
](
auto
as
)
{
v
(
f
(
xs
,
ndim
,
as
)...);
});
});
}
template
<
class
F
>
struct
hip_convert
{
F
f
;
template
<
class
RawData
,
class
N
,
class
As
>
auto
operator
()(
RawData
x
,
N
ndim
,
As
as
)
const
->
decltype
(
make_hip_view
<
ndim
>
(
x
.
get_shape
(),
f
(
as
.
from
(
x
.
data
()))))
{
return
make_hip_view
<
ndim
>
(
x
.
get_shape
(),
f
(
as
.
from
(
x
.
data
())));
}
template
<
class
N
,
class
As
>
auto
operator
()(
const
shape
&
s
,
N
ndim
,
As
)
const
{
return
make_hip_shape
<
ndim
>
(
s
);
}
};
template
<
class
F
>
hip_convert
<
F
>
make_hip_convert
(
F
f
)
{
return
{
f
};
}
template
<
class
T
,
class
...
Ts
>
auto
hip_visit_all
(
T
&&
x
,
Ts
&&
...
xs
)
{
return
[
&
](
auto
f
)
{
visit_tensor_size
(
tensor_size
(
x
),
[
&
](
auto
dim
)
{
visit_all
(
x
,
xs
...)([
&
](
auto
...
vs
)
{
f
(
make_hip
<
dim
>
(
device_cast
(
vs
))...);
});
});
hip_visit_all_impl
(
get_shape
(
x
),
make_hip_convert
([](
auto
*
p
)
{
return
device_cast
(
p
);}),
f
,
x
,
xs
...);
};
}
...
...
@@ -65,10 +94,7 @@ template <std::size_t N, class T, class... Ts>
auto
hip_vec_visit_all
(
T
&&
x
,
Ts
&&
...
xs
)
{
return
[
&
](
auto
f
)
{
visit_tensor_size
(
tensor_size
(
x
),
[
&
](
auto
dim
)
{
visit_all
(
x
,
xs
...)([
&
](
auto
...
vs
)
{
f
(
make_hip
<
dim
>
(
as_vec
<
N
>
(
device_cast
(
vs
)))...);
});
});
hip_visit_all_impl
(
get_shape
(
x
),
make_hip_convert
([](
auto
*
p
)
{
return
as_vec
<
N
>
(
device_cast
(
p
));}),
f
,
x
,
xs
...);
};
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment