Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
8724242e
Commit
8724242e
authored
Jun 24, 2019
by
Shucai Xiao
Browse files
further refactoring of softmax and logsoftmax.
parent
6d1c23e9
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
27 additions
and
29 deletions
+27
-29
src/targets/gpu/device/logsoftmax.cpp
src/targets/gpu/device/logsoftmax.cpp
+7
-9
src/targets/gpu/device/softmax.cpp
src/targets/gpu/device/softmax.cpp
+7
-9
src/targets/gpu/include/migraphx/gpu/device/logsoftmax.hpp
src/targets/gpu/include/migraphx/gpu/device/logsoftmax.hpp
+3
-3
src/targets/gpu/include/migraphx/gpu/device/softmax.hpp
src/targets/gpu/include/migraphx/gpu/device/softmax.hpp
+4
-4
src/targets/gpu/logsoftmax.cpp
src/targets/gpu/logsoftmax.cpp
+3
-2
src/targets/gpu/softmax.cpp
src/targets/gpu/softmax.cpp
+3
-2
No files found.
src/targets/gpu/device/logsoftmax.cpp
View file @
8724242e
...
...
@@ -11,24 +11,24 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
gpu
{
namespace
device
{
argument
logsoftmax
(
hipStream_t
stream
,
const
migraphx
::
shape
&
output_shape
,
std
::
vector
<
migraphx
::
argument
>
arg
s
,
void
logsoftmax
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int
axis
)
{
auto
lens
=
outpu
t_shape
.
lens
();
auto
lens
=
result
.
ge
t_shape
()
.
lens
();
auto
n_dims
=
lens
[
axis
];
auto
batch_lens
=
lens
;
batch_lens
[
axis
]
=
1
;
migraphx
::
shape
batch_shape
{
outpu
t_shape
.
type
(),
batch_lens
};
migraphx
::
shape
batch_shape
{
result
.
ge
t_shape
()
.
type
(),
batch_lens
};
visit_all
(
args
.
back
(),
args
.
front
()
)([
&
](
auto
output
,
auto
input
)
{
visit_all
(
result
,
arg
)([
&
](
auto
output
,
auto
input
)
{
const
auto
*
input_ptr
=
device_cast
(
input
.
data
());
auto
*
output_ptr
=
device_cast
(
output
.
data
());
visit_tensor_size
(
batch_shape
.
lens
().
size
(),
[
&
](
auto
n_dim
)
{
hip_tensor_descriptor
<
n_dim
>
desc_batch
(
batch_shape
);
hip_tensor_descriptor
<
n_dim
>
desc_data
(
outpu
t_shape
);
hip_tensor_descriptor
<
n_dim
>
desc_data
(
result
.
ge
t_shape
()
);
// use one block for items in one batch.
// opt 1, load all data to lds then use the same approach as
...
...
@@ -142,8 +142,6 @@ argument logsoftmax(hipStream_t stream,
});
});
});
return
args
.
back
();
}
}
// namespace device
...
...
src/targets/gpu/device/softmax.cpp
View file @
8724242e
...
...
@@ -12,23 +12,23 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
gpu
{
namespace
device
{
argument
softmax
(
hipStream_t
stream
,
const
migraphx
::
shape
&
output_shape
,
std
::
vector
<
migraphx
::
argument
>
arg
s
,
void
softmax
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int
axis
)
{
auto
lens
=
outpu
t_shape
.
lens
();
auto
lens
=
result
.
ge
t_shape
()
.
lens
();
auto
batch_lens
=
lens
;
size_t
n_dims
=
lens
[
axis
];
batch_lens
[
axis
]
=
1
;
migraphx
::
shape
batch_shape
{
outpu
t_shape
.
type
(),
batch_lens
};
migraphx
::
shape
batch_shape
{
result
.
ge
t_shape
()
.
type
(),
batch_lens
};
visit_all
(
args
.
back
(),
args
.
front
()
)([
&
](
auto
output
,
auto
input
)
{
visit_all
(
result
,
arg
)([
&
](
auto
output
,
auto
input
)
{
const
auto
*
input_ptr
=
device_cast
(
input
.
data
());
auto
*
output_ptr
=
device_cast
(
output
.
data
());
visit_tensor_size
(
batch_shape
.
lens
().
size
(),
[
&
](
auto
n_dim
)
{
hip_tensor_descriptor
<
n_dim
>
desc_batch
(
batch_shape
);
hip_tensor_descriptor
<
n_dim
>
desc_data
(
outpu
t_shape
);
hip_tensor_descriptor
<
n_dim
>
desc_data
(
result
.
ge
t_shape
()
);
// use one block for items in one batch.
const
size_t
max_block_size
=
1024
;
...
...
@@ -139,8 +139,6 @@ argument softmax(hipStream_t stream,
});
});
});
return
args
.
back
();
}
}
// namespace device
...
...
src/targets/gpu/include/migraphx/gpu/device/logsoftmax.hpp
View file @
8724242e
...
...
@@ -10,9 +10,9 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
gpu
{
namespace
device
{
argument
logsoftmax
(
hipStream_t
stream
,
const
migraphx
::
shape
&
output_shape
,
std
::
vector
<
migraphx
::
argument
>
arg
s
,
void
logsoftmax
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int
axis
);
}
// namespace device
...
...
src/targets/gpu/include/migraphx/gpu/device/softmax.hpp
View file @
8724242e
...
...
@@ -10,10 +10,10 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
gpu
{
namespace
device
{
argument
softmax
(
hipStream_t
stream
,
const
migraphx
::
shape
&
output_shape
,
std
::
vector
<
migraphx
::
argument
>
arg
s
,
int
axis
);
void
softmax
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int
axis
);
}
// namespace device
}
// namespace gpu
...
...
src/targets/gpu/logsoftmax.cpp
View file @
8724242e
...
...
@@ -16,10 +16,11 @@ shape hip_logsoftmax::compute_shape(const std::vector<shape>& inputs) const
}
argument
hip_logsoftmax
::
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
shape
&
,
const
std
::
vector
<
argument
>&
args
)
const
{
return
device
::
logsoftmax
(
ctx
.
get_stream
().
get
(),
output_shape
,
args
,
op
.
axis
);
device
::
logsoftmax
(
ctx
.
get_stream
().
get
(),
args
.
back
(),
args
.
front
(),
op
.
axis
);
return
args
.
back
();
}
}
// namespace gpu
...
...
src/targets/gpu/softmax.cpp
View file @
8724242e
...
...
@@ -38,10 +38,11 @@ shape hip_softmax::compute_shape(const std::vector<shape>& inputs) const
}
argument
hip_softmax
::
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
shape
&
,
const
std
::
vector
<
argument
>&
args
)
const
{
return
device
::
softmax
(
ctx
.
get_stream
().
get
(),
output_shape
,
args
,
op
.
axis
);
device
::
softmax
(
ctx
.
get_stream
().
get
(),
args
.
back
(),
args
.
front
(),
op
.
axis
);
return
args
.
back
();
}
}
// namespace gpu
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment