Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
ce3048d4
Commit
ce3048d4
authored
Jun 21, 2019
by
Paul
Browse files
Refactor gather operator
parent
72c188be
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
66 additions
and
14 deletions
+66
-14
src/targets/gpu/device/gather.cpp
src/targets/gpu/device/gather.cpp
+11
-13
src/targets/gpu/device/include/migraphx/gpu/device/tensor_view.hpp
...ts/gpu/device/include/migraphx/gpu/device/tensor_view.hpp
+6
-0
src/targets/gpu/device/include/migraphx/gpu/device/visit.hpp
src/targets/gpu/device/include/migraphx/gpu/device/visit.hpp
+48
-0
src/targets/gpu/gather.cpp
src/targets/gpu/gather.cpp
+1
-1
No files found.
src/targets/gpu/device/gather.cpp
View file @
ce3048d4
...
@@ -17,20 +17,18 @@ argument gather(hipStream_t stream, argument result, argument arg1, argument arg
...
@@ -17,20 +17,18 @@ argument gather(hipStream_t stream, argument result, argument arg1, argument arg
auto
&
input_shape
=
arg1
.
get_shape
();
auto
&
input_shape
=
arg1
.
get_shape
();
auto
lens
=
input_shape
.
lens
();
auto
lens
=
input_shape
.
lens
();
lens
[
axis_index
]
=
arg2
.
get_shape
().
elements
();
lens
[
axis_index
]
=
arg2
.
get_shape
().
elements
();
shape
out_comp_shape
{
result
.
get_shape
().
type
(),
lens
};
std
::
size_t
nelements
=
result
.
get_shape
().
elements
();
std
::
size_t
nelements
=
result
.
get_shape
().
elements
();
visit_all
(
result
,
arg1
)([
&
](
auto
output
,
auto
input
)
{
arg2
.
visit
([
&
](
auto
indices
)
{
visit_all
(
result
,
arg1
)([
&
](
auto
output
,
auto
input_v
)
{
const
auto
*
indices_ptr
=
device_cast
(
indices
.
data
());
hip_visit_views
(
input_v
,
out_comp_shape
)([
&
](
auto
input
,
auto
out_comp
)
{
auto
*
out_ptr
=
device_cast
(
output
.
data
());
arg2
.
visit
([
&
](
auto
indices
)
{
const
auto
*
in_ptr
=
device_cast
(
input
.
data
());
const
auto
*
indices_ptr
=
device_cast
(
indices
.
data
());
migraphx
::
shape
out_comp_shape
{
result
.
get_shape
().
type
(),
lens
};
auto
*
output_ptr
=
device_cast
(
output
.
data
());
visit_tensor_size
(
out_comp_shape
.
lens
().
size
(),
[
&
](
auto
n_out_dim
)
{
gs_launch
(
stream
,
nelements
)([
=
](
auto
i
)
{
hip_tensor_descriptor
<
n_out_dim
>
desc_input
(
input_shape
);
auto
idx
=
out_comp
.
multi
(
i
);
hip_tensor_descriptor
<
n_out_dim
>
desc_output
(
out_comp_shape
);
idx
[
axis_index
]
=
indices_ptr
[
idx
[
axis_index
]];
gs_launch
(
stream
,
nelements
)([
=
](
auto
ii
)
{
output_ptr
[
i
]
=
input
[
idx
];
auto
in_idx
=
desc_output
.
multi
(
ii
);
in_idx
[
axis_index
]
=
indices_ptr
[
in_idx
[
axis_index
]];
out_ptr
[
ii
]
=
in_ptr
[
desc_input
.
linear
(
in_idx
)];
});
});
});
});
});
});
...
...
src/targets/gpu/device/include/migraphx/gpu/device/tensor_view.hpp
View file @
ce3048d4
...
@@ -44,6 +44,12 @@ hip_tensor_view<T, N> make_hip_view(const shape& s, T* x)
...
@@ -44,6 +44,12 @@ hip_tensor_view<T, N> make_hip_view(const shape& s, T* x)
return
{
x
,
s
};
return
{
x
,
s
};
}
}
template
<
std
::
size_t
N
,
class
T
>
hip_tensor_view
<
T
,
N
>
make_hip_view
(
tensor_view
<
T
>
x
)
{
return
{
x
};
}
}
// namespace device
}
// namespace device
}
// namespace gpu
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/targets/gpu/device/include/migraphx/gpu/device/visit.hpp
View file @
ce3048d4
...
@@ -54,10 +54,26 @@ auto get_shape(const T& x) -> decltype(x.get_shape())
...
@@ -54,10 +54,26 @@ auto get_shape(const T& x) -> decltype(x.get_shape())
template
<
class
V
,
class
F
,
class
...
Ts
>
template
<
class
V
,
class
F
,
class
...
Ts
>
void
hip_visit_all_impl
(
const
shape
&
s
,
F
f
,
V
&&
v
,
Ts
&&
...
xs
)
void
hip_visit_all_impl
(
const
shape
&
s
,
F
f
,
V
&&
v
,
Ts
&&
...
xs
)
{
{
std
::
initializer_list
<
migraphx
::
shape
::
type_t
>
types
=
{
get_shape
(
xs
).
type
()...};
if
(
!
std
::
all_of
(
types
.
begin
(),
types
.
end
(),
[
&
](
migraphx
::
shape
::
type_t
t
)
{
return
t
==
s
.
type
();
}))
MIGRAPHX_THROW
(
"Types must be the same"
);
std
::
initializer_list
<
std
::
size_t
>
ranks
=
{
get_shape
(
xs
).
lens
().
size
()...};
if
(
!
std
::
all_of
(
ranks
.
begin
(),
ranks
.
end
(),
[
&
](
std
::
size_t
r
)
{
return
r
==
s
.
lens
().
size
();
}))
MIGRAPHX_THROW
(
"Ranks must be the same"
);
visit_tensor_size
(
s
.
lens
().
size
(),
visit_tensor_size
(
s
.
lens
().
size
(),
[
&
](
auto
ndim
)
{
s
.
visit_type
([
&
](
auto
as
)
{
v
(
f
(
xs
,
ndim
,
as
)...);
});
});
[
&
](
auto
ndim
)
{
s
.
visit_type
([
&
](
auto
as
)
{
v
(
f
(
xs
,
ndim
,
as
)...);
});
});
}
}
template
<
class
V
,
class
F
,
class
...
Ts
>
void
hip_visit_views_impl
(
const
shape
&
s
,
F
f
,
V
&&
v
,
Ts
&&
...
xs
)
{
std
::
initializer_list
<
std
::
size_t
>
ranks
=
{
get_shape
(
xs
).
lens
().
size
()...};
if
(
!
std
::
all_of
(
ranks
.
begin
(),
ranks
.
end
(),
[
&
](
std
::
size_t
r
)
{
return
r
==
s
.
lens
().
size
();
}))
MIGRAPHX_THROW
(
"Ranks must be the same"
);
visit_tensor_size
(
s
.
lens
().
size
(),
[
&
](
auto
ndim
)
{
v
(
f
(
xs
,
ndim
)...);
});
}
template
<
class
F
>
template
<
class
F
>
struct
hip_convert
struct
hip_convert
{
{
...
@@ -82,6 +98,29 @@ hip_convert<F> make_hip_convert(F f)
...
@@ -82,6 +98,29 @@ hip_convert<F> make_hip_convert(F f)
return
{
f
};
return
{
f
};
}
}
template
<
class
F
>
struct
hip_convert_view
{
F
f
;
template
<
class
T
,
class
N
>
auto
operator
()(
tensor_view
<
T
>
x
,
N
ndim
)
const
{
return
make_hip_view
<
ndim
>
(
f
(
x
));
}
template
<
class
N
>
auto
operator
()(
const
shape
&
s
,
N
ndim
)
const
{
return
make_hip_shape
<
ndim
>
(
s
);
}
};
template
<
class
F
>
hip_convert_view
<
F
>
make_hip_convert_view
(
F
f
)
{
return
{
f
};
}
template
<
class
T
,
class
...
Ts
>
template
<
class
T
,
class
...
Ts
>
auto
hip_visit_all
(
T
&&
x
,
Ts
&&
...
xs
)
auto
hip_visit_all
(
T
&&
x
,
Ts
&&
...
xs
)
{
{
...
@@ -109,6 +148,15 @@ auto hip_pointer_visit_all(T&& x, Ts&&... xs)
...
@@ -109,6 +148,15 @@ auto hip_pointer_visit_all(T&& x, Ts&&... xs)
return
[
&
](
auto
f
)
{
visit_all
(
x
,
xs
...)([
&
](
auto
...
vs
)
{
f
(
device_cast
(
vs
.
data
())...);
});
};
return
[
&
](
auto
f
)
{
visit_all
(
x
,
xs
...)([
&
](
auto
...
vs
)
{
f
(
device_cast
(
vs
.
data
())...);
});
};
}
}
template
<
class
T
,
class
...
Ts
>
auto
hip_visit_views
(
T
&&
x
,
Ts
&&
...
xs
)
{
return
[
&
](
auto
f
)
{
hip_visit_views_impl
(
get_shape
(
x
),
make_hip_convert_view
([](
auto
v
)
{
return
device_cast
(
v
);
}),
f
,
x
,
xs
...);
};
}
}
// namespace device
}
// namespace device
}
// namespace gpu
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/targets/gpu/gather.cpp
View file @
ce3048d4
...
@@ -13,7 +13,7 @@ shape hip_gather::compute_shape(std::vector<shape> inputs) const
...
@@ -13,7 +13,7 @@ shape hip_gather::compute_shape(std::vector<shape> inputs) const
}
}
argument
hip_gather
::
compute
(
context
&
ctx
,
argument
hip_gather
::
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
shape
&
,
const
std
::
vector
<
argument
>&
args
)
const
const
std
::
vector
<
argument
>&
args
)
const
{
{
return
device
::
gather
(
ctx
.
get_stream
().
get
(),
args
.
back
(),
args
[
0
],
args
[
1
],
op
.
axis
);
return
device
::
gather
(
ctx
.
get_stream
().
get
(),
args
.
back
(),
args
[
0
],
args
[
1
],
op
.
axis
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment