Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
8724242e
Commit
8724242e
authored
Jun 24, 2019
by
Shucai Xiao
Browse files
further refactoring of softmax and logsoftmax.
parent
6d1c23e9
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
27 additions
and
29 deletions
+27
-29
src/targets/gpu/device/logsoftmax.cpp
src/targets/gpu/device/logsoftmax.cpp
+7
-9
src/targets/gpu/device/softmax.cpp
src/targets/gpu/device/softmax.cpp
+7
-9
src/targets/gpu/include/migraphx/gpu/device/logsoftmax.hpp
src/targets/gpu/include/migraphx/gpu/device/logsoftmax.hpp
+3
-3
src/targets/gpu/include/migraphx/gpu/device/softmax.hpp
src/targets/gpu/include/migraphx/gpu/device/softmax.hpp
+4
-4
src/targets/gpu/logsoftmax.cpp
src/targets/gpu/logsoftmax.cpp
+3
-2
src/targets/gpu/softmax.cpp
src/targets/gpu/softmax.cpp
+3
-2
No files found.
src/targets/gpu/device/logsoftmax.cpp
View file @
8724242e
...
...
@@ -11,24 +11,24 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
gpu
{
namespace
device
{
argument
logsoftmax
(
hipStream_t
stream
,
const
migraphx
::
shape
&
output_shape
,
std
::
vector
<
migraphx
::
argument
>
arg
s
,
void
logsoftmax
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int
axis
)
{
auto
lens
=
outpu
t_shape
.
lens
();
auto
lens
=
result
.
ge
t_shape
()
.
lens
();
auto
n_dims
=
lens
[
axis
];
auto
batch_lens
=
lens
;
batch_lens
[
axis
]
=
1
;
migraphx
::
shape
batch_shape
{
outpu
t_shape
.
type
(),
batch_lens
};
migraphx
::
shape
batch_shape
{
result
.
ge
t_shape
()
.
type
(),
batch_lens
};
visit_all
(
args
.
back
(),
args
.
front
()
)([
&
](
auto
output
,
auto
input
)
{
visit_all
(
result
,
arg
)([
&
](
auto
output
,
auto
input
)
{
const
auto
*
input_ptr
=
device_cast
(
input
.
data
());
auto
*
output_ptr
=
device_cast
(
output
.
data
());
visit_tensor_size
(
batch_shape
.
lens
().
size
(),
[
&
](
auto
n_dim
)
{
hip_tensor_descriptor
<
n_dim
>
desc_batch
(
batch_shape
);
hip_tensor_descriptor
<
n_dim
>
desc_data
(
outpu
t_shape
);
hip_tensor_descriptor
<
n_dim
>
desc_data
(
result
.
ge
t_shape
()
);
// use one block for items in one batch.
// opt 1, load all data to lds then use the same approach as
...
...
@@ -142,8 +142,6 @@ argument logsoftmax(hipStream_t stream,
});
});
});
return
args
.
back
();
}
}
// namespace device
...
...
src/targets/gpu/device/softmax.cpp
View file @
8724242e
...
...
@@ -12,23 +12,23 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
gpu
{
namespace
device
{
argument
softmax
(
hipStream_t
stream
,
const
migraphx
::
shape
&
output_shape
,
std
::
vector
<
migraphx
::
argument
>
arg
s
,
void
softmax
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int
axis
)
{
auto
lens
=
outpu
t_shape
.
lens
();
auto
lens
=
result
.
ge
t_shape
()
.
lens
();
auto
batch_lens
=
lens
;
size_t
n_dims
=
lens
[
axis
];
batch_lens
[
axis
]
=
1
;
migraphx
::
shape
batch_shape
{
outpu
t_shape
.
type
(),
batch_lens
};
migraphx
::
shape
batch_shape
{
result
.
ge
t_shape
()
.
type
(),
batch_lens
};
visit_all
(
args
.
back
(),
args
.
front
()
)([
&
](
auto
output
,
auto
input
)
{
visit_all
(
result
,
arg
)([
&
](
auto
output
,
auto
input
)
{
const
auto
*
input_ptr
=
device_cast
(
input
.
data
());
auto
*
output_ptr
=
device_cast
(
output
.
data
());
visit_tensor_size
(
batch_shape
.
lens
().
size
(),
[
&
](
auto
n_dim
)
{
hip_tensor_descriptor
<
n_dim
>
desc_batch
(
batch_shape
);
hip_tensor_descriptor
<
n_dim
>
desc_data
(
outpu
t_shape
);
hip_tensor_descriptor
<
n_dim
>
desc_data
(
result
.
ge
t_shape
()
);
// use one block for items in one batch.
const
size_t
max_block_size
=
1024
;
...
...
@@ -139,8 +139,6 @@ argument softmax(hipStream_t stream,
});
});
});
return
args
.
back
();
}
}
// namespace device
...
...
src/targets/gpu/include/migraphx/gpu/device/logsoftmax.hpp
View file @
8724242e
...
...
@@ -10,9 +10,9 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
gpu
{
namespace
device
{
argument
logsoftmax
(
hipStream_t
stream
,
const
migraphx
::
shape
&
output_shape
,
std
::
vector
<
migraphx
::
argument
>
arg
s
,
void
logsoftmax
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int
axis
);
}
// namespace device
...
...
src/targets/gpu/include/migraphx/gpu/device/softmax.hpp
View file @
8724242e
...
...
@@ -10,10 +10,10 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
gpu
{
namespace
device
{
argument
softmax
(
hipStream_t
stream
,
const
migraphx
::
shape
&
output_shape
,
std
::
vector
<
migraphx
::
argument
>
arg
s
,
int
axis
);
void
softmax
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int
axis
);
}
// namespace device
}
// namespace gpu
...
...
src/targets/gpu/logsoftmax.cpp
View file @
8724242e
...
...
@@ -16,10 +16,11 @@ shape hip_logsoftmax::compute_shape(const std::vector<shape>& inputs) const
}
argument
hip_logsoftmax
::
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
shape
&
,
const
std
::
vector
<
argument
>&
args
)
const
{
return
device
::
logsoftmax
(
ctx
.
get_stream
().
get
(),
output_shape
,
args
,
op
.
axis
);
device
::
logsoftmax
(
ctx
.
get_stream
().
get
(),
args
.
back
(),
args
.
front
(),
op
.
axis
);
return
args
.
back
();
}
}
// namespace gpu
...
...
src/targets/gpu/softmax.cpp
View file @
8724242e
...
...
@@ -38,10 +38,11 @@ shape hip_softmax::compute_shape(const std::vector<shape>& inputs) const
}
argument
hip_softmax
::
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
shape
&
,
const
std
::
vector
<
argument
>&
args
)
const
{
return
device
::
softmax
(
ctx
.
get_stream
().
get
(),
output_shape
,
args
,
op
.
axis
);
device
::
softmax
(
ctx
.
get_stream
().
get
(),
args
.
back
(),
args
.
front
(),
op
.
axis
);
return
args
.
back
();
}
}
// namespace gpu
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment