Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
30c49503
Commit
30c49503
authored
Mar 23, 2023
by
Khalique Ahmed
Browse files
manual merge
parents
870a396b
09aaa63e
Changes
202
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
532 additions
and
143 deletions
+532
-143
src/targets/gpu/include/migraphx/gpu/name.hpp
src/targets/gpu/include/migraphx/gpu/name.hpp
+0
-1
src/targets/gpu/include/migraphx/gpu/prefuse_ops.hpp
src/targets/gpu/include/migraphx/gpu/prefuse_ops.hpp
+2
-2
src/targets/gpu/include/migraphx/gpu/target.hpp
src/targets/gpu/include/migraphx/gpu/target.hpp
+0
-1
src/targets/gpu/jit/reduce.cpp
src/targets/gpu/jit/reduce.cpp
+7
-7
src/targets/gpu/kernels/include/migraphx/kernels/array.hpp
src/targets/gpu/kernels/include/migraphx/kernels/array.hpp
+1
-1
src/targets/gpu/kernels/include/migraphx/kernels/debug.hpp
src/targets/gpu/kernels/include/migraphx/kernels/debug.hpp
+4
-0
src/targets/gpu/kernels/include/migraphx/kernels/dpp.hpp
src/targets/gpu/kernels/include/migraphx/kernels/dpp.hpp
+1
-1
src/targets/gpu/kernels/include/migraphx/kernels/gathernd.hpp
...targets/gpu/kernels/include/migraphx/kernels/gathernd.hpp
+13
-19
src/targets/gpu/kernels/include/migraphx/kernels/hip.hpp
src/targets/gpu/kernels/include/migraphx/kernels/hip.hpp
+8
-5
src/targets/gpu/kernels/include/migraphx/kernels/index.hpp
src/targets/gpu/kernels/include/migraphx/kernels/index.hpp
+71
-12
src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp
...argets/gpu/kernels/include/migraphx/kernels/layernorm.hpp
+18
-11
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
+2
-3
src/targets/gpu/kernels/include/migraphx/kernels/ops.hpp
src/targets/gpu/kernels/include/migraphx/kernels/ops.hpp
+12
-3
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
+292
-51
src/targets/gpu/kernels/include/migraphx/kernels/shape.hpp
src/targets/gpu/kernels/include/migraphx/kernels/shape.hpp
+0
-8
src/targets/gpu/kernels/include/migraphx/kernels/softmax.hpp
src/targets/gpu/kernels/include/migraphx/kernels/softmax.hpp
+9
-7
src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
...gets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
+24
-0
src/targets/gpu/kernels/include/migraphx/kernels/types.hpp
src/targets/gpu/kernels/include/migraphx/kernels/types.hpp
+39
-2
src/targets/gpu/lowering.cpp
src/targets/gpu/lowering.cpp
+16
-2
src/targets/gpu/prefuse_ops.cpp
src/targets/gpu/prefuse_ops.cpp
+13
-7
No files found.
src/targets/gpu/include/migraphx/gpu/name.hpp
View file @
30c49503
...
...
@@ -56,7 +56,6 @@ struct oper
return
name
.
substr
(
pos_ns
+
2
);
}
}
return
"unknown_operator_name"
;
}
};
...
...
src/targets/gpu/include/migraphx/gpu/prefuse_ops.hpp
View file @
30c49503
...
...
@@ -30,14 +30,14 @@
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
module
;
struct
module
_pass_manager
;
namespace
gpu
{
struct
prefuse_ops
{
std
::
string
name
()
const
{
return
"gpu::prefuse_ops"
;
}
void
apply
(
module
&
m
)
const
;
void
apply
(
module
_pass_manager
&
mp
m
)
const
;
};
}
// namespace gpu
...
...
src/targets/gpu/include/migraphx/gpu/target.hpp
View file @
30c49503
...
...
@@ -37,7 +37,6 @@ struct target
std
::
string
name
()
const
;
std
::
vector
<
pass
>
get_passes
(
migraphx
::
context
&
gctx
,
const
compile_options
&
options
)
const
;
migraphx
::
context
get_context
()
const
;
argument
copy_to
(
const
argument
&
arg
)
const
;
argument
copy_from
(
const
argument
&
arg
)
const
;
argument
allocate
(
const
shape
&
s
)
const
;
...
...
src/targets/gpu/jit/reduce.cpp
View file @
30c49503
...
...
@@ -118,17 +118,17 @@ struct reduce_compiler : compiler<reduce_compiler>
options
.
virtual_inputs
=
reduce_dims
(
inputs
);
auto
faxis
=
find_fast_axis
({
options
.
virtual_inputs
.
front
()});
vectorize
vec
{};
// Vectorize if the axis is a reduction axis
if
(
options
.
virtual_inputs
.
back
().
lens
()[
faxis
]
==
1
)
{
vec
=
vectorize
::
elements
(
ctx
,
faxis
,
options
.
virtual_inputs
);
}
auto
relements
=
get_reduce_elements
(
options
.
virtual_inputs
)
/
vec
.
size
;
auto
nelements
=
options
.
virtual_inputs
.
back
().
elements
();
auto
algo
=
v
.
get
(
"algo"
,
get_reduce_algo
(
options
.
virtual_inputs
));
if
(
algo
==
"block"
)
{
// Vectorize if the axis is a reduction axis
if
(
options
.
virtual_inputs
.
back
().
lens
()[
faxis
]
==
1
)
vec
=
vectorize
::
elements
(
ctx
,
faxis
,
options
.
virtual_inputs
);
auto
relements
=
get_reduce_elements
(
options
.
virtual_inputs
)
/
vec
.
size
;
auto
block_size
=
compute_block_size
(
relements
,
256
);
if
(
relements
>=
block_size
*
256
)
algo
=
"block_large"
;
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
nelements
*
block_size
,
256
),
block_size
);
}
...
...
@@ -166,7 +166,7 @@ struct reduce_compiler : compiler<reduce_compiler>
auto
reduce_elements
=
get_reduce_elements
(
ins
->
inputs
());
auto
reduce_type
=
ins
->
inputs
().
front
()
->
get_shape
().
type
();
v
[
"reduction"
]
=
"op::sum{}"
;
std
::
string
mean
=
"op::mean
{
"
+
std
::
to_string
(
reduce_elements
)
+
"}"
;
std
::
string
mean
=
"op::mean
<
"
+
std
::
to_string
(
reduce_elements
)
+
"
>{
}"
;
// Use float accumulator when reduction size is too large for half
if
(
reduce_type
==
shape
::
half_type
and
reduce_elements
>
16384
)
v
[
"read"
]
=
"compose("
+
mean
+
", op::convert_to<float>{})"
;
...
...
src/targets/gpu/kernels/include/migraphx/kernels/array.hpp
View file @
30c49503
...
...
@@ -105,7 +105,7 @@ constexpr auto array_for_each(T& x, Ts&... xs)
}
else
{
using
vec_type
=
std
::
remove_reference_t
<
decltype
(
array2vec
(
x
))
>
;
using
vec_type
=
remove_reference_t
<
decltype
(
array2vec
(
x
))
>
;
f
(
array2vec
(
x
),
__builtin_convertvector
(
array2vec
(
xs
),
vec_type
)...);
}
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/debug.hpp
View file @
30c49503
...
...
@@ -178,5 +178,9 @@ MIGRAPHX_HIP_NORETURN inline __host__ __device__ void assert_fail(const source_l
#define MIGRAPHX_WARN(...)
#endif
#define MIGRAPHX_STATIC_ASSERT_FOR(...) \
static_assert(__VA_ARGS__); \
if constexpr(__VA_ARGS__)
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_DEBUG_HPP
src/targets/gpu/kernels/include/migraphx/kernels/dpp.hpp
View file @
30c49503
...
...
@@ -72,7 +72,7 @@ __device__ T dpp_mov(T& x)
}
return
output
.
data
;
}
#endif
#endif
// MIGRAPHX_HAS_DPP
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_DPP_HPP
src/targets/gpu/kernels/include/migraphx/kernels/gathernd.hpp
View file @
30c49503
...
...
@@ -26,7 +26,7 @@
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/ops.hpp>
namespace
migraphx
{
template
<
class
T
>
...
...
@@ -53,23 +53,17 @@ __device__ void gathernd(const T& data_t, const U& indices_t, const V& output_t,
auto
indices_shape_lens
=
indices_shape
.
lens
;
auto
data_shape_lens
=
data_shape
.
lens
;
auto
num_slice_dims
=
indices_shape_lens
.
back
();
std
::
size_t
num_slices
=
accumulate
(
indices_shape_lens
.
begin
(),
indices_shape_lens
.
end
()
-
1
,
1
,
std
::
multiplies
<
std
::
size_t
>
());
std
::
size_t
slice_size
=
accumulate
(
data_shape_lens
.
begin
()
+
num_slice_dims
+
batch_dims
,
std
::
size_t
num_slices
=
accumulate
(
indices_shape_lens
.
begin
(),
indices_shape_lens
.
end
()
-
1
,
1
,
op
::
product
{});
std
::
size_t
slice_size
=
accumulate
(
data_shape_lens
.
begin
()
+
num_slice_dims
+
batch_dims
,
data_shape_lens
.
end
(),
1
,
std
::
multiplies
<
std
::
size_t
>
());
const
std
::
size_t
num_batches
=
accumulate
(
data_shape_lens
.
begin
(),
data_shape_lens
.
begin
()
+
batch_dims
,
1
,
std
::
multiplies
<
std
::
size_t
>
());
const
std
::
size_t
data_batch_stride
=
accumulate
(
data_shape_lens
.
begin
()
+
batch_dims
,
data_shape_lens
.
end
(),
1
,
std
::
multiplies
<
std
::
size_t
>
());
const
auto
num_slices_per_batch
=
num_slices
/
num_batches
;
op
::
product
{});
const
std
::
size_t
num_batches
=
accumulate
(
data_shape_lens
.
begin
(),
data_shape_lens
.
begin
()
+
batch_dims
,
1
,
op
::
product
{});
const
std
::
size_t
data_batch_stride
=
accumulate
(
data_shape_lens
.
begin
()
+
batch_dims
,
data_shape_lens
.
end
(),
1
,
op
::
product
{});
const
auto
num_slices_per_batch
=
num_slices
/
num_batches
;
ind
.
global_stride
(
output_shape
.
elements
(),
[
&
](
auto
i
)
{
const
auto
*
indices_ptr
=
indices_t
.
data
();
...
...
@@ -83,15 +77,15 @@ __device__ void gathernd(const T& data_t, const U& indices_t, const V& output_t,
int64_t
index
=
slice_indices
[
idx
];
const
std
::
size_t
input_dim_idx
=
batch_dims
+
idx
;
const
auto
input_dim
=
data_shape_lens
[
input_dim_idx
];
assert
(
index
>=
-
static_cast
<
int64_t
>
(
input_dim
)
and
index
<
static_cast
<
int64_t
>
(
input_dim
));
MIGRAPHX_ASSERT
(
index
>=
-
static_cast
<
int64_t
>
(
input_dim
)
and
index
<
static_cast
<
int64_t
>
(
input_dim
));
if
(
index
<
0
)
index
+=
input_dim
;
std
::
size_t
size_from_slice_dims
=
accumulate
(
data_shape_lens
.
begin
()
+
batch_dims
+
idx
+
1
,
data_shape_lens
.
begin
()
+
batch_dims
+
num_slice_dims
,
slice_size
,
std
::
multiplies
<
std
::
size_t
>
()
);
op
::
product
{}
);
relative_slice_offset
+=
index
*
size_from_slice_dims
;
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/hip.hpp
View file @
30c49503
...
...
@@ -24,11 +24,14 @@
#ifndef MIGRAPHX_GUARD_KERNELS_HIP_HPP
#define MIGRAPHX_GUARD_KERNELS_HIP_HPP
// Workaround macro redefinition issue with clang tidy
#if defined(__HIP_PLATFORM_HCC__) && defined(MIGRAPHX_USE_CLANG_TIDY)
#undef __HIP_PLATFORM_HCC__ // NOLINT
#endif
#ifndef MIGRAPHX_USE_HIPRTC
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <hip/math_functions.h>
#include <hip/hip_math_constants.h>
#elif defined(MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS)
#include <hip/hip_common.h>
#include <hip/hip_math_constants.h>
#endif
#endif // MIGRAPHX_GUARD_KERNELS_HIP_HPP
src/targets/gpu/kernels/include/migraphx/kernels/index.hpp
View file @
30c49503
...
...
@@ -29,6 +29,7 @@
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/type_traits.hpp>
#include <migraphx/kernels/debug.hpp>
#include <migraphx/kernels/functional.hpp>
namespace
migraphx
{
...
...
@@ -135,42 +136,100 @@ struct index
return
(
n
-
_c
<
1
>
)
/
stride
+
_c
<
1
>
;
}
template
<
class
N
>
constexpr
auto
max_global_stride_iterations
(
N
n
)
const
{
return
max_stride_iterations
(
n
,
nglobal
());
}
template
<
class
N
>
constexpr
auto
max_local_stride_iterations
(
N
n
)
const
{
return
max_stride_iterations
(
n
,
nlocal
());
}
template
<
class
F
,
class
I
,
class
D
>
static
constexpr
auto
invoke_loop
(
F
f
,
I
i
,
D
d
)
->
decltype
(
f
(
i
,
d
))
{
return
f
(
i
,
d
);
}
template
<
class
F
,
class
I
,
class
D
>
static
constexpr
auto
invoke_loop
(
F
f
,
I
i
,
D
)
->
decltype
(
f
(
i
))
{
return
f
(
i
);
}
template
<
class
F
,
class
N
,
class
Stride
>
static
constexpr
void
for_stride_loop_unroll
(
index_int
start
,
N
n
,
Stride
stride
,
F
f
)
{
sequence
(
max_stride_iterations
(
n
,
stride
),
[
&
](
auto
...
ks
)
{
fold
([
&
](
auto
d
,
auto
k
)
{
auto
i
=
start
+
stride
*
k
;
if
(
i
<
n
)
invoke_loop
(
f
,
i
,
d
);
return
d
+
_c
<
1
>
;
})(
_c
<
0
>
,
ks
...);
});
}
template
<
class
F
,
class
N
,
class
Stride
>
static
constexpr
void
for_stride_loop
(
index_int
start
,
N
n
,
Stride
stride
,
F
f
)
{
index_int
k
=
0
;
for
(
index_int
i
=
start
;
i
<
n
;
i
+=
stride
)
{
invoke_loop
(
f
,
i
,
k
);
k
++
;
}
}
template
<
bool
Unroll
,
class
F
,
class
N
,
class
Stride
>
static
constexpr
void
for_stride
(
index_int
start
,
N
n
,
Stride
stride
,
F
f
)
{
MIGRAPHX_ASSERT
(
start
<
stride
);
if
constexpr
(
not
is_integral
<
N
>
{}
and
not
is_integral
<
Stride
>
{}
and
max_stride_iterations
(
n
,
stride
)
==
1
)
if
constexpr
(
not
is_integral
<
N
>
{}
and
not
is_integral
<
Stride
>
{})
{
if
constexpr
(
stride
>
n
)
if
constexpr
(
max_stride_iterations
(
n
,
stride
)
==
1
)
{
if
constexpr
(
stride
>
n
)
{
if
(
start
<
n
)
invoke_loop
(
f
,
start
,
_c
<
0
>
);
}
else
{
invoke_loop
(
f
,
start
,
_c
<
0
>
);
}
}
else
if
constexpr
(
Unroll
)
{
if
(
start
<
n
)
f
(
start
);
MIGRAPHX_STATIC_ASSERT_FOR
(
max_stride_iterations
(
n
,
stride
)
<
256
)
{
for_stride_loop_unroll
(
start
,
n
,
stride
,
f
);
}
}
else
{
f
(
start
);
f
or_stride_loop
(
start
,
n
,
stride
,
f
);
}
}
else
{
for
(
index_int
i
=
start
;
i
<
n
;
i
+=
stride
)
{
f
(
i
);
}
for_stride_loop
(
start
,
n
,
stride
,
f
);
}
}
template
<
class
F
,
class
N
>
__device__
void
global_stride
(
N
n
,
F
f
)
const
{
for_stride
(
global
,
n
,
nglobal
(),
f
);
for_stride
<
false
>
(
global
,
n
,
nglobal
(),
f
);
}
template
<
class
F
,
class
N
>
__device__
void
local_stride
(
N
n
,
F
f
)
const
{
for_stride
(
local
,
n
,
nlocal
(),
f
);
for_stride
<
true
>
(
local
,
n
,
nlocal
(),
f
);
}
};
...
...
src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp
View file @
30c49503
...
...
@@ -46,28 +46,35 @@ template <index_int Axis,
__device__
void
generic_binary_layernorm
(
F
compute
,
BinOp
op
,
float
eps
,
Output
output
,
Input1
input1
,
Input2
input2
,
Inputs
...
inputs
)
{
using
block
=
reduce
::
auto_block
<
reduce
::
reduce_elements_with_axis
<
Input1
,
Axis
>
()
>
;
using
reduce_output
=
reduce
::
with_axis
<
Input1
,
Axis
>
;
reduce
::
block
::
run
<
reduce_output
>
([
&
](
auto
,
auto
r
)
{
using
value_type
=
typename
Input1
::
type
;
constexpr
auto
relements
=
r
.
template
elements
<
Input1
>();
auto
means
=
r
.
reduce
(
op
::
sum
{},
make_array
<
vec_type
<
value_type
>>
(
0
,
0
),
[
&
](
auto
x1
,
auto
x2
)
{
auto
x
=
op
(
x1
,
x2
);
return
make_array
(
x
,
x
*
x
)
*
vec_type
<
value_type
>
{
1.0
/
relements
};
})(
input1
,
input2
);
block
::
template
run
<
reduce_output
>([
&
](
auto
,
auto
r
)
{
auto
input
=
r
.
inner
([
&
](
auto
x1
,
auto
x2
)
{
return
op
(
x1
,
x2
);
})(
input1
,
input2
);
using
value_type
=
typename
Input1
::
type
;
constexpr
auto
relements
=
r
.
template
elements
<
Input1
>();
constexpr
auto
relements_r
=
vec_type
<
value_type
>
{
1.0
/
relements
};
auto
relements_rsqrt
=
sqrt
(
relements_r
);
auto
means
=
r
.
reduce
(
op
::
sum
{},
make_array
<
vec_type
<
value_type
>>
(
0
,
0
),
[
&
](
auto
x
)
{
auto
x_out
=
x
*
relements_r
;
// dividing x by sqrt(relements) before squaring allows computing higher values
// before overflow in low precision
auto
x2_sqrt
=
x
*
relements_rsqrt
;
return
make_array
(
x_out
,
x2_sqrt
*
x2_sqrt
);
})(
input
);
auto
mean_x
=
means
[
0
];
auto
mean_x2
=
means
[
1
];
auto
variance
=
mean_x2
-
(
mean_x
*
mean_x
);
value_type
eps_val
=
eps
;
// implicit conversion for eps
r
.
inner
([
&
](
auto
&
y
,
auto
x1
,
auto
x2
,
auto
...
xs
)
{
auto
x
=
op
(
x1
,
x2
);
r
.
inner
([
&
](
auto
&
y
,
auto
x
,
auto
...
xs
)
{
auto
m
=
x
-
mean_x
;
// m * rsqrt(mean(m ^ 2) + epsilon)
y
=
compute
(
m
*
rsqrt
(
variance
+
eps_val
),
xs
...);
})(
output
,
input
1
,
input2
,
inputs
...);
})(
output
,
input
,
inputs
...);
});
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
View file @
30c49503
...
...
@@ -28,8 +28,7 @@
#include <migraphx/kernels/vec.hpp>
#include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/type_traits.hpp>
#include <hip/hip_fp16.h>
#include <hip/math_functions.h>
#include <migraphx/kernels/hip.hpp>
namespace
migraphx
{
...
...
@@ -222,7 +221,7 @@ constexpr auto min(const T& a, const U& b)
template
<
class
T
,
MIGRAPHX_REQUIRES
(
is_same
<
vec_type
<
T
>,
half
>
{})
>
constexpr
T
sin
(
T
x
)
{
constexpr
const
T
shift
=
M_PI_2
;
constexpr
const
T
shift
=
HIP_PIO2_F
;
return
migraphx
::
cos
(
shift
-
x
);
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/ops.hpp
View file @
30c49503
...
...
@@ -66,13 +66,22 @@ struct convert_to
}
};
template
<
index_int
N
>
struct
mean
{
index_int
item_num
=
1
;
template
<
class
T
>
MIGRAPHX_DEVICE_CONSTEXPR
auto
operator
()(
T
x
)
const
MIGRAPHX_DEVICE_CONSTEXPR
T
operator
()(
T
x
)
const
{
return
x
/
static_cast
<
T
>
(
item_num
);
using
type
=
vec_type
<
T
>
;
if
constexpr
(
is_floating_point
<
type
>
{})
{
constexpr
type
d
=
1.0
/
N
;
return
x
*
d
;
}
else
{
return
x
/
static_cast
<
type
>
(
N
);
}
}
};
...
...
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
View file @
30c49503
...
...
@@ -103,10 +103,10 @@ __device__ auto block_reduce(index idx, Op op, T init, Index n, F f)
#else
constexpr
index_int
lanes_per_thread
=
64
;
#endif
using
type
=
decltype
(
f
(
0
));
using
type
=
decltype
(
index
::
invoke_loop
(
f
,
0
,
_c
<
0
>
));
__shared__
type
buffer
[
idx
.
max_nlocal
()
/
lanes_per_thread
];
type
x
=
init
;
idx
.
local_stride
(
n
,
[
&
](
auto
i
)
{
x
=
op
(
x
,
f
(
i
));
});
idx
.
local_stride
(
n
,
[
&
](
auto
i
,
auto
d
)
{
x
=
op
(
x
,
index
::
invoke_loop
(
f
,
i
,
d
));
});
dpp_reduce
(
x
,
op
);
const
auto
ldsidx
=
idx
.
local
/
lanes_per_thread
;
...
...
@@ -128,10 +128,10 @@ template <class Op, class T, class Index, class F>
__device__
auto
block_reduce
(
index
idx
,
Op
op
,
T
init
,
Index
n
,
F
f
)
{
MIGRAPHX_ASSERT
(
idx
.
max_nlocal
()
==
idx
.
nlocal
());
using
type
=
decltype
(
f
(
0
));
using
type
=
decltype
(
index
::
invoke_loop
(
f
,
0
,
_c
<
0
>
));
__shared__
type
buffer
[
idx
.
max_nlocal
()];
type
x
=
init
;
idx
.
local_stride
(
n
,
[
&
](
auto
i
)
{
x
=
op
(
x
,
f
(
i
));
});
idx
.
local_stride
(
n
,
[
&
](
auto
i
,
auto
d
)
{
x
=
op
(
x
,
index
::
invoke_loop
(
f
,
i
,
d
));
});
buffer
[
idx
.
local
]
=
x
;
__syncthreads
();
...
...
@@ -167,6 +167,25 @@ constexpr auto reduce_slice(Input input, T i)
namespace
reduce
{
struct
inner_storage_tag
{
};
template
<
class
T
>
using
is_inner_storage
=
is_base_of
<
inner_storage_tag
,
remove_cv_t
<
remove_reference_t
<
T
>>>
;
template
<
class
R
,
class
F
>
struct
storage_access
:
F
{
using
type
=
R
;
};
template
<
class
R
,
class
F
>
constexpr
storage_access
<
R
,
F
>
make_storage_access
(
F
f
)
{
return
{{
f
}};
}
template
<
class
Slicer
,
class
F
>
constexpr
auto
sliced
(
Slicer
slicer
,
F
f
)
{
...
...
@@ -191,20 +210,140 @@ constexpr auto compute_reduce_axis()
template
<
class
Input
,
index_int
Axis
>
using
with_axis
=
decltype
(
compute_reduce_axis
<
Input
,
Axis
>
());
template
<
class
Derived
>
struct
reducer_base
{
template
<
class
T
>
__device__
auto
make_inner_slice
(
T
x
)
const
{
if
constexpr
(
is_inner_storage
<
T
>
{})
{
return
x
;
}
else
{
auto
&&
derived
=
static_cast
<
const
Derived
&>
(
*
this
);
auto
t
=
derived
.
slice
(
x
);
return
make_storage_access
<
typename
decltype
(
t
)
::
type
>
([
=
](
auto
i
,
auto
...)
->
auto
&
{
return
t
[
i
];
});
}
}
template
<
class
T
,
class
...
Ts
>
constexpr
auto
get_size
(
T
&&
x
,
[[
maybe_unused
]]
Ts
&&
...
xs
)
const
{
MIGRAPHX_ASSERT
(
get_size
(
x
)
==
get_size
(
xs
...));
return
get_size
(
x
);
}
template
<
class
T
,
class
...
Ts
>
constexpr
auto
get_size
(
T
&&
x
)
const
{
if
constexpr
(
is_inner_storage
<
T
>
{})
{
return
x
.
rsize
();
}
else
{
auto
&&
derived
=
static_cast
<
const
Derived
&>
(
*
this
);
auto
t
=
derived
.
slice
(
x
);
return
t
.
size
();
}
}
template
<
class
F
>
__device__
auto
inner_sliced
(
F
f
)
const
{
return
[
=
](
auto
&&
...
xs
)
{
return
f
(
get_size
(
xs
...),
make_inner_slice
(
xs
)...);
};
}
template
<
class
T
>
static
__device__
typename
T
::
type
&
decl_inner_storage
(
const
T
&
);
template
<
class
F
>
__device__
auto
inner
(
F
f
)
const
{
return
this
->
inner_sliced
([
=
](
auto
n
,
auto
&&
...
xs
)
{
using
result_type
=
decltype
(
f
(
decl_inner_storage
(
xs
)...));
auto
&&
derived
=
static_cast
<
const
Derived
&>
(
*
this
);
if
constexpr
(
is_void
<
result_type
>
{})
{
derived
.
inner_void_impl
(
f
,
n
,
xs
...);
}
else
{
return
derived
.
template
inner_impl
<
result_type
>(
f
,
n
,
xs
...);
}
});
}
template
<
class
Op
,
class
T
,
class
Read
>
__device__
auto
reduce
(
Op
op
,
T
init
,
Read
read
)
const
{
return
this
->
inner_sliced
([
=
](
auto
n
,
auto
&&
...
xs
)
{
auto
&&
derived
=
static_cast
<
const
Derived
&>
(
*
this
);
return
derived
.
reduce_impl
(
op
,
init
,
read
,
n
,
xs
...);
});
}
template
<
class
Op
,
class
T
>
__device__
auto
reduce
(
Op
op
,
T
init
)
const
{
return
this
->
reduce
(
op
,
init
,
op
::
id
{});
}
template
<
class
F
>
__device__
void
outer
(
F
f
)
const
{
f
();
}
template
<
class
Input
>
constexpr
auto
elements
()
const
{
auto
&&
derived
=
static_cast
<
const
Derived
&>
(
*
this
);
using
reduce_type
=
decltype
(
derived
.
slice
(
Input
{}));
using
value_type
=
typename
Input
::
type
;
constexpr
auto
relements
=
get_shape_c
<
reduce_type
>
{}.
elements
();
if
constexpr
(
vec_size
<
value_type
>
()
>
1
)
return
relements
*
vec_size
<
value_type
>
();
else
return
relements
;
}
};
struct
block
{
template
<
class
Slicer
>
struct
reducer
struct
reducer
:
reducer_base
<
reducer
<
Slicer
>>
{
index
idx
;
Slicer
slice
;
template
<
class
Op
,
class
T
,
class
Read
>
__device__
auto
reduce
(
Op
op
,
T
init
,
Read
read
)
const
template
<
class
T
,
index_int
N
,
class
Size
>
struct
inner_storage
:
inner_storage_tag
{
using
type
=
T
;
array
<
T
,
N
>
arr
;
constexpr
Size
rsize
()
const
{
return
{};
}
template
<
class
U
,
class
V
>
constexpr
auto
&
operator
()(
U
,
V
d
)
const
{
return
arr
[
d
];
}
template
<
class
U
,
class
V
>
constexpr
auto
&
operator
()(
U
,
V
d
)
{
return
arr
[
d
];
}
};
template
<
class
Op
,
class
T
,
class
Read
,
class
N
,
class
...
Ts
>
__device__
auto
reduce_impl
(
Op
op
,
T
init
,
Read
read
,
N
n
,
Ts
&&
...
xs
)
const
{
return
sliced
(
slice
,
[
=
](
auto
x
,
auto
...
xs
)
{
return
block_reduce
(
idx
,
op
,
init
,
x
.
get_shape
().
elements
(),
[
&
](
auto
j
)
{
return
vec_reduce
(
read
(
x
[
j
],
xs
[
j
]...),
op
);
});
return
block_reduce
(
idx
,
op
,
init
,
n
,
[
&
](
auto
j
,
auto
d
)
{
return
vec_reduce
(
read
(
xs
(
j
,
d
)...),
op
);
});
}
...
...
@@ -215,31 +354,99 @@ struct block
f
();
}
template
<
class
F
>
__device__
auto
inner
(
F
f
)
const
template
<
class
F
,
class
N
,
class
...
Ts
>
__device__
void
inner_void_impl
(
F
f
,
N
n
,
Ts
&&
...
xs
)
const
{
idx
.
local_stride
(
n
,
[
&
](
auto
j
,
auto
d
)
{
f
(
xs
(
j
,
d
)...);
});
}
template
<
class
R
,
class
F
,
class
N
,
class
...
Ts
>
__device__
auto
inner_impl
(
F
f
,
N
n
,
Ts
&&
...
xs
)
const
{
return
sliced
(
slice
,
[
=
](
auto
x
,
auto
...
xs
)
{
idx
.
local_stride
(
x
.
get_shape
().
elements
(),
[
&
](
auto
j
)
{
f
(
x
[
j
],
xs
[
j
]...);
});
using
max_iterations
=
decltype
(
idx
.
max_local_stride_iterations
(
n
));
inner_storage
<
R
,
max_iterations
{},
N
>
storage
;
idx
.
local_stride
(
n
,
[
&
](
auto
j
,
auto
d
)
{
storage
(
j
,
d
)
=
f
(
xs
(
j
,
d
)...);
});
return
storage
;
}
};
template
<
class
Slicer
>
static
__device__
auto
make
(
index
idx
,
Slicer
slicer
)
{
return
reducer
<
Slicer
>
{{},
idx
,
slicer
};
}
template
<
class
Output
,
class
F
>
static
__device__
void
run
(
F
f
)
{
auto
idx
=
make_index
();
constexpr
auto
nelements
=
get_shape_c
<
Output
>
{}.
elements
();
idx
.
global_stride
(
nelements
*
idx
.
nlocal
(),
[
&
](
auto
i
)
{
const
auto
out_idx
=
get_shape_c
<
Output
>
{}.
multi
(
i
/
idx
.
nlocal
());
f
(
out_idx
,
make
(
idx
,
[
&
](
auto
input
)
{
return
reduce_slice
<
Output
>
(
input
,
out_idx
);
}));
});
}
};
struct
block_large
{
template
<
class
Slicer
>
struct
reducer
:
reducer_base
<
reducer
<
Slicer
>>
{
index
idx
;
Slicer
slice
;
template
<
class
Size
,
class
F
>
struct
inner_storage
:
inner_storage_tag
{
using
type
=
remove_reference_t
<
decltype
(
declval
<
F
>
()(
0
,
_c
<
0
>
))
>
;
F
f
;
constexpr
Size
rsize
()
const
{
return
{};
}
template
<
class
U
,
class
V
>
constexpr
auto
operator
()(
U
j
,
V
d
)
const
{
return
f
(
j
,
d
);
}
};
template
<
class
Size
,
class
F
>
static
constexpr
inner_storage
<
Size
,
F
>
make_inner_storage
(
Size
,
F
f
)
{
return
{{},
{
f
}};
}
template
<
class
Op
,
class
T
,
class
Read
,
class
N
,
class
...
Ts
>
__device__
auto
reduce_impl
(
Op
op
,
T
init
,
Read
read
,
N
n
,
Ts
&&
...
xs
)
const
{
return
block_reduce
(
idx
,
op
,
init
,
index_int
{
n
},
[
&
](
auto
j
,
auto
d
)
{
return
vec_reduce
(
read
(
xs
(
j
,
d
)...),
op
);
});
}
template
<
class
Input
>
constexpr
auto
elements
(
)
const
template
<
class
F
>
__device__
void
outer
(
F
f
)
const
{
using
reduce_type
=
decltype
(
slice
(
Input
{}));
using
value_type
=
typename
Input
::
type
;
constexpr
auto
relements
=
get_shape_c
<
reduce_type
>
{}.
elements
();
if
constexpr
(
vec_size
<
value_type
>
()
>
1
)
return
relements
*
vec_size
<
value_type
>
();
else
return
relements
;
if
(
idx
.
local
==
0
)
f
();
}
template
<
class
F
,
class
N
,
class
...
Ts
>
__device__
void
inner_void_impl
(
F
f
,
N
n
,
Ts
&&
...
xs
)
const
{
idx
.
local_stride
(
index_int
{
n
},
[
&
](
auto
j
,
auto
d
)
{
f
(
xs
(
j
,
d
)...);
});
}
template
<
class
R
,
class
F
,
class
N
,
class
...
Ts
>
__device__
auto
inner_impl
(
F
f
,
N
n
,
Ts
&&
...
xs
)
const
{
return
make_inner_storage
(
n
,
[
=
](
auto
j
,
auto
d
)
{
return
f
(
xs
(
j
,
d
)...);
});
}
};
template
<
class
Slicer
>
static
__device__
auto
make
(
index
idx
,
Slicer
slicer
)
{
return
reducer
<
Slicer
>
{
idx
,
slicer
};
return
reducer
<
Slicer
>
{
{},
idx
,
slicer
};
}
template
<
class
Output
,
class
F
>
...
...
@@ -257,22 +464,40 @@ struct block
struct
lane
{
template
<
class
Slicer
>
struct
reducer
struct
reducer
:
reducer_base
<
reducer
<
Slicer
>>
{
index
idx
;
Slicer
slice
;
template
<
class
Op
,
class
T
,
class
Read
>
__device__
auto
reduce
(
Op
op
,
T
init
,
Read
read
)
const
template
<
class
Size
,
class
F
>
struct
inner_storage
:
inner_storage_tag
{
return
sliced
(
slice
,
[
=
](
auto
x
,
auto
...
xs
)
{
using
type
=
typename
decltype
(
x
)
::
type
;
type
r
=
init
;
for
(
index_int
j
=
0
;
j
<
x
.
get_shape
().
elements
();
j
++
)
{
r
=
op
(
r
,
read
(
x
[
j
],
xs
[
j
]...));
}
return
r
;
});
using
type
=
remove_reference_t
<
decltype
(
declval
<
F
>
()(
0
,
_c
<
0
>
))
>
;
F
f
;
constexpr
Size
rsize
()
const
{
return
{};
}
template
<
class
U
,
class
V
>
constexpr
auto
operator
()(
U
j
,
V
d
)
const
{
return
f
(
j
,
d
);
}
};
template
<
class
Size
,
class
F
>
static
constexpr
inner_storage
<
Size
,
F
>
make_inner_storage
(
Size
,
F
f
)
{
return
{{},
{
f
}};
}
template
<
class
Op
,
class
T
,
class
Read
,
class
N
,
class
U
,
class
...
Us
>
__device__
auto
reduce_impl
(
Op
op
,
T
init
,
Read
read
,
N
n
,
U
&&
x
,
Us
&&
...
xs
)
const
{
using
type
=
remove_reference_t
<
decltype
(
x
(
0
,
_c
<
0
>
))
>
;
type
r
=
init
;
for
(
index_int
j
=
0
;
j
<
n
;
j
++
)
{
r
=
op
(
r
,
read
(
x
(
j
,
_c
<
0
>
),
xs
(
j
,
_c
<
0
>
)...));
}
return
r
;
}
template
<
class
F
>
...
...
@@ -281,29 +506,25 @@ struct lane
f
();
}
template
<
class
F
>
__device__
auto
inner
(
F
f
)
const
template
<
class
F
,
class
N
,
class
...
Ts
>
__device__
void
inner
_void_impl
(
F
f
,
N
n
,
Ts
&&
...
xs
)
const
{
return
sliced
(
slice
,
[
=
](
auto
x
,
auto
...
xs
)
{
for
(
index_int
j
=
0
;
j
<
x
.
get_shape
().
elements
();
j
++
)
{
f
(
x
[
j
],
xs
[
j
]...);
}
});
for
(
index_int
j
=
0
;
j
<
n
;
j
++
)
{
f
(
xs
(
j
,
_c
<
0
>
)...);
}
}
template
<
class
Input
>
constexpr
auto
elements
(
)
const
template
<
class
R
,
class
F
,
class
N
,
class
...
Ts
>
__device__
auto
inner_impl
(
F
f
,
N
n
,
Ts
&&
...
xs
)
const
{
using
reduce_type
=
decltype
(
slice
(
Input
{}));
return
get_shape_c
<
reduce_type
>
{}.
elements
();
return
make_inner_storage
(
n
,
[
=
](
auto
j
,
auto
d
)
{
return
f
(
xs
(
j
,
d
)...);
});
}
};
template
<
class
Slicer
>
static
__device__
auto
make
(
index
idx
,
Slicer
slicer
)
{
return
reducer
<
Slicer
>
{
idx
,
slicer
};
return
reducer
<
Slicer
>
{
{},
idx
,
slicer
};
}
template
<
class
Output
,
class
F
>
...
...
@@ -318,6 +539,26 @@ struct lane
}
};
// TODO: Remove these in the future when they can be selected in the compiler class
template
<
index_int
RElements
>
constexpr
auto
pick_block
()
{
using
nlocal
=
decltype
(
index
{}.
max_nlocal
());
if
constexpr
(
RElements
<
nlocal
{}
*
256
)
return
block
{};
else
return
block_large
{};
}
template
<
index_int
RElements
>
using
auto_block
=
decltype
(
pick_block
<
RElements
>
());
template
<
class
Input
,
index_int
Axis
>
constexpr
auto
reduce_elements_with_axis
()
{
constexpr
auto
s
=
get_shape_c
<
Input
>
{};
return
s
.
lens
[
Axis
];
}
}
// namespace reduce
template
<
class
Algo
,
...
...
src/targets/gpu/kernels/include/migraphx/kernels/shape.hpp
View file @
30c49503
...
...
@@ -76,14 +76,6 @@ struct shape
constexpr
index_int
index
(
index_array
x
)
const
{
return
x
.
dot
(
strides
);
}
constexpr
index_int
index
(
std
::
initializer_list
<
index_int
>
x
)
const
{
index_int
idx
=
0
;
for
(
index_int
i
=
0
;
i
<
x
.
size
();
i
++
)
idx
+=
*
(
x
.
begin
()
+
i
)
*
strides
[
i
];
return
idx
;
}
constexpr
index_int
index
(
index_int
i
)
const
{
if
(
this
->
standard
())
...
...
src/targets/gpu/kernels/include/migraphx/kernels/softmax.hpp
View file @
30c49503
...
...
@@ -30,18 +30,20 @@
namespace
migraphx
{
template
<
index_int
Axis
,
class
Input
,
class
Output
>
__device__
void
softmax
(
Input
input
,
Output
output
)
__device__
void
softmax
(
Input
input
1
,
Output
output
)
{
reduce
::
block
::
run
<
reduce
::
with_axis
<
Input
,
Axis
>>
([
&
](
auto
,
auto
r
)
{
using
block
=
reduce
::
auto_block
<
reduce
::
reduce_elements_with_axis
<
Input
,
Axis
>
()
>
;
block
::
template
run
<
reduce
::
with_axis
<
Input
,
Axis
>
>
([
&
](
auto
,
auto
r
)
{
auto
input
=
r
.
inner
(
op
::
id
{})(
input1
);
#ifdef MIGRAPHX_USE_FAST_SOFTMAX
const
auto
c
=
vec_at
(
r
.
slice
(
input
)[
0
],
0
);
const
auto
c
=
vec_at
(
r
.
slice
(
input
1
)[
0
],
0
);
#else
const
auto
c
=
r
.
reduce
(
op
::
max
{},
lowest
{},
op
::
id
{})(
input
);
#endif
auto
batch_sum
=
r
.
reduce
(
op
::
sum
{},
0
,
[
&
](
auto
x
)
{
return
migraphx
::
convert
<
float
>
(
migraphx
::
exp
(
x
-
c
));
})(
input
);
r
.
inner
([
&
](
auto
&
y
,
auto
x
)
{
y
=
migraphx
::
exp
(
x
-
c
)
/
batch_sum
;
})(
output
,
input
);
auto
exp_in
=
r
.
inner
([
&
](
auto
x
)
{
return
migraphx
::
exp
(
x
-
c
);
})(
input
);
auto
batch_sum
=
r
.
reduce
(
op
::
sum
{},
0
,
[](
auto
x
)
{
return
migraphx
::
convert
<
float
>
(
x
);
})(
exp_in
);
r
.
inner
([
&
](
auto
&
y
,
auto
x
)
{
y
=
x
/
batch_sum
;
})(
output
,
exp_in
);
});
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
View file @
30c49503
...
...
@@ -141,6 +141,25 @@ MIGRAPHX_BUILTIN_TYPE_TRAITN(is_constructible);
MIGRAPHX_BUILTIN_TYPE_TRAITN
(
is_nothrow_constructible
);
MIGRAPHX_BUILTIN_TYPE_TRAITN
(
is_trivially_constructible
);
template
<
class
T
>
struct
remove_cv
{
using
type
=
T
;
};
template
<
class
T
>
struct
remove_cv
<
const
T
>
:
remove_cv
<
T
>
{
};
template
<
class
T
>
struct
remove_cv
<
volatile
T
>
:
remove_cv
<
T
>
{
};
template
<
class
T
>
using
remove_cv_t
=
typename
remove_cv
<
T
>::
type
;
template
<
class
T
>
struct
remove_reference
{
...
...
@@ -168,6 +187,11 @@ struct add_pointer : type_identity<typename remove_reference<T>::type*>
template
<
class
T
>
using
add_pointer_t
=
typename
add_pointer
<
T
>::
type
;
template
<
class
T
>
struct
is_void
:
is_same
<
void
,
remove_cv_t
<
T
>>
{
};
template
<
class
...
Ts
>
struct
common_type
;
...
...
src/targets/gpu/kernels/include/migraphx/kernels/types.hpp
View file @
30c49503
...
...
@@ -28,8 +28,45 @@
namespace
migraphx
{
using
index_int
=
std
::
uint32_t
;
using
diff_int
=
std
::
int32_t
;
#if defined(MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS) and defined(MIGRAPHX_USE_HIPRTC)
using
int8_t
=
signed
char
;
using
uint8_t
=
unsigned
char
;
using
int16_t
=
signed
short
;
using
uint16_t
=
unsigned
short
;
using
int32_t
=
signed
int
;
using
uint32_t
=
unsigned
int
;
using
int64_t
=
signed
long
long
;
using
uint64_t
=
unsigned
long
long
;
#elif defined(MIGRAPHX_USE_HIPRTC)
using
int8_t
=
__hip_int8_t
;
using
uint8_t
=
__hip_uint8_t
;
using
int16_t
=
__hip_int16_t
;
using
uint16_t
=
__hip_uint16_t
;
using
int32_t
=
__hip_int32_t
;
using
uint32_t
=
__hip_uint32_t
;
using
int64_t
=
__hip_int64_t
;
using
uint64_t
=
__hip_uint64_t
;
#else
using
int8_t
=
std
::
int8_t
;
using
uint8_t
=
std
::
uint8_t
;
using
int16_t
=
std
::
int16_t
;
using
uint16_t
=
std
::
uint16_t
;
using
int32_t
=
std
::
int32_t
;
using
uint32_t
=
std
::
uint32_t
;
using
int64_t
=
std
::
int64_t
;
using
uint64_t
=
std
::
uint64_t
;
#endif // MIGRAPHX_USE_HIPRTC
using
index_int
=
uint32_t
;
using
diff_int
=
int32_t
;
static_assert
(
sizeof
(
int8_t
)
==
1
,
"int8_t must be 1 bytes"
);
static_assert
(
sizeof
(
uint8_t
)
==
1
,
"uint8_t must be 1 bytes"
);
static_assert
(
sizeof
(
int16_t
)
==
2
,
"int16_t must be 2 bytes"
);
static_assert
(
sizeof
(
uint16_t
)
==
2
,
"uint16_t must be 2 bytes"
);
static_assert
(
sizeof
(
int32_t
)
==
4
,
"int32_t must be 4 bytes"
);
static_assert
(
sizeof
(
uint32_t
)
==
4
,
"uint32_t must be 4 bytes"
);
static_assert
(
sizeof
(
int64_t
)
==
8
,
"int64_t must be 8 bytes"
);
static_assert
(
sizeof
(
uint64_t
)
==
8
,
"uint64_t must be 8 bytes"
);
#define MIGRAPHX_DEVICE_CONSTEXPR constexpr __device__ __host__ // NOLINT
...
...
src/targets/gpu/lowering.cpp
View file @
30c49503
...
...
@@ -83,8 +83,7 @@ struct miopen_apply
auto
&
ctx
=
get_context
();
int8_x4_format
=
get_int8_x4_format
(
ctx
);
compute_fp32
=
get_compute_fp32_flag
();
offload_copy
=
(
mod
->
name
()
==
"main"
)
?
pass
->
offload_copy
:
false
;
offload_copy
=
(
mod
->
name
()
==
"main"
)
?
pass
->
offload_copy
:
false
;
add_generic_op
(
"contiguous"
);
...
...
@@ -112,6 +111,7 @@ struct miopen_apply
add_loop_op
();
add_neg_op
();
add_nms_op
();
add_select_module_op
();
}
void
copy_params
()
const
...
...
@@ -359,6 +359,20 @@ struct miopen_apply
return
mod
->
replace_instruction
(
ins
,
gpu_out
);
});
}
/**
* Adds dynamic allocation for submodule output parameter.
*/
void
add_select_module_op
()
{
apply_map
.
emplace
(
"select_module"
,
[
=
](
instruction_ref
ins
)
{
auto
s
=
ins
->
get_shape
();
auto
output
=
insert_allocation
(
ins
,
s
);
std
::
vector
<
instruction_ref
>
inputs
=
ins
->
inputs
();
inputs
.
push_back
(
output
);
return
mod
->
replace_instruction
(
ins
,
ins
->
get_operator
(),
inputs
,
ins
->
module_inputs
());
});
}
};
void
lowering
::
apply
(
module
&
m
)
const
{
miopen_apply
{
&
m
,
this
}.
apply
();
}
...
...
src/targets/gpu/prefuse_ops.cpp
View file @
30c49503
...
...
@@ -26,6 +26,8 @@
#include <migraphx/check_shapes.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -90,7 +92,9 @@ struct find_layernorm
{
auto
ins
=
r
.
result
;
auto
x_ins
=
r
.
instructions
[
"x"
];
auto
eps
=
r
.
instructions
[
"eps"
]
->
eval
().
at
<
float
>
();
float
eps
=
0
;
if
(
contains
(
r
.
instructions
,
"eps"
))
eps
=
r
.
instructions
[
"eps"
]
->
eval
().
at
<
float
>
();
m
.
replace_instruction
(
ins
,
layernorm
{
eps
},
x_ins
);
}
...
...
@@ -100,24 +104,26 @@ struct find_add_layernorm
{
auto
matcher
()
const
{
return
match
::
layernorm
(
)(
match
::
v
ar
(
"x"
)
(
match
::
name
(
"add"
)(
match
::
used_once
()).
bind
(
"add"
)));
return
match
::
name
(
"gpu::pre
layernorm
"
)(
match
::
ar
gs
(
match
::
name
(
"add"
)(
match
::
used_once
()).
bind
(
"add"
)));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
add_ins
=
r
.
instructions
[
"add"
];
auto
eps
=
r
.
instructions
[
"eps"
]
->
eval
().
at
<
float
>
(
);
auto
op
=
any_cast
<
layernorm
>
(
ins
->
get_operator
()
);
m
.
replace_instruction
(
ins
,
add_layernorm
{
eps
},
add_ins
->
inputs
());
m
.
replace_instruction
(
ins
,
add_layernorm
{
op
.
epsilon
},
add_ins
->
inputs
());
}
};
}
// namespace
void
prefuse_ops
::
apply
(
module
&
m
)
const
void
prefuse_ops
::
apply
(
module
_pass_manager
&
mp
m
)
const
{
match
::
find_matches
(
m
,
find_add_layernorm
{},
find_layernorm
{});
match
::
find_matches
(
mpm
.
get_module
(),
find_layernorm
{});
mpm
.
run_pass
(
dead_code_elimination
{});
match
::
find_matches
(
mpm
.
get_module
(),
find_add_layernorm
{});
}
}
// namespace gpu
...
...
Prev
1
2
3
4
5
6
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment