Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
d1672f1d
Commit
d1672f1d
authored
Jun 25, 2019
by
Shucai Xiao
Browse files
code backup.
parent
3855c6af
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
149 additions
and
100 deletions
+149
-100
src/targets/gpu/device/argmax.cpp
src/targets/gpu/device/argmax.cpp
+39
-49
src/targets/gpu/device/argmin.cpp
src/targets/gpu/device/argmin.cpp
+39
-49
src/targets/gpu/include/migraphx/gpu/device/argmax.hpp
src/targets/gpu/include/migraphx/gpu/device/argmax.hpp
+1
-1
src/targets/gpu/include/migraphx/gpu/device/argmin.hpp
src/targets/gpu/include/migraphx/gpu/device/argmin.hpp
+1
-1
src/targets/gpu/include/migraphx/gpu/device/reduce_opers.hpp
src/targets/gpu/include/migraphx/gpu/device/reduce_opers.hpp
+69
-0
No files found.
src/targets/gpu/device/argmax.cpp
View file @
d1672f1d
...
...
@@ -5,6 +5,7 @@
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/device/reduce_opers.hpp>
#include <migraphx/gpu/hip.hpp>
namespace
migraphx
{
...
...
@@ -12,69 +13,58 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
gpu
{
namespace
device
{
argument
argmax
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int
axis
)
void
argmax
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int
axis
)
{
auto
lens
=
arg
.
get_shape
().
lens
();
auto
batch_lens
=
lens
;
size_t
n_dims
=
lens
[
axis
];
size_t
batch_item_num
=
lens
[
axis
];
batch_lens
[
axis
]
=
1
;
migraphx
::
shape
batch_shape
{
shape
::
float_type
,
batch_lens
};
visit_all
(
result
,
arg
)([
&
](
auto
output
,
auto
input
)
{
const
auto
*
input_ptr
=
device_cast
(
input
.
data
());
auto
*
output_ptr
=
device_cast
(
output
.
data
());
visit_tensor_size
(
batch_shape
.
lens
().
size
(),
[
&
](
auto
n_dim
)
{
hip_tensor_descriptor
<
n_dim
>
desc_batch
(
batch_shape
);
hip_tensor_descriptor
<
n_dim
>
desc_data
(
arg
.
get_shape
());
hip_visit_all
(
result
,
arg
,
batch_shape
)([
&
](
auto
output
,
auto
input
,
auto
batch
)
{
// use one block for items in one batch.
const
size_t
max_block_size
=
1024
;
size_t
block_size
=
1
;
while
(
block_size
<
max_block_size
and
block_size
<
batch_item_num
)
{
block_size
*=
2
;
}
// each block is for one batch
const
size_t
block_size
=
1024
;
launch
(
stream
,
batch_shape
.
elements
()
*
block_size
,
block_size
)([
=
](
auto
idx
)
__device__
{
size_t
thr_idx
=
idx
.
local
;
size_t
blk_idx
=
idx
.
group
;
using
type
=
device_type
<
std
::
remove_cv_t
<
typename
decltype
(
output
)
::
value_type
>>
;
launch
(
stream
,
batch_shape
.
elements
()
*
block_size
,
block_size
)([
=
](
auto
idx
)
__device__
{
size_t
thr_idx
=
idx
.
local
;
size_t
blk_idx
=
idx
.
group
;
using
type
=
device_type
<
std
::
remove_cv_t
<
typename
decltype
(
output
)
::
value_type
>>
;
auto
batch_idx
=
desc_batch
.
multi
(
blk_idx
);
auto
data_idx
=
batch_idx
;
MIGRAPHX_DEVICE_SHARED
type
lds_data
[
block_size
];
MIGRAPHX_DEVICE_SHARED
int64_t
lds_index
[
block_size
];
// load data to lds_data
size_t
item_num
=
n_dims
;
for
(
size_t
i
=
thr_idx
;
i
<
n_dims
;
i
+=
block_size
)
auto
batch_idx
=
batch
.
multi
(
blk_idx
);
auto
data_idx
=
batch_idx
;
MIGRAPHX_DEVICE_SHARED
type
lds_data
[
max_block_size
+
1
];
MIGRAPHX_DEVICE_SHARED
int64_t
lds_index
[
max_block_size
+
1
];
// load data to lds_data
size_t
round_item_num
=
(
batch_item_num
+
block_size
-
1
)
/
block_size
*
block_size
;
size_t
remaining_item_num
=
batch_item_num
;
lds_data
[
max_block_size
]
=
input
[
0
];
lds_index
[
max_block_size
]
=
0
;
for
(
size_t
i
=
thr_idx
;
i
<
round_item_num
;
i
+=
block_size
)
{
if
(
i
<
batch_item_num
)
{
data_idx
[
axis
]
=
i
;
lds_index
[
thr_idx
]
=
i
;
lds_data
[
thr_idx
]
=
input_ptr
[
desc_data
.
linear
(
data_idx
)];
__syncthreads
();
auto
size
=
(
item_num
>
block_size
)
?
block_size
:
item_num
;
auto
stride
=
(
size
+
1
)
/
2
;
while
(
true
)
{
if
(
thr_idx
+
stride
<
size
and
lds_data
[
thr_idx
]
<
lds_data
[
thr_idx
+
stride
])
{
lds_data
[
thr_idx
]
=
lds_data
[
thr_idx
+
stride
];
lds_index
[
thr_idx
]
=
lds_index
[
thr_idx
+
stride
];
}
__syncthreads
();
size
=
stride
;
stride
=
(
stride
+
1
)
/
2
;
lds_data
[
thr_idx
]
=
input
[
data_idx
];
}
__syncthreads
();
if
(
size
==
1
)
break
;
}
auto
item_num
=
(
remaining_item_num
>
block_size
)
?
block_size
:
remaining_item_num
;
reduce_argmax
(
lds_data
,
lds_index
,
block_size
,
thr_idx
,
size
,
max_block_size
);
if
(
thr_idx
==
0
)
{
output_ptr
[
blk_idx
]
=
lds_index
[
0
];
}
remaining_item_num
-=
block_size
;
}
item_num
-=
block_size
;
}
});
if
(
thr_idx
==
0
)
{
output
[
batch_idx
]
=
lds_index
[
max_block_size
];
}
});
});
...
...
src/targets/gpu/device/argmin.cpp
View file @
d1672f1d
...
...
@@ -5,6 +5,7 @@
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/device/reduce_opers.hpp>
#include <migraphx/gpu/hip.hpp>
namespace
migraphx
{
...
...
@@ -12,69 +13,58 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
gpu
{
namespace
device
{
argument
argm
ax
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int
axis
)
void
argm
in
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int
axis
)
{
auto
lens
=
arg
.
get_shape
().
lens
();
auto
batch_lens
=
lens
;
size_t
n_dims
=
lens
[
axis
];
size_t
batch_item_num
=
lens
[
axis
];
batch_lens
[
axis
]
=
1
;
migraphx
::
shape
batch_shape
{
shape
::
float_type
,
batch_lens
};
visit_all
(
result
,
arg
)([
&
](
auto
output
,
auto
input
)
{
const
auto
*
input_ptr
=
device_cast
(
input
.
data
());
auto
*
output_ptr
=
device_cast
(
output
.
data
());
visit_tensor_size
(
batch_shape
.
lens
().
size
(),
[
&
](
auto
n_dim
)
{
hip_tensor_descriptor
<
n_dim
>
desc_batch
(
batch_shape
);
hip_tensor_descriptor
<
n_dim
>
desc_data
(
arg
.
get_shape
());
hip_visit_all
(
result
,
arg
,
batch_shape
)([
&
](
auto
output
,
auto
input
,
auto
batch
)
{
// use one block for items in one batch.
const
size_t
max_block_size
=
1024
;
size_t
block_size
=
1
;
while
(
block_size
<
max_block_size
and
block_size
<
batch_item_num
)
{
block_size
*=
2
;
}
// each block is for one batch
const
size_t
block_size
=
1024
;
launch
(
stream
,
batch_shape
.
elements
()
*
block_size
,
block_size
)([
=
](
auto
idx
)
__device__
{
size_t
thr_idx
=
idx
.
local
;
size_t
blk_idx
=
idx
.
group
;
using
type
=
device_type
<
std
::
remove_cv_t
<
typename
decltype
(
output
)
::
value_type
>>
;
launch
(
stream
,
batch_shape
.
elements
()
*
block_size
,
block_size
)([
=
](
auto
idx
)
__device__
{
size_t
thr_idx
=
idx
.
local
;
size_t
blk_idx
=
idx
.
group
;
using
type
=
device_type
<
std
::
remove_cv_t
<
typename
decltype
(
output
)
::
value_type
>>
;
auto
batch_idx
=
desc_batch
.
multi
(
blk_idx
);
auto
data_idx
=
batch_idx
;
MIGRAPHX_DEVICE_SHARED
type
lds_data
[
block_size
];
MIGRAPHX_DEVICE_SHARED
int64_t
lds_index
[
block_size
];
// load data to lds_data
size_t
item_num
=
n_dims
;
for
(
size_t
i
=
thr_idx
;
i
<
n_dims
;
i
+=
block_size
)
auto
batch_idx
=
batch
.
multi
(
blk_idx
);
auto
data_idx
=
batch_idx
;
MIGRAPHX_DEVICE_SHARED
type
lds_data
[
max_block_size
+
1
];
MIGRAPHX_DEVICE_SHARED
int64_t
lds_index
[
max_block_size
+
1
];
// load data to lds_data
size_t
round_item_num
=
(
batch_item_num
+
block_size
-
1
)
/
block_size
*
block_size
;
size_t
remaining_item_num
=
batch_item_num
;
lds_data
[
max_block_size
]
=
input
[
0
];
lds_index
[
max_block_size
]
=
0
;
for
(
size_t
i
=
thr_idx
;
i
<
round_item_num
;
i
+=
block_size
)
{
if
(
i
<
batch_item_num
)
{
data_idx
[
axis
]
=
i
;
lds_index
[
thr_idx
]
=
i
;
lds_data
[
thr_idx
]
=
input_ptr
[
desc_data
.
linear
(
data_idx
)];
__syncthreads
();
auto
size
=
(
item_num
>
block_size
)
?
block_size
:
item_num
;
auto
stride
=
(
size
+
1
)
/
2
;
while
(
true
)
{
if
(
thr_idx
+
stride
<
size
and
lds_data
[
thr_idx
]
>
lds_data
[
thr_idx
+
stride
])
{
lds_data
[
thr_idx
]
=
lds_data
[
thr_idx
+
stride
];
lds_index
[
thr_idx
]
=
lds_index
[
thr_idx
+
stride
];
}
__syncthreads
();
size
=
stride
;
stride
=
(
stride
+
1
)
/
2
;
lds_data
[
thr_idx
]
=
input
[
data_idx
];
}
__syncthreads
();
if
(
size
==
1
)
break
;
}
auto
item_num
=
(
remaining_item_num
>
block_size
)
?
block_size
:
remaining_item_num
;
reduce_argmin
(
lds_data
,
lds_index
,
block_size
,
thr_idx
,
size
,
max_block_size
);
if
(
thr_idx
==
0
)
{
output_ptr
[
blk_idx
]
=
lds_index
[
0
];
}
remaining_item_num
-=
block_size
;
}
item_num
-=
block_size
;
}
});
if
(
thr_idx
==
0
)
{
output
[
batch_idx
]
=
lds_index
[
max_block_size
];
}
});
});
...
...
src/targets/gpu/include/migraphx/gpu/device/argmax.hpp
View file @
d1672f1d
...
...
@@ -10,7 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
gpu
{
namespace
device
{
argument
argmax
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int
axis
);
void
argmax
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int
axis
);
}
// namespace device
}
// namespace gpu
...
...
src/targets/gpu/include/migraphx/gpu/device/argmin.hpp
View file @
d1672f1d
...
...
@@ -10,7 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
gpu
{
namespace
device
{
argument
argmin
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int
axis
);
void
argmin
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int
axis
);
}
// namespace device
}
// namespace gpu
...
...
src/targets/gpu/include/migraphx/gpu/device/reduce_opers.hpp
View file @
d1672f1d
...
...
@@ -37,6 +37,75 @@ inline __device__ void reduce_max(T* data_ptr, size_t block_size, size_t thr_idx
__syncthreads
();
}
template
<
class
T
>
inline
__device__
void
reduce_argmax
(
T
*
data_ptr
,
int64_t
*
index_ptr
,
size_t
block_size
,
size_t
thr_idx
,
size_t
item_num
,
size_t
max_index
)
{
while
(
true
)
{
auto
stride
=
(
item_num
+
1
)
/
2
;
auto
size
=
item_num
/
2
;
for
(
size_t
i
=
thr_idx
;
i
<
size
;
i
+=
block_size
)
{
if
(
data_ptr
[
i
]
<
data_ptr
[
i
+
stride
])
{
data_ptr
[
i
]
=
data_ptr
[
i
+
stride
];
index_ptr
[
i
]
=
index_ptr
[
i
+
stride
];
}
}
__syncthreads
();
item_num
=
stride
;
if
(
item_num
==
1
)
break
;
}
if
(
thr_idx
==
0
)
{
if
(
data_ptr
[
max_index
]
<
data_ptr
[
0
])
{
data_ptr
[
max_index
]
=
data_ptr
[
0
];
index_ptr
[
max_index
]
=
index_ptr
[
0
];
}
}
__syncthreads
();
}
template
<
class
T
>
inline
__device__
void
reduce_argmin
(
T
*
data_ptr
,
int64_t
*
index_ptr
,
size_t
block_size
,
size_t
thr_idx
,
size_t
item_num
)
{
size_t
min_index
=
item_num
;
while
(
true
)
{
auto
stride
=
(
item_num
+
1
)
/
2
;
auto
size
=
item_num
/
2
;
for
(
size_t
i
=
thr_idx
;
i
<
size
;
i
+=
block_size
)
{
if
(
data_ptr
[
i
]
>
data_ptr
[
i
+
stride
])
{
data_ptr
[
i
]
=
data_ptr
[
i
+
stride
];
index_ptr
[
i
]
=
index_ptr
[
i
+
stride
];
}
}
__syncthreads
();
item_num
=
stride
;
if
(
item_num
==
1
)
break
;
}
if
(
thr_idx
==
0
)
{
if
(
data_ptr
[
min_index
]
>
data_ptr
[
0
])
{
data_ptr
[
min_index
]
=
data_ptr
[
0
];
index_ptr
[
min_index
]
=
index_ptr
[
0
];
}
}
__syncthreads
();
}
template
<
class
T
>
inline
__device__
void
reduce_sum
(
T
*
data_ptr
,
size_t
block_size
,
size_t
thr_idx
,
size_t
item_num
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment