Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
31b53e5a
Commit
31b53e5a
authored
Jun 27, 2019
by
Shucai Xiao
Browse files
merge changes from branch softmax/logsoftmax optimization
parents
2efbf57f
6d52e887
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
3 additions
and
84 deletions
+3
-84
src/targets/gpu/device/logsoftmax.cpp
src/targets/gpu/device/logsoftmax.cpp
+0
-2
src/targets/gpu/include/migraphx/gpu/device/reduce_opers.hpp
src/targets/gpu/include/migraphx/gpu/device/reduce_opers.hpp
+3
-82
No files found.
src/targets/gpu/device/logsoftmax.cpp
View file @
31b53e5a
...
@@ -53,7 +53,6 @@ void logsoftmax(hipStream_t stream, const argument& result, const argument& arg,
...
@@ -53,7 +53,6 @@ void logsoftmax(hipStream_t stream, const argument& result, const argument& arg,
__syncthreads
();
__syncthreads
();
auto
item_num
=
(
remaining_item_num
>
block_size
)
?
block_size
:
remaining_item_num
;
auto
item_num
=
(
remaining_item_num
>
block_size
)
?
block_size
:
remaining_item_num
;
// reduce_max(lds_data, block_size, thr_idx, item_num, max_block_size);
block_reduce
<
type
,
max_op
<
type
>>
(
block_reduce
<
type
,
max_op
<
type
>>
(
lds_data
,
max_op
<
type
>
{},
block_size
,
thr_idx
,
item_num
,
max_block_size
);
lds_data
,
max_op
<
type
>
{},
block_size
,
thr_idx
,
item_num
,
max_block_size
);
...
@@ -77,7 +76,6 @@ void logsoftmax(hipStream_t stream, const argument& result, const argument& arg,
...
@@ -77,7 +76,6 @@ void logsoftmax(hipStream_t stream, const argument& result, const argument& arg,
__syncthreads
();
__syncthreads
();
auto
item_num
=
(
remaining_item_num
>
block_size
)
?
block_size
:
remaining_item_num
;
auto
item_num
=
(
remaining_item_num
>
block_size
)
?
block_size
:
remaining_item_num
;
// reduce_sum(lds_data, block_size, thr_idx, item_num, max_block_size);
block_reduce
<
type
,
sum_op
<
type
>>
(
block_reduce
<
type
,
sum_op
<
type
>>
(
lds_data
,
sum_op
<
type
>
{},
block_size
,
thr_idx
,
item_num
,
max_block_size
);
lds_data
,
sum_op
<
type
>
{},
block_size
,
thr_idx
,
item_num
,
max_block_size
);
...
...
src/targets/gpu/include/migraphx/gpu/device/reduce_opers.hpp
View file @
31b53e5a
...
@@ -13,19 +13,19 @@ namespace device {
...
@@ -13,19 +13,19 @@ namespace device {
template
<
class
T
>
template
<
class
T
>
struct
max_op
struct
max_op
{
{
T
operator
()(
T
x
,
T
y
)
{
return
(
x
>
y
)
?
x
:
y
;
}
T
operator
()(
T
x
,
T
y
)
const
{
return
(
x
>
y
)
?
x
:
y
;
}
};
};
template
<
class
T
>
template
<
class
T
>
struct
min_op
struct
min_op
{
{
T
operator
()(
T
x
,
T
y
)
{
return
(
x
<
y
)
?
x
:
y
;
}
T
operator
()(
T
x
,
T
y
)
const
{
return
(
x
<
y
)
?
x
:
y
;
}
};
};
template
<
class
T
>
template
<
class
T
>
struct
sum_op
struct
sum_op
{
{
T
operator
()(
T
x
,
T
y
)
{
return
x
+
y
;
}
T
operator
()(
T
x
,
T
y
)
const
{
return
x
+
y
;
}
};
};
template
<
class
T
,
class
Op
>
template
<
class
T
,
class
Op
>
...
@@ -42,7 +42,6 @@ inline __device__ void block_reduce(T* data_ptr,
...
@@ -42,7 +42,6 @@ inline __device__ void block_reduce(T* data_ptr,
auto
size
=
item_num
/
2
;
auto
size
=
item_num
/
2
;
for
(
std
::
size_t
i
=
thr_idx
;
i
<
size
;
i
+=
block_size
)
for
(
std
::
size_t
i
=
thr_idx
;
i
<
size
;
i
+=
block_size
)
{
{
// data_ptr[i] = ::max(to_hip_type(data_ptr[i]), to_hip_type(data_ptr[i + stride]));
data_ptr
[
i
]
=
op
(
data_ptr
[
i
],
data_ptr
[
i
+
stride
]);
data_ptr
[
i
]
=
op
(
data_ptr
[
i
],
data_ptr
[
i
+
stride
]);
}
}
__syncthreads
();
__syncthreads
();
...
@@ -60,84 +59,6 @@ inline __device__ void block_reduce(T* data_ptr,
...
@@ -60,84 +59,6 @@ inline __device__ void block_reduce(T* data_ptr,
__syncthreads
();
__syncthreads
();
}
}
template
<
class
T
>
inline
__device__
void
reduce_argmax
(
T
*
data_ptr
,
int64_t
*
index_ptr
,
std
::
size_t
block_size
,
std
::
size_t
thr_idx
,
std
::
size_t
item_num
,
std
::
size_t
max_index
)
{
while
(
true
)
{
auto
stride
=
(
item_num
+
1
)
/
2
;
auto
size
=
item_num
/
2
;
for
(
std
::
size_t
i
=
thr_idx
;
i
<
size
;
i
+=
block_size
)
{
if
(
data_ptr
[
i
]
<
data_ptr
[
i
+
stride
])
{
data_ptr
[
i
]
=
data_ptr
[
i
+
stride
];
index_ptr
[
i
]
=
index_ptr
[
i
+
stride
];
}
}
__syncthreads
();
item_num
=
stride
;
if
(
item_num
==
1
)
break
;
}
if
(
thr_idx
==
0
)
{
if
(
data_ptr
[
max_index
]
<
data_ptr
[
0
])
{
data_ptr
[
max_index
]
=
data_ptr
[
0
];
index_ptr
[
max_index
]
=
index_ptr
[
0
];
}
}
__syncthreads
();
}
template
<
class
T
>
inline
__device__
void
reduce_argmin
(
T
*
data_ptr
,
int64_t
*
index_ptr
,
std
::
size_t
block_size
,
std
::
size_t
thr_idx
,
std
::
size_t
item_num
,
std
::
size_t
min_index
)
{
while
(
true
)
{
auto
stride
=
(
item_num
+
1
)
/
2
;
auto
size
=
item_num
/
2
;
for
(
std
::
size_t
i
=
thr_idx
;
i
<
size
;
i
+=
block_size
)
{
if
(
data_ptr
[
i
]
>
data_ptr
[
i
+
stride
])
{
data_ptr
[
i
]
=
data_ptr
[
i
+
stride
];
index_ptr
[
i
]
=
index_ptr
[
i
+
stride
];
}
}
__syncthreads
();
item_num
=
stride
;
if
(
item_num
==
1
)
break
;
}
if
(
thr_idx
==
0
)
{
if
(
data_ptr
[
min_index
]
>
data_ptr
[
0
])
{
data_ptr
[
min_index
]
=
data_ptr
[
0
];
index_ptr
[
min_index
]
=
index_ptr
[
0
];
}
}
__syncthreads
();
}
}
// namespace device
}
// namespace device
}
// namespace gpu
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment