Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
d15edcb6
Commit
d15edcb6
authored
Jun 21, 2019
by
Paul
Browse files
Formatting
parent
ce3048d4
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
20 additions
and
17 deletions
+20
-17
src/targets/gpu/device/gather.cpp
src/targets/gpu/device/gather.cpp
+7
-7
src/targets/gpu/device/include/migraphx/gpu/device/visit.hpp
src/targets/gpu/device/include/migraphx/gpu/device/visit.hpp
+12
-7
src/targets/gpu/gather.cpp
src/targets/gpu/gather.cpp
+1
-3
No files found.
src/targets/gpu/device/gather.cpp
View file @
d15edcb6
...
@@ -13,22 +13,22 @@ namespace device {
...
@@ -13,22 +13,22 @@ namespace device {
argument
gather
(
hipStream_t
stream
,
argument
result
,
argument
arg1
,
argument
arg2
,
int
axis
)
argument
gather
(
hipStream_t
stream
,
argument
result
,
argument
arg1
,
argument
arg2
,
int
axis
)
{
{
auto
axis_index
=
(
axis
<
0
)
?
(
axis
+
arg1
.
get_shape
().
lens
().
size
())
:
axis
;
auto
axis_index
=
(
axis
<
0
)
?
(
axis
+
arg1
.
get_shape
().
lens
().
size
())
:
axis
;
auto
&
input_shape
=
arg1
.
get_shape
();
auto
&
input_shape
=
arg1
.
get_shape
();
auto
lens
=
input_shape
.
lens
();
auto
lens
=
input_shape
.
lens
();
lens
[
axis_index
]
=
arg2
.
get_shape
().
elements
();
lens
[
axis_index
]
=
arg2
.
get_shape
().
elements
();
shape
out_comp_shape
{
result
.
get_shape
().
type
(),
lens
};
shape
out_comp_shape
{
result
.
get_shape
().
type
(),
lens
};
std
::
size_t
nelements
=
result
.
get_shape
().
elements
();
std
::
size_t
nelements
=
result
.
get_shape
().
elements
();
visit_all
(
result
,
arg1
)([
&
](
auto
output
,
auto
input_v
)
{
visit_all
(
result
,
arg1
)([
&
](
auto
output
,
auto
input_v
)
{
hip_visit_views
(
input_v
,
out_comp_shape
)([
&
](
auto
input
,
auto
out_comp
)
{
hip_visit_views
(
input_v
,
out_comp_shape
)([
&
](
auto
input
,
auto
out_comp
)
{
arg2
.
visit
([
&
](
auto
indices
)
{
arg2
.
visit
([
&
](
auto
indices
)
{
const
auto
*
indices_ptr
=
device_cast
(
indices
.
data
());
const
auto
*
indices_ptr
=
device_cast
(
indices
.
data
());
auto
*
output_ptr
=
device_cast
(
output
.
data
());
auto
*
output_ptr
=
device_cast
(
output
.
data
());
gs_launch
(
stream
,
nelements
)([
=
](
auto
i
)
{
gs_launch
(
stream
,
nelements
)([
=
](
auto
i
)
{
auto
idx
=
out_comp
.
multi
(
i
);
auto
idx
=
out_comp
.
multi
(
i
);
idx
[
axis_index
]
=
indices_ptr
[
idx
[
axis_index
]];
idx
[
axis_index
]
=
indices_ptr
[
idx
[
axis_index
]];
output_ptr
[
i
]
=
input
[
idx
];
output_ptr
[
i
]
=
input
[
idx
];
});
});
});
});
});
});
...
...
src/targets/gpu/device/include/migraphx/gpu/device/visit.hpp
View file @
d15edcb6
...
@@ -55,10 +55,12 @@ template <class V, class F, class... Ts>
...
@@ -55,10 +55,12 @@ template <class V, class F, class... Ts>
void
hip_visit_all_impl
(
const
shape
&
s
,
F
f
,
V
&&
v
,
Ts
&&
...
xs
)
void
hip_visit_all_impl
(
const
shape
&
s
,
F
f
,
V
&&
v
,
Ts
&&
...
xs
)
{
{
std
::
initializer_list
<
migraphx
::
shape
::
type_t
>
types
=
{
get_shape
(
xs
).
type
()...};
std
::
initializer_list
<
migraphx
::
shape
::
type_t
>
types
=
{
get_shape
(
xs
).
type
()...};
if
(
!
std
::
all_of
(
types
.
begin
(),
types
.
end
(),
[
&
](
migraphx
::
shape
::
type_t
t
)
{
return
t
==
s
.
type
();
}))
if
(
!
std
::
all_of
(
types
.
begin
(),
types
.
end
(),
[
&
](
migraphx
::
shape
::
type_t
t
)
{
return
t
==
s
.
type
();
}))
MIGRAPHX_THROW
(
"Types must be the same"
);
MIGRAPHX_THROW
(
"Types must be the same"
);
std
::
initializer_list
<
std
::
size_t
>
ranks
=
{
get_shape
(
xs
).
lens
().
size
()...};
std
::
initializer_list
<
std
::
size_t
>
ranks
=
{
get_shape
(
xs
).
lens
().
size
()...};
if
(
!
std
::
all_of
(
ranks
.
begin
(),
ranks
.
end
(),
[
&
](
std
::
size_t
r
)
{
return
r
==
s
.
lens
().
size
();
}))
if
(
!
std
::
all_of
(
ranks
.
begin
(),
ranks
.
end
(),
[
&
](
std
::
size_t
r
)
{
return
r
==
s
.
lens
().
size
();
}))
MIGRAPHX_THROW
(
"Ranks must be the same"
);
MIGRAPHX_THROW
(
"Ranks must be the same"
);
visit_tensor_size
(
s
.
lens
().
size
(),
visit_tensor_size
(
s
.
lens
().
size
(),
[
&
](
auto
ndim
)
{
s
.
visit_type
([
&
](
auto
as
)
{
v
(
f
(
xs
,
ndim
,
as
)...);
});
});
[
&
](
auto
ndim
)
{
s
.
visit_type
([
&
](
auto
as
)
{
v
(
f
(
xs
,
ndim
,
as
)...);
});
});
...
@@ -68,10 +70,10 @@ template <class V, class F, class... Ts>
...
@@ -68,10 +70,10 @@ template <class V, class F, class... Ts>
void
hip_visit_views_impl
(
const
shape
&
s
,
F
f
,
V
&&
v
,
Ts
&&
...
xs
)
void
hip_visit_views_impl
(
const
shape
&
s
,
F
f
,
V
&&
v
,
Ts
&&
...
xs
)
{
{
std
::
initializer_list
<
std
::
size_t
>
ranks
=
{
get_shape
(
xs
).
lens
().
size
()...};
std
::
initializer_list
<
std
::
size_t
>
ranks
=
{
get_shape
(
xs
).
lens
().
size
()...};
if
(
!
std
::
all_of
(
ranks
.
begin
(),
ranks
.
end
(),
[
&
](
std
::
size_t
r
)
{
return
r
==
s
.
lens
().
size
();
}))
if
(
!
std
::
all_of
(
ranks
.
begin
(),
ranks
.
end
(),
[
&
](
std
::
size_t
r
)
{
return
r
==
s
.
lens
().
size
();
}))
MIGRAPHX_THROW
(
"Ranks must be the same"
);
MIGRAPHX_THROW
(
"Ranks must be the same"
);
visit_tensor_size
(
s
.
lens
().
size
(),
visit_tensor_size
(
s
.
lens
().
size
(),
[
&
](
auto
ndim
)
{
v
(
f
(
xs
,
ndim
)...);
});
[
&
](
auto
ndim
)
{
v
(
f
(
xs
,
ndim
)...);
});
}
}
template
<
class
F
>
template
<
class
F
>
...
@@ -152,8 +154,11 @@ template <class T, class... Ts>
...
@@ -152,8 +154,11 @@ template <class T, class... Ts>
auto
hip_visit_views
(
T
&&
x
,
Ts
&&
...
xs
)
auto
hip_visit_views
(
T
&&
x
,
Ts
&&
...
xs
)
{
{
return
[
&
](
auto
f
)
{
return
[
&
](
auto
f
)
{
hip_visit_views_impl
(
hip_visit_views_impl
(
get_shape
(
x
),
get_shape
(
x
),
make_hip_convert_view
([](
auto
v
)
{
return
device_cast
(
v
);
}),
f
,
x
,
xs
...);
make_hip_convert_view
([](
auto
v
)
{
return
device_cast
(
v
);
}),
f
,
x
,
xs
...);
};
};
}
}
...
...
src/targets/gpu/gather.cpp
View file @
d15edcb6
...
@@ -12,9 +12,7 @@ shape hip_gather::compute_shape(std::vector<shape> inputs) const
...
@@ -12,9 +12,7 @@ shape hip_gather::compute_shape(std::vector<shape> inputs) const
return
op
.
compute_shape
(
inputs
);
return
op
.
compute_shape
(
inputs
);
}
}
argument
hip_gather
::
compute
(
context
&
ctx
,
argument
hip_gather
::
compute
(
context
&
ctx
,
const
shape
&
,
const
std
::
vector
<
argument
>&
args
)
const
const
shape
&
,
const
std
::
vector
<
argument
>&
args
)
const
{
{
return
device
::
gather
(
ctx
.
get_stream
().
get
(),
args
.
back
(),
args
[
0
],
args
[
1
],
op
.
axis
);
return
device
::
gather
(
ctx
.
get_stream
().
get
(),
args
.
back
(),
args
[
0
],
args
[
1
],
op
.
axis
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment