Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
4a39a0f7
Commit
4a39a0f7
authored
Oct 11, 2021
by
Shucai Xiao
Browse files
Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into add-conv_bn_add-test
parents
5564172e
bb827865
Changes
542
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
921 additions
and
82 deletions
+921
-82
src/targets/gpu/device/include/migraphx/gpu/device/reduce.hpp
...targets/gpu/device/include/migraphx/gpu/device/reduce.hpp
+16
-78
src/targets/gpu/device/include/migraphx/gpu/device/reduce_ops.hpp
...ets/gpu/device/include/migraphx/gpu/device/reduce_ops.hpp
+88
-0
src/targets/gpu/device/include/migraphx/gpu/device/scan.hpp
src/targets/gpu/device/include/migraphx/gpu/device/scan.hpp
+59
-0
src/targets/gpu/device/include/migraphx/gpu/device/types.hpp
src/targets/gpu/device/include/migraphx/gpu/device/types.hpp
+15
-0
src/targets/gpu/device/layernorm.cpp
src/targets/gpu/device/layernorm.cpp
+15
-4
src/targets/gpu/device/multinomial.cpp
src/targets/gpu/device/multinomial.cpp
+66
-0
src/targets/gpu/device/nonzero.cpp
src/targets/gpu/device/nonzero.cpp
+53
-0
src/targets/gpu/device/prefix_scan_sum.cpp
src/targets/gpu/device/prefix_scan_sum.cpp
+40
-0
src/targets/gpu/device/reverse.cpp
src/targets/gpu/device/reverse.cpp
+43
-0
src/targets/gpu/device/scatter.cpp
src/targets/gpu/device/scatter.cpp
+42
-0
src/targets/gpu/device/topk.cpp
src/targets/gpu/device/topk.cpp
+216
-0
src/targets/gpu/device/where.cpp
src/targets/gpu/device/where.cpp
+39
-0
src/targets/gpu/device_name.cpp
src/targets/gpu/device_name.cpp
+43
-0
src/targets/gpu/driver/CMakeLists.txt
src/targets/gpu/driver/CMakeLists.txt
+11
-0
src/targets/gpu/driver/action.cpp
src/targets/gpu/driver/action.cpp
+27
-0
src/targets/gpu/driver/compile_pointwise.cpp
src/targets/gpu/driver/compile_pointwise.cpp
+26
-0
src/targets/gpu/driver/include/migraphx/gpu/driver/action.hpp
...targets/gpu/driver/include/migraphx/gpu/driver/action.hpp
+37
-0
src/targets/gpu/driver/include/migraphx/gpu/driver/parser.hpp
...targets/gpu/driver/include/migraphx/gpu/driver/parser.hpp
+45
-0
src/targets/gpu/driver/include/migraphx/gpu/driver/perf.hpp
src/targets/gpu/driver/include/migraphx/gpu/driver/perf.hpp
+19
-0
src/targets/gpu/driver/main.cpp
src/targets/gpu/driver/main.cpp
+21
-0
No files found.
src/targets/gpu/device/include/migraphx/gpu/device/reduce.hpp
100755 → 100644
View file @
4a39a0f7
...
...
@@ -5,85 +5,13 @@
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/visit.hpp>
#include <migraphx/gpu/device/multi_index.hpp>
#include <migraphx/gpu/device/reduce_ops.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
device
{
struct
sum
{
template
<
class
T
,
class
U
>
MIGRAPHX_DEVICE_CONSTEXPR
auto
operator
()(
T
x
,
U
y
)
const
{
return
x
+
y
;
}
};
struct
product
{
template
<
class
T
,
class
U
>
MIGRAPHX_DEVICE_CONSTEXPR
auto
operator
()(
T
x
,
U
y
)
const
{
return
x
*
y
;
}
};
struct
id
{
template
<
class
T
>
MIGRAPHX_DEVICE_CONSTEXPR
auto
operator
()(
T
x
)
const
{
return
x
;
}
};
struct
mean
{
size_t
item_num
=
1
;
template
<
class
T
>
MIGRAPHX_DEVICE_CONSTEXPR
auto
operator
()(
T
x
)
const
{
return
x
/
static_cast
<
T
>
(
item_num
);
}
};
struct
max
{
template
<
class
T
,
class
U
>
MIGRAPHX_DEVICE_CONSTEXPR
auto
operator
()(
T
x
,
U
y
)
const
{
return
(
x
>
y
)
?
x
:
y
;
}
};
struct
min
{
template
<
class
T
,
class
U
>
MIGRAPHX_DEVICE_CONSTEXPR
auto
operator
()(
T
x
,
U
y
)
const
{
return
(
x
<
y
)
?
x
:
y
;
}
};
struct
lowest
{
template
<
class
T
>
__device__
__host__
operator
T
()
const
{
return
device_cast
(
std
::
numeric_limits
<
host_type
<
T
>>::
lowest
());
}
};
struct
highest
{
template
<
class
T
>
__device__
__host__
operator
T
()
const
{
return
device_cast
(
std
::
numeric_limits
<
host_type
<
T
>>::
max
());
}
};
#ifdef MIGRAPHX_NO_DPP
template
<
index_int
N
,
class
Op
,
...
...
@@ -166,10 +94,12 @@ __device__ void dpp_reduce(T& in, Op op)
in
=
op
(
in
,
out
);
out
=
dpp_mov
<
dpp_row_shr
(
8
),
0xf
,
0xc
>
(
in
);
in
=
op
(
in
,
out
);
#if __AMDGCN_WAVEFRONT_SIZE == 64
out
=
dpp_mov
<
dpp_row_bcast
(
15
),
0xa
>
(
in
);
in
=
op
(
in
,
out
);
out
=
dpp_mov
<
dpp_row_bcast
(
31
),
0xc
>
(
in
);
in
=
op
(
in
,
out
);
#endif
}
__device__
inline
void
dpp_reduce
(
float
&
x
,
sum
)
...
...
@@ -186,9 +116,11 @@ __device__ inline void dpp_reduce(float& x, sum)
"s_nop 1
\n
"
"v_add_f32 %0 %0 %0 row_shr:8 bank_mask:0xc
\n
"
"s_nop 1
\n
"
#if __AMDGCN_WAVEFRONT_SIZE == 64
"v_add_f32 %0 %0 %0 row_bcast:15 row_mask:0xa
\n
"
"s_nop 1
\n
"
"v_add_f32 %0 %0 %0 row_bcast:31 row_mask:0xc
\n
"
#endif
"s_nop 1
\n
"
:
"=v"
(
x
)
:
"0"
(
x
));
...
...
@@ -203,21 +135,27 @@ template <index_int N,
MIGRAPHX_REQUIRES
(
not
std
::
is_integral
<
ForStride
>{})
>
__device__
auto
block_reduce
(
index
idx
,
Op
op
,
T
init
,
ForStride
fs
,
F
f
)
{
using
type
=
decltype
(
f
(
deduce_for_stride
(
fs
)));
MIGRAPHX_DEVICE_SHARED
type
buffer
[
N
/
64
];
#if __AMDGCN_WAVEFRONT_SIZE == 32
constexpr
index_int
nthreads
=
16
;
#else
constexpr
index_int
nthreads
=
64
;
#endif
using
type
=
decltype
(
f
(
deduce_for_stride
(
fs
)));
MIGRAPHX_DEVICE_SHARED
type
buffer
[
N
/
nthreads
];
type
x
=
init
;
fs
([
&
](
auto
i
)
{
x
=
op
(
x
,
f
(
i
));
});
dpp_reduce
(
x
,
op
);
const
auto
ldsidx
=
idx
.
local
/
64
;
if
((
idx
.
local
%
64
)
==
63
)
const
auto
ldsidx
=
idx
.
local
/
nthreads
;
if
((
idx
.
local
%
nthreads
)
==
nthreads
-
1
)
{
buffer
[
ldsidx
]
=
x
;
}
__syncthreads
();
type
y
=
init
;
for
(
index_int
i
=
0
;
i
<
idx
.
nlocal
()
/
64
;
i
++
)
for
(
index_int
i
=
0
;
i
<
idx
.
nlocal
()
/
nthreads
;
i
++
)
{
y
=
op
(
y
,
buffer
[
i
]);
}
...
...
src/targets/gpu/device/include/migraphx/gpu/device/reduce_ops.hpp
0 → 100755
View file @
4a39a0f7
#ifndef MIGRAPHX_GUARD_DEVICE_REDUCE_OPS_HPP
#define MIGRAPHX_GUARD_DEVICE_REDUCE_OPS_HPP
#include <migraphx/gpu/device/types.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
device
{
struct
sum
{
template
<
class
T
,
class
U
>
MIGRAPHX_DEVICE_CONSTEXPR
auto
operator
()(
T
x
,
U
y
)
const
{
return
x
+
y
;
}
};
struct
product
{
template
<
class
T
,
class
U
>
MIGRAPHX_DEVICE_CONSTEXPR
auto
operator
()(
T
x
,
U
y
)
const
{
return
x
*
y
;
}
};
struct
id
{
template
<
class
T
>
MIGRAPHX_DEVICE_CONSTEXPR
auto
operator
()(
T
x
)
const
{
return
x
;
}
};
struct
mean
{
size_t
item_num
=
1
;
template
<
class
T
>
MIGRAPHX_DEVICE_CONSTEXPR
auto
operator
()(
T
x
)
const
{
return
x
/
static_cast
<
T
>
(
item_num
);
}
};
struct
max
{
template
<
class
T
,
class
U
>
MIGRAPHX_DEVICE_CONSTEXPR
auto
operator
()(
T
x
,
U
y
)
const
{
return
(
x
>
y
)
?
x
:
y
;
}
};
struct
min
{
template
<
class
T
,
class
U
>
MIGRAPHX_DEVICE_CONSTEXPR
auto
operator
()(
T
x
,
U
y
)
const
{
return
(
x
<
y
)
?
x
:
y
;
}
};
struct
lowest
{
template
<
class
T
>
__device__
__host__
operator
T
()
const
{
return
device_cast
(
std
::
numeric_limits
<
host_type
<
T
>>::
lowest
());
}
};
struct
highest
{
template
<
class
T
>
__device__
__host__
operator
T
()
const
{
return
device_cast
(
std
::
numeric_limits
<
host_type
<
T
>>::
max
());
}
};
}
// namespace device
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_DEVICE_REDUCE_OPS_HPP
src/targets/gpu/device/include/migraphx/gpu/device/scan.hpp
0 → 100644
View file @
4a39a0f7
#ifndef MIGRAPHX_GUARD_DEVICE_SCAN_HPP
#define MIGRAPHX_GUARD_DEVICE_SCAN_HPP
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/visit.hpp>
#include <migraphx/gpu/device/multi_index.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
device
{
template
<
index_int
N
,
class
Op
,
class
T
,
class
ForStride
,
class
Input
,
class
Output
,
MIGRAPHX_REQUIRES
(
not
std
::
is_integral
<
ForStride
>{})
>
__device__
void
block_scan
(
index
idx
,
Op
op
,
T
init
,
ForStride
fs
,
Input
input
,
Output
output
)
{
using
type
=
decltype
(
input
(
deduce_for_stride
(
fs
)));
MIGRAPHX_DEVICE_SHARED
type
buffer
[
N
];
type
x
=
init
;
fs
([
&
](
auto
i
)
{
if
(
idx
.
local
==
0
)
buffer
[
idx
.
local
]
=
op
(
input
(
i
),
x
);
else
buffer
[
idx
.
local
]
=
input
(
i
);
__syncthreads
();
for
(
index_int
s
=
1
;
s
<
idx
.
nlocal
();
s
*=
2
)
{
if
(
idx
.
local
+
s
<
idx
.
nlocal
())
{
buffer
[
idx
.
local
+
s
]
=
op
(
buffer
[
idx
.
local
],
buffer
[
idx
.
local
+
s
]);
}
__syncthreads
();
}
x
=
buffer
[
idx
.
nlocal
()
-
1
];
output
(
i
,
buffer
[
idx
.
local
]);
});
}
template
<
index_int
N
,
class
Op
,
class
T
,
class
Input
,
class
Output
>
__device__
void
block_scan
(
index
idx
,
Op
op
,
T
init
,
index_int
n
,
Input
input
,
Output
output
)
{
block_scan
<
N
>
(
idx
,
op
,
init
,
[
&
](
auto
f
)
->
decltype
(
f
(
index_int
{}))
{
return
idx
.
local_stride
(
n
,
f
);
},
input
,
output
);
}
}
// namespace device
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_DEVICE_SCAN_HPP
src/targets/gpu/device/include/migraphx/gpu/device/types.hpp
View file @
4a39a0f7
...
...
@@ -129,6 +129,21 @@ __device__ __host__ T to_hip_type(T x)
// Hip doens't support __fp16
inline
__device__
__host__
float
to_hip_type
(
gpu_half
x
)
{
return
x
;
}
#define MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(trait, T) \
template <class X> \
struct trait : std::trait<X> \
{ \
}; \
\
template <> \
struct trait<T> : std::true_type \
{ \
};
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_floating_point
,
__fp16
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_signed
,
__fp16
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_arithmetic
,
__fp16
)
}
// namespace device
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/targets/gpu/device/layernorm.cpp
100755 → 100644
View file @
4a39a0f7
...
...
@@ -8,6 +8,14 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
gpu
{
namespace
device
{
#ifndef MIGRAPHX_WORKAROUND_NAVI_DPP_SYNC
#if __AMDGCN_WAVEFRONT_SIZE == 32
#define MIGRAPHX_WORKAROUND_NAVI_DPP_SYNC 1
#else
#define MIGRAPHX_WORKAROUND_NAVI_DPP_SYNC 0
#endif
#endif
template
<
class
T
>
struct
vector_type
{
...
...
@@ -86,10 +94,13 @@ __device__ void layernorm(index_int i,
const
bool
in_range
=
idx
.
local
<
relements_v
;
auto
mean
=
[
&
](
auto
z
)
{
return
auto_block_reduce
<
MaxBlockSize
>
(
idx
,
sum
{},
value_type
(
0
),
relements_v
,
[
=
](
auto
)
{
return
z
;
})
/
value_type
(
relements
);
auto
m
=
auto_block_reduce
<
MaxBlockSize
>
(
idx
,
sum
{},
value_type
(
0
),
relements_v
,
[
=
](
auto
)
{
return
z
;
})
/
value_type
(
relements
);
#if MIGRAPHX_WORKAROUND_NAVI_DPP_SYNC
__builtin_amdgcn_s_barrier
();
#endif
return
m
;
};
// m = x - mean(x)
...
...
src/targets/gpu/device/multinomial.cpp
0 → 100644
View file @
4a39a0f7
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/gpu/device/multinomial.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
device
{
template
<
class
Iterator
,
class
T
>
constexpr
Iterator
upper_bound
(
Iterator
first
,
Iterator
last
,
const
T
&
value
)
{
Iterator
it
;
typename
std
::
iterator_traits
<
Iterator
>::
difference_type
count
;
typename
std
::
iterator_traits
<
Iterator
>::
difference_type
step
;
count
=
std
::
distance
(
first
,
last
);
while
(
count
>
0
)
{
it
=
first
;
step
=
count
/
2
;
std
::
advance
(
it
,
step
);
if
(
!
(
value
<
*
it
))
{
first
=
++
it
;
count
-=
step
+
1
;
}
else
count
=
step
;
}
return
first
;
}
void
multinomial
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg0
,
const
argument
&
arg1
)
{
size_t
batch_size
=
arg0
.
get_shape
().
lens
().
front
();
size_t
class_size
=
arg0
.
get_shape
().
lens
().
back
();
size_t
sample_size
=
result
.
get_shape
().
lens
().
back
();
hip_visit_all
(
arg0
,
arg1
)([
&
](
auto
cdf
,
auto
dist
)
{
result
.
visit
([
&
](
auto
out
)
{
hip_visit_views
(
out
)([
&
](
auto
output
)
{
gs_launch
(
stream
,
batch_size
*
sample_size
)([
=
](
auto
i
)
__device__
{
auto
idx
=
output
.
get_shape
().
multi
(
i
);
auto
cdf_begin
=
cdf
.
begin
()
+
(
idx
.
front
()
*
class_size
);
auto
cdf_end
=
cdf_begin
+
class_size
;
auto
sample_iter
=
upper_bound
(
cdf_begin
,
cdf_end
,
dist
[
i
]
*
*
(
std
::
prev
(
cdf_end
)));
output
[
i
]
=
std
::
distance
(
cdf_begin
,
sample_iter
);
});
});
});
});
}
}
// namespace device
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/device/nonzero.cpp
0 → 100644
View file @
4a39a0f7
#include <migraphx/gpu/device/nonzero.hpp>
#include <migraphx/gpu/device/float_equal.hpp>
#include <migraphx/gpu/device/scan.hpp>
#include <migraphx/gpu/device/reduce_ops.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
device
{
argument
nonzero
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg_data
)
{
auto
s
=
arg_data
.
get_shape
();
auto
elem_num
=
s
.
elements
();
auto
out_elem_num
=
result
.
get_shape
().
elements
();
// call the prefix_sum function to do a prefix_sum to compute
// index in the output. Only 1 block can be used since we have
// only one prefix sum
const
index_int
block_size
=
256
;
hip_visit_all
(
arg_data
,
s
)([
&
](
auto
input
,
auto
si
)
{
const
auto
*
in_ptr
=
device_cast
(
input
.
data
());
auto
*
ptr
=
result
.
cast
<
int64_t
>
();
gs_launch
(
stream
,
block_size
,
block_size
)([
=
](
auto
,
auto
idx
)
__device__
{
// fill all output to 0 first
idx
.
local_stride
(
out_elem_num
,
[
&
](
auto
j
)
{
ptr
[
j
]
=
0
;
});
block_scan
<
block_size
>
(
idx
,
sum
{},
0
,
elem_num
,
[
&
](
auto
j
)
{
return
(
float_equal
(
in_ptr
[
j
],
0
))
?
0
:
1
;
},
[
&
](
auto
j
,
auto
x
)
{
auto
out_loc
=
x
-
1
;
if
(
float_equal
(
in_ptr
[
j
],
0
))
return
;
auto
index
=
si
.
multi
(
j
);
for
(
size_t
k
=
0
;
k
<
index
.
size
();
++
k
)
{
ptr
[
k
*
elem_num
+
out_loc
]
=
index
[
k
];
}
});
});
});
return
result
;
}
}
// namespace device
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/device/prefix_scan_sum.cpp
0 → 100644
View file @
4a39a0f7
#include <migraphx/gpu/device/prefix_scan_sum.hpp>
#include <migraphx/gpu/device/scan.hpp>
#include <migraphx/gpu/device/reduce_ops.hpp>
#include <migraphx/gpu/device/types.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
device
{
void
prefix_scan_sum
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int32_t
axis
)
{
const
index_int
block_size
=
256
;
const
index_int
n
=
arg
.
get_shape
().
lens
()[
axis
];
auto
rlens
=
result
.
get_shape
().
lens
();
rlens
[
axis
]
=
1
;
hip_visit_all
(
result
,
arg
,
result
.
get_shape
().
with_lens
(
rlens
))(
[
=
](
auto
output
,
auto
input
,
auto
rshape
)
{
gs_launch
(
stream
,
rshape
.
elements
()
*
block_size
,
block_size
)(
[
=
](
auto
i
,
auto
idx
)
__device__
{
const
auto
ridx
=
rshape
.
multi
(
i
/
block_size
);
auto
compute_idx
=
[
&
](
auto
j
)
{
auto
k
=
ridx
;
k
[
axis
]
=
j
;
return
k
;
};
block_scan
<
block_size
>
(
idx
,
sum
{},
0
,
n
,
[
&
](
auto
j
)
{
return
input
[
compute_idx
(
j
)];
},
[
&
](
auto
j
,
auto
x
)
{
output
[
compute_idx
(
j
)]
=
x
;
});
});
});
}
}
// namespace device
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/device/reverse.cpp
0 → 100644
View file @
4a39a0f7
#include "migraphx/gpu/device/visit.hpp"
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/reverse.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
device
{
argument
reverse
(
hipStream_t
stream
,
argument
result
,
argument
arg1
,
const
std
::
vector
<
int64_t
>&
axes
)
{
auto
s
=
arg1
.
get_shape
();
// auto lens = s.lens();
std
::
vector
<
std
::
size_t
>
axis_len
(
axes
.
begin
(),
axes
.
end
());
shape
sa
{
shape
::
float_type
,
axis_len
};
std
::
size_t
nelements
=
s
.
elements
();
visit_all
(
result
,
arg1
)([
&
](
auto
output1
,
auto
input1
)
{
hip_visit_views
(
output1
,
input1
,
s
)([
&
](
auto
output
,
auto
input
,
auto
hs
)
{
hip_visit_views
(
sa
)([
&
](
auto
daxes
)
{
auto
lens
=
hs
.
lens
;
gs_launch
(
stream
,
nelements
)([
=
](
auto
i
)
__device__
{
auto
idx
=
hs
.
multi
(
i
);
auto
in_idx
=
idx
;
for
(
auto
axis
:
daxes
.
lens
)
in_idx
[
axis
]
=
lens
[
axis
]
-
1
-
idx
[
axis
];
output
[
idx
]
=
input
[
in_idx
];
});
});
});
});
return
result
;
}
}
// namespace device
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/device/scatter.cpp
0 → 100644
View file @
4a39a0f7
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/scatter.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
device
{
argument
scatter
(
hipStream_t
stream
,
argument
result
,
argument
arg0
,
argument
arg1
,
argument
arg2
,
int64_t
axis
)
{
auto
ds
=
arg0
.
get_shape
();
auto
inds
=
arg1
.
get_shape
();
auto
axis_dim_size
=
ds
.
lens
()[
axis
];
hip_visit_all
(
result
,
arg0
,
inds
)([
&
](
auto
output
,
auto
data
,
auto
s1
)
{
auto
*
output_ptr
=
device_cast
(
output
.
data
());
const
auto
*
data_ptr
=
device_cast
(
data
.
data
());
gs_launch
(
stream
,
ds
.
elements
())([
=
](
auto
i
)
__device__
{
output_ptr
[
i
]
=
data_ptr
[
i
];
});
hip_visit_all
(
arg1
,
arg2
)([
&
](
auto
indices
,
auto
update
)
{
const
auto
*
upd_ptr
=
device_cast
(
update
.
data
());
const
auto
*
indices_ptr
=
device_cast
(
indices
.
data
());
gs_launch
(
stream
,
inds
.
elements
())([
=
](
auto
i
)
__device__
{
auto
out_idx
=
s1
.
multi
(
i
);
auto
index
=
indices_ptr
[
i
];
index
=
index
<
0
?
index
+
axis_dim_size
:
index
;
out_idx
[
axis
]
=
index
;
output
[
out_idx
]
=
upd_ptr
[
i
];
});
});
});
return
result
;
}
}
// namespace device
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/device/topk.cpp
0 → 100644
View file @
4a39a0f7
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/topk.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/device/visit.hpp>
#include <migraphx/ranges.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
device
{
template
<
class
T
,
class
Index
,
class
Compare
>
struct
hip_heap_vector
{
MIGRAPHX_DEVICE_CONSTEXPR
hip_heap_vector
(
T
*
val
,
index_int
n
,
Index
v_idx
,
Compare
comp
)
:
data
(
val
),
size
(
n
),
data_index
(
v_idx
),
compare
(
comp
)
{
make_heap
(
size
);
}
MIGRAPHX_DEVICE_CONSTEXPR
void
try_push
(
const
T
val
)
{
if
(
compare
(
val
,
data
[
data_index
(
0
)]))
return
;
pop_heap
(
size
-
1
);
data
[
data_index
(
size
-
1
)]
=
val
;
push_heap
(
size
-
1
);
}
MIGRAPHX_DEVICE_CONSTEXPR
void
sort
()
{
sort_heap
(
size
);
}
private:
MIGRAPHX_DEVICE_CONSTEXPR
inline
static
void
swap
(
T
&
v1
,
T
&
v2
)
{
T
v
=
v1
;
v1
=
v2
;
v2
=
v
;
}
MIGRAPHX_DEVICE_CONSTEXPR
inline
void
heapify_down
(
index_int
n
,
index_int
index
)
{
while
(
index
<
n
)
{
auto
pre_index
=
index
;
index_int
l
=
2
*
index
+
1
;
index_int
r
=
2
*
index
+
2
;
if
(
l
<
n
&&
compare
(
data
[
data_index
(
l
)],
data
[
data_index
(
index
)]))
{
index
=
l
;
}
if
(
r
<
n
&&
compare
(
data
[
data_index
(
r
)],
data
[
data_index
(
index
)]))
{
index
=
r
;
if
(
compare
(
data
[
data_index
(
l
)],
data
[
data_index
(
r
)]))
{
index
=
l
;
}
}
if
(
index
==
pre_index
)
{
break
;
}
swap
(
data
[
data_index
(
index
)],
data
[
data_index
(
pre_index
)]);
}
}
MIGRAPHX_DEVICE_CONSTEXPR
inline
void
heapify_up
(
index_int
index
)
{
while
(
index
>
0
)
{
auto
parent_idx
=
(
index
-
1
)
/
2
;
if
(
not
compare
(
data
[
data_index
(
index
)],
data
[
data_index
(
parent_idx
)]))
{
break
;
}
swap
(
data
[
data_index
(
index
)],
data
[
data_index
(
parent_idx
)]);
index
=
parent_idx
;
}
}
MIGRAPHX_DEVICE_CONSTEXPR
inline
void
make_heap
(
index_int
n
)
{
for
(
int
j
=
n
/
2
-
1
;
j
>=
0
;
--
j
)
{
heapify_down
(
n
,
j
);
}
}
MIGRAPHX_DEVICE_CONSTEXPR
inline
void
push_heap
(
index_int
loc
)
{
heapify_up
(
loc
);
}
MIGRAPHX_DEVICE_CONSTEXPR
inline
void
pop_heap
(
index_int
loc
)
{
swap
(
data
[
data_index
(
0
)],
data
[
data_index
(
loc
)]);
heapify_down
(
loc
,
0
);
}
MIGRAPHX_DEVICE_CONSTEXPR
inline
void
sort_heap
(
index_int
n
)
{
for
(
int
j
=
n
-
1
;
j
>
0
;
--
j
)
{
swap
(
data
[
data_index
(
0
)],
data
[
data_index
(
j
)]);
heapify_down
(
j
,
0
);
}
}
T
*
data
=
nullptr
;
index_int
size
;
Index
data_index
;
Compare
compare
;
};
template
<
class
T
,
class
Index
,
class
Compare
>
__device__
hip_heap_vector
<
T
,
Index
,
Compare
>
make_heap
(
T
*
data
,
index_int
n
,
Index
idx
,
Compare
compare
)
{
return
{
data
,
n
,
idx
,
compare
};
}
template
<
class
Compare
>
std
::
vector
<
argument
>
topk
(
hipStream_t
stream
,
const
argument
&
val_res
,
const
argument
&
ind_res
,
const
argument
&
arg
,
int64_t
k
,
int64_t
axis
,
Compare
compare
)
{
auto
in_s
=
arg
.
get_shape
();
auto
in_lens
=
in_s
.
lens
();
auto
out_s
=
val_res
.
get_shape
();
auto
axis_dim
=
in_s
.
lens
()[
axis
];
auto
comp_lens
=
in_lens
;
comp_lens
[
axis
]
=
1
;
shape
comp_s
{
in_s
.
type
(),
comp_lens
};
std
::
size_t
elem_num
=
comp_s
.
elements
();
hip_visit_all
(
val_res
,
arg
,
out_s
,
in_s
,
comp_s
)(
[
&
](
auto
out_val
,
auto
input
,
auto
oss
,
auto
iss
,
auto
css
)
{
auto
*
data
=
device_cast
(
input
.
data
());
auto
*
out
=
device_cast
(
out_val
.
data
());
auto
*
const
ind
=
ind_res
.
cast
<
int64_t
>
();
gs_launch
(
stream
,
elem_num
)([
=
](
auto
i
)
__device__
{
auto
idx
=
css
.
multi
(
i
);
auto
in_idx
=
[
&
](
int
ii
)
{
auto
iidx
=
idx
;
iidx
[
axis
]
=
ii
;
return
iss
.
index
(
iidx
);
};
auto
out_idx
=
[
&
](
int
ii
)
{
auto
iidx
=
idx
;
iidx
[
axis
]
=
ii
;
return
oss
.
index
(
iidx
);
};
auto
data_compare
=
[
=
](
auto
ii
,
auto
jj
)
{
return
compare
(
data
[
in_idx
(
ii
)],
data
[
in_idx
(
jj
)]);
};
for
(
int
j
=
0
;
j
<
k
;
++
j
)
{
ind
[
out_idx
(
j
)]
=
j
;
}
auto
hp
=
make_heap
(
ind
,
k
,
out_idx
,
data_compare
);
for
(
int
j
=
k
;
j
<
axis_dim
;
++
j
)
{
hp
.
try_push
(
j
);
}
hp
.
sort
();
for
(
int
j
=
0
;
j
<
k
;
++
j
)
{
out
[
out_idx
(
j
)]
=
data
[
in_idx
(
ind
[
out_idx
(
j
)])];
}
});
});
return
{
val_res
,
ind_res
};
}
argument
topk_largest
(
hipStream_t
stream
,
const
argument
&
val_res
,
const
argument
&
ind_res
,
const
argument
&
arg
,
int64_t
k
,
int64_t
axis
)
{
return
{
topk
(
stream
,
val_res
,
ind_res
,
arg
,
k
,
axis
,
std
::
less
<>
{})};
}
argument
topk_smallest
(
hipStream_t
stream
,
const
argument
&
val_res
,
const
argument
&
ind_res
,
const
argument
&
arg
,
int64_t
k
,
int64_t
axis
)
{
return
{
topk
(
stream
,
val_res
,
ind_res
,
arg
,
k
,
axis
,
std
::
greater
<>
{})};
}
}
// namespace device
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/device/where.cpp
0 → 100644
View file @
4a39a0f7
#include <migraphx/gpu/device/where.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/device/launch.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
device
{
template
<
class
Shape
>
constexpr
auto
get_rank
(
const
Shape
&
)
{
return
decltype
(
typename
Shape
::
hip_index
{}.
size
()){};
}
void
where
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg0
,
const
argument
&
arg1
,
const
argument
&
arg2
)
{
hip_visit_all
(
result
,
arg1
,
arg2
)([
&
](
auto
output
,
auto
x
,
auto
y
)
{
hip_visit_all
(
arg0
)([
&
](
auto
cond
)
{
if
constexpr
(
get_rank
(
cond
.
get_shape
())
==
get_rank
(
output
.
get_shape
()))
{
gs_launch
(
stream
,
arg1
.
get_shape
().
elements
())([
=
](
auto
idx
)
__device__
{
auto
i
=
output
.
get_shape
().
multi
(
idx
);
output
[
i
]
=
cond
[
i
]
?
x
[
i
]
:
y
[
i
];
});
}
});
});
}
}
// namespace device
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/device_name.cpp
0 → 100755
View file @
4a39a0f7
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/rank.hpp>
#include <migraphx/stringutils.hpp>
#include <hip/hip_runtime_api.h>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
template
<
class
HipDeviceProp
>
std
::
string
get_arch_name
(
rank
<
0
>
,
const
HipDeviceProp
&
props
)
{
return
"gfx"
+
std
::
to_string
(
props
.
gcnArch
);
}
template
<
class
HipDeviceProp
>
auto
get_arch_name
(
rank
<
1
>
,
const
HipDeviceProp
&
props
)
->
decltype
(
std
::
string
(
props
.
gcnArchName
))
{
return
std
::
string
(
props
.
gcnArchName
);
}
int
get_device_id
()
{
int
device
;
auto
status
=
hipGetDevice
(
&
device
);
if
(
status
!=
hipSuccess
)
MIGRAPHX_THROW
(
"No device"
);
return
device
;
}
std
::
string
get_device_name
()
{
hipDeviceProp_t
props
{};
auto
status
=
hipGetDeviceProperties
(
&
props
,
get_device_id
());
if
(
status
!=
hipSuccess
)
MIGRAPHX_THROW
(
"Failed to get device properties"
);
return
get_arch_name
(
rank
<
1
>
{},
props
);
}
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/driver/CMakeLists.txt
0 → 100755
View file @
4a39a0f7
add_executable
(
gpu-driver
action.cpp
compile_pointwise.cpp
main.cpp
parser.cpp
perf.cpp
run_op.cpp
)
target_include_directories
(
gpu-driver PRIVATE include
)
target_link_libraries
(
gpu-driver PRIVATE migraphx_gpu
)
src/targets/gpu/driver/action.cpp
0 → 100644
View file @
4a39a0f7
#include <migraphx/gpu/driver/action.hpp>
#include <migraphx/errors.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
driver
{
auto
&
action_map
()
{
static
std
::
unordered_map
<
std
::
string
,
action_function
>
m
;
return
m
;
}
action_function
get_action
(
const
std
::
string
&
name
)
{
if
(
action_map
().
count
(
name
)
==
0
)
MIGRAPHX_THROW
(
"Missing action: "
+
name
);
return
action_map
().
at
(
name
);
}
void
register_action
(
const
std
::
string
&
name
,
const
action_function
&
a
)
{
action_map
()[
name
]
=
a
;
}
}
// namespace driver
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/driver/compile_pointwise.cpp
0 → 100755
View file @
4a39a0f7
#include <migraphx/gpu/driver/action.hpp>
#include <migraphx/gpu/driver/perf.hpp>
#include <migraphx/gpu/compile_pointwise.hpp>
#include <migraphx/gpu/context.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
driver
{
struct
compile_pointwise
:
action
<
compile_pointwise
>
{
static
void
apply
(
const
parser
&
p
,
const
value
&
v
)
{
context
ctx
;
auto
inputs
=
p
.
parse_shapes
(
v
.
at
(
"inputs"
));
auto
op
=
gpu
::
compile_pointwise
(
ctx
,
inputs
,
v
.
at
(
"lambda"
).
to
<
std
::
string
>
());
double
t
=
time_op
(
ctx
,
op
,
inputs
,
p
.
get
(
v
,
"iterations"
,
100
));
std
::
cout
<<
op
<<
": "
<<
t
<<
"ms"
<<
std
::
endl
;
}
};
}
// namespace driver
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/driver/include/migraphx/gpu/driver/action.hpp
0 → 100644
View file @
4a39a0f7
#ifndef MIGRAPHX_GUARD_GPU_DRIVER_ACTION_HPP
#define MIGRAPHX_GUARD_GPU_DRIVER_ACTION_HPP
#include <migraphx/config.hpp>
#include <migraphx/auto_register.hpp>
#include <migraphx/type_name.hpp>
#include <migraphx/gpu/driver/parser.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
driver
{
using
action_function
=
std
::
function
<
void
(
const
parser
&
,
const
value
&
)
>
;
action_function
get_action
(
const
std
::
string
&
name
);
void
register_action
(
const
std
::
string
&
name
,
const
action_function
&
a
);
struct
auto_register_action
{
template
<
class
T
>
static
void
apply
()
{
auto
name
=
get_type_name
<
T
>
();
register_action
(
name
.
substr
(
name
.
rfind
(
"::"
)
+
2
),
[](
auto
&&
...
xs
)
{
T
::
apply
(
std
::
forward
<
decltype
(
xs
)
>
(
xs
)...);
});
}
};
template
<
class
T
>
using
action
=
auto_register
<
auto_register_action
,
T
>
;
}
// namespace driver
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_DRIVER_ACTION_HPP
src/targets/gpu/driver/include/migraphx/gpu/driver/parser.hpp
0 → 100644
View file @
4a39a0f7
#ifndef MIGRAPHX_GUARD_GPU_DRIVER_PARSER_HPP
#define MIGRAPHX_GUARD_GPU_DRIVER_PARSER_HPP
#include <migraphx/value.hpp>
#include <migraphx/shape.hpp>
#include <unordered_map>
#include <functional>
#include <vector>
#include <string>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
driver
{
[[
noreturn
]]
void
error
(
const
std
::
string
&
msg
);
struct
parser
{
parser
()
=
default
;
template
<
class
T
>
T
get
(
const
value
&
v
,
const
std
::
string
&
key
,
const
T
&
default_value
)
const
{
return
v
.
get
(
key
,
settings
.
get
(
key
,
default_value
));
}
shape
parse_shape
(
const
value
&
v
)
const
;
std
::
vector
<
shape
>
parse_shapes
(
const
value
&
v
)
const
;
void
load_settings
(
const
value
&
v
);
static
void
process
(
const
value
&
v
);
private:
value
settings
=
value
::
object
{};
};
}
// namespace driver
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_DRIVER_PARSER_HPP
src/targets/gpu/driver/include/migraphx/gpu/driver/perf.hpp
0 → 100755
View file @
4a39a0f7
#ifndef MIGRAPHX_GUARD_GPU_DRIVER_PERF_HPP
#define MIGRAPHX_GUARD_GPU_DRIVER_PERF_HPP
#include <migraphx/config.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/operation.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
driver
{
double
time_op
(
context
&
ctx
,
operation
op
,
const
std
::
vector
<
shape
>&
inputs
,
int
n
=
100
);
}
// namespace driver
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_DRIVER_PERF_HPP
src/targets/gpu/driver/main.cpp
0 → 100644
View file @
4a39a0f7
#include <migraphx/gpu/driver/parser.hpp>
#include <migraphx/json.hpp>
#include <migraphx/convert_to_json.hpp>
#include <migraphx/file_buffer.hpp>
#include <iostream>
using
namespace
migraphx
;
// NOLINT
using
namespace
migraphx
::
gpu
;
// NOLINT
using
namespace
migraphx
::
gpu
::
driver
;
// NOLINT
int
main
(
int
argc
,
char
const
*
argv
[])
{
std
::
vector
<
std
::
string
>
args
(
argv
,
argv
+
argc
);
if
(
args
.
size
()
<
2
)
{
std
::
cout
<<
"Usage: gpu-driver <input-file>"
<<
std
::
endl
;
std
::
abort
();
}
auto
v
=
from_json_string
(
convert_to_json
(
read_string
(
args
[
1
])));
parser
::
process
(
v
);
}
Prev
1
…
9
10
11
12
13
14
15
16
17
…
28
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment