Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
8e68f6ca
Commit
8e68f6ca
authored
Jul 24, 2019
by
Paul
Browse files
Add fused mul add
parent
16864eef
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
154 additions
and
15 deletions
+154
-15
src/targets/gpu/CMakeLists.txt
src/targets/gpu/CMakeLists.txt
+1
-0
src/targets/gpu/device/include/migraphx/gpu/device/nary.hpp
src/targets/gpu/device/include/migraphx/gpu/device/nary.hpp
+80
-9
src/targets/gpu/fuse_ops.cpp
src/targets/gpu/fuse_ops.cpp
+55
-6
test/gpu/miopen.cpp
test/gpu/miopen.cpp
+18
-0
No files found.
src/targets/gpu/CMakeLists.txt
View file @
8e68f6ca
...
@@ -16,6 +16,7 @@ add_library(migraphx_device
...
@@ -16,6 +16,7 @@ add_library(migraphx_device
device/argmin.cpp
device/argmin.cpp
device/max.cpp
device/max.cpp
device/min.cpp
device/min.cpp
device/mul_add.cpp
device/exp.cpp
device/exp.cpp
device/erf.cpp
device/erf.cpp
device/log.cpp
device/log.cpp
...
...
src/targets/gpu/device/include/migraphx/gpu/device/nary.hpp
View file @
8e68f6ca
...
@@ -118,10 +118,66 @@ void nary_broadcast_impl(hipStream_t stream, F f, argument result, argument barg
...
@@ -118,10 +118,66 @@ void nary_broadcast_impl(hipStream_t stream, F f, argument result, argument barg
});
});
}
}
template
<
class
F
,
class
...
Arguments
>
void
nary_double_broadcast_vec_impl
(
hipStream_t
stream
,
F
f
,
argument
result
,
argument
barg1
,
argument
barg2
,
Arguments
...
args
)
{
assert
(
barg1
.
get_shape
()
==
barg2
.
get_shape
());
const
auto
&
output_shape
=
result
.
get_shape
();
const
auto
&
b_shape
=
barg1
.
get_shape
();
auto
bdim
=
std
::
distance
(
b_shape
.
strides
().
begin
(),
std
::
find_if
(
b_shape
.
strides
().
begin
(),
b_shape
.
strides
().
end
(),
[](
auto
x
)
{
return
x
!=
0
;
}));
auto
bdim_len
=
output_shape
.
lens
()[
bdim
];
auto
bdim_stride
=
output_shape
.
strides
()[
bdim
];
auto
bdim_next_stride
=
bdim_stride
*
bdim_len
;
const
std
::
size_t
vec_size
=
4
;
const
std
::
size_t
nlocal
=
1024
;
const
std
::
size_t
nglobal
=
256
*
nlocal
;
const
std
::
size_t
bdim_vec_len
=
bdim_len
/
vec_size
;
hip_vec_visit_all
<
vec_size
>
(
result
,
barg1
,
barg2
,
args
...)(
[
&
](
auto
output
,
auto
binput1
,
auto
binput2
,
auto
...
inputs
)
{
using
type
=
typename
decltype
(
output
)
::
value_type
;
const
std
::
size_t
nelements
=
output
.
size
()
/
vec_size
;
launch
(
stream
,
nglobal
,
nlocal
)([
=
](
auto
idx
)
__device__
{
MIGRAPHX_DEVICE_SHARED
type
buffer
[
2048
/
vec_size
];
// Load bias into LDS
for
(
size_t
i
=
idx
.
local
;
i
<
bdim_vec_len
;
i
+=
nlocal
)
{
buffer
[
i
]
=
binput1
.
data
()[
i
];
}
for
(
size_t
i
=
idx
.
local
;
i
<
bdim_vec_len
;
i
+=
nlocal
)
{
buffer
[
i
+
bdim_vec_len
]
=
binput2
.
data
()[
i
];
}
__syncthreads
();
auto
*
bp
=
as_pointer
(
buffer
);
// Process the data
for
(
size_t
i
=
idx
.
global
;
i
<
nelements
;
i
+=
nglobal
)
{
auto
bidx
=
((
i
*
vec_size
)
%
bdim_next_stride
)
/
bdim_stride
;
auto
b1
=
bp
[
bidx
];
auto
b2
=
bp
[
bidx
+
bdim_vec_len
];
auto
out
=
output
.
data
()[
i
];
for
(
std
::
size_t
j
=
0
;
j
<
vec_size
;
j
++
)
{
out
[
j
]
=
f
(
inputs
.
data
()[
i
][
j
]...,
b2
,
b1
);
}
output
.
data
()[
i
]
=
out
;
}
});
});
}
template
<
class
F
,
class
...
Arguments
>
template
<
class
F
,
class
...
Arguments
>
void
nary_double_broadcast_impl
(
void
nary_double_broadcast_impl
(
hipStream_t
stream
,
F
f
,
argument
result
,
argument
barg1
,
argument
barg2
,
Arguments
...
args
)
hipStream_t
stream
,
F
f
,
argument
result
,
argument
barg1
,
argument
barg2
,
Arguments
...
args
)
{
{
assert
(
barg1
.
get_shape
()
==
barg2
.
get_shape
());
const
auto
&
output_shape
=
result
.
get_shape
();
const
auto
&
output_shape
=
result
.
get_shape
();
const
auto
&
b_shape
=
barg1
.
get_shape
();
const
auto
&
b_shape
=
barg1
.
get_shape
();
auto
bdim
=
auto
bdim
=
...
@@ -148,7 +204,7 @@ void nary_double_broadcast_impl(
...
@@ -148,7 +204,7 @@ void nary_double_broadcast_impl(
}
}
for
(
size_t
i
=
idx
.
local
;
i
<
bdim_len
;
i
+=
nlocal
)
for
(
size_t
i
=
idx
.
local
;
i
<
bdim_len
;
i
+=
nlocal
)
{
{
buffer
[
i
+
bdim_len
]
=
binput2
.
data
()[
i
+
bdim_len
];
buffer
[
i
+
bdim_len
]
=
binput2
.
data
()[
i
];
}
}
__syncthreads
();
__syncthreads
();
// Process the data
// Process the data
...
@@ -157,7 +213,7 @@ void nary_double_broadcast_impl(
...
@@ -157,7 +213,7 @@ void nary_double_broadcast_impl(
auto
bidx
=
(
i
%
bdim_next_stride
)
/
bdim_stride
;
auto
bidx
=
(
i
%
bdim_next_stride
)
/
bdim_stride
;
auto
b1
=
buffer
[
bidx
];
auto
b1
=
buffer
[
bidx
];
auto
b2
=
buffer
[
bidx
+
bdim_len
];
auto
b2
=
buffer
[
bidx
+
bdim_len
];
output
.
data
()[
i
]
=
f
(
inputs
.
data
()[
i
]...,
b
1
,
b
2
);
output
.
data
()[
i
]
=
f
(
inputs
.
data
()[
i
]...,
b
2
,
b
1
);
}
}
});
});
});
});
...
@@ -222,7 +278,7 @@ auto nary_standard(hipStream_t stream, argument result, Arguments... args)
...
@@ -222,7 +278,7 @@ auto nary_standard(hipStream_t stream, argument result, Arguments... args)
}
}
template
<
class
...
Arguments
>
template
<
class
...
Arguments
>
bool
broadcastable
(
bool
&
divisible_by_4
,
argument
result
,
argument
barg
,
Arguments
...
args
)
bool
broadcastable
(
bool
&
divisible_by_4
,
std
::
size_t
max_size
,
argument
result
,
argument
barg
,
Arguments
...
args
)
{
{
divisible_by_4
=
false
;
divisible_by_4
=
false
;
auto
bshape
=
barg
.
get_shape
();
auto
bshape
=
barg
.
get_shape
();
...
@@ -240,7 +296,7 @@ bool broadcastable(bool& divisible_by_4, argument result, argument barg, Argumen
...
@@ -240,7 +296,7 @@ bool broadcastable(bool& divisible_by_4, argument result, argument barg, Argumen
auto
b_len
=
result
.
get_shape
().
lens
()[
b_idx
];
auto
b_len
=
result
.
get_shape
().
lens
()[
b_idx
];
auto
b_stride
=
result
.
get_shape
().
strides
()[
b_idx
];
auto
b_stride
=
result
.
get_shape
().
strides
()[
b_idx
];
assert
(
bshape
.
lens
()[
b_idx
]
==
b_len
);
assert
(
bshape
.
lens
()[
b_idx
]
==
b_len
);
if
(
b_len
<=
2048
and
std
::
none_of
(
std
::
next
(
b_it
),
strides
.
end
(),
not_zero
))
if
(
b_len
<=
max_size
and
std
::
none_of
(
std
::
next
(
b_it
),
strides
.
end
(),
not_zero
))
{
{
divisible_by_4
=
(
b_len
%
4
==
0
)
and
(
b_stride
%
4
==
0
)
and
divisible_by_4
=
(
b_len
%
4
==
0
)
and
(
b_stride
%
4
==
0
)
and
...
@@ -251,7 +307,7 @@ bool broadcastable(bool& divisible_by_4, argument result, argument barg, Argumen
...
@@ -251,7 +307,7 @@ bool broadcastable(bool& divisible_by_4, argument result, argument barg, Argumen
return
false
;
return
false
;
}
}
inline
bool
broadcastable
(
bool
&
divisible_by_4
,
argument
,
argument
)
inline
bool
broadcastable
(
bool
&
divisible_by_4
,
std
::
size_t
,
argument
,
argument
)
{
{
divisible_by_4
=
false
;
divisible_by_4
=
false
;
return
false
;
return
false
;
...
@@ -274,7 +330,7 @@ inline auto nary(hipStream_t stream, argument result, argument arg, argument bar
...
@@ -274,7 +330,7 @@ inline auto nary(hipStream_t stream, argument result, argument arg, argument bar
{
{
return
[
=
](
auto
f
)
{
return
[
=
](
auto
f
)
{
bool
divisible_by_4
=
false
;
bool
divisible_by_4
=
false
;
if
(
broadcastable
(
divisible_by_4
,
result
,
barg
,
arg
))
if
(
broadcastable
(
divisible_by_4
,
2048
,
result
,
barg
,
arg
))
{
{
if
(
divisible_by_4
)
if
(
divisible_by_4
)
nary_broadcast_vec_impl
(
stream
,
f
,
result
,
barg
,
arg
);
nary_broadcast_vec_impl
(
stream
,
f
,
result
,
barg
,
arg
);
...
@@ -293,9 +349,24 @@ auto nary(hipStream_t stream, argument result, Arguments... args)
...
@@ -293,9 +349,24 @@ auto nary(hipStream_t stream, argument result, Arguments... args)
{
{
return
[
=
](
auto
f
)
{
return
[
=
](
auto
f
)
{
auto
barg1
=
back_args
(
args
...);
auto
barg1
=
back_args
(
args
...);
bool
fallback
=
pop_back_args
(
args
...)([
&
](
auto
&&
...
args2
)
{
bool
fallback1
=
pop_back_args
(
args
...)([
&
](
auto
&&
...
args2
)
{
auto
barg2
=
back_args
(
args2
...);
bool
fallback2
=
barg2
.
get_shape
()
==
barg1
.
get_shape
()
and
barg2
.
get_shape
().
broadcasted
()
and
pop_back_args
(
args2
...)([
&
](
auto
&&
...
args3
)
{
bool
divisible_by_4
=
false
;
if
(
broadcastable
(
divisible_by_4
,
1024
,
result
,
barg2
,
args3
...))
{
if
(
divisible_by_4
)
nary_double_broadcast_vec_impl
(
stream
,
f
,
result
,
barg1
,
barg2
,
args3
...);
else
nary_double_broadcast_impl
(
stream
,
f
,
result
,
barg1
,
barg2
,
args3
...);
return
false
;
}
return
true
;
});
if
(
not
fallback2
)
return
false
;
bool
divisible_by_4
=
false
;
bool
divisible_by_4
=
false
;
if
(
broadcastable
(
divisible_by_4
,
result
,
barg1
,
args2
...))
if
(
broadcastable
(
divisible_by_4
,
2048
,
result
,
barg1
,
args2
...))
{
{
if
(
divisible_by_4
)
if
(
divisible_by_4
)
nary_broadcast_vec_impl
(
stream
,
f
,
result
,
barg1
,
args2
...);
nary_broadcast_vec_impl
(
stream
,
f
,
result
,
barg1
,
args2
...);
...
@@ -305,7 +376,7 @@ auto nary(hipStream_t stream, argument result, Arguments... args)
...
@@ -305,7 +376,7 @@ auto nary(hipStream_t stream, argument result, Arguments... args)
}
}
return
true
;
return
true
;
});
});
if
(
fallback
)
if
(
fallback
1
)
nary_impl
(
stream
,
f
,
result
,
args
...);
nary_impl
(
stream
,
f
,
result
,
args
...);
};
};
}
}
...
...
src/targets/gpu/fuse_ops.cpp
View file @
8e68f6ca
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
#include <migraphx/matcher.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/convolution.hpp>
#include <migraphx/gpu/convolution.hpp>
#include <migraphx/gpu/device/mul_add.hpp>
#include <migraphx/gpu/device/add_relu.hpp>
#include <migraphx/gpu/device/add_relu.hpp>
#include <migraphx/gpu/device/add.hpp>
#include <migraphx/gpu/device/add.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
...
@@ -200,21 +201,42 @@ struct hip_add_relu
...
@@ -200,21 +201,42 @@ struct hip_add_relu
}
}
};
};
struct
hip_mul_add
{
std
::
string
name
()
const
{
return
"hip::mul_add"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
4
);
return
inputs
.
front
();
}
argument
compute
(
context
&
ctx
,
const
shape
&
,
const
std
::
vector
<
argument
>&
args
)
const
{
device
::
mul_add
(
ctx
.
get_stream
().
get
(),
args
.
at
(
3
),
args
.
at
(
0
),
args
.
at
(
1
),
args
.
at
(
2
));
return
args
.
at
(
3
);
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
shapes
)
const
{
return
shapes
.
size
()
-
1
;
}
};
void
move_broadcasted_back
(
std
::
vector
<
instruction_ref
>&
args
)
void
move_broadcasted_back
(
std
::
vector
<
instruction_ref
>&
args
)
{
{
// Ensure the last arguments is the broadcasted one
// Ensure the last arguments is the broadcasted one
auto
last
=
std
::
prev
(
args
.
end
());
auto
it
=
std
::
find_if
(
auto
it
=
std
::
find_if
(
args
.
begin
(),
args
.
end
()
,
[](
auto
arg
)
{
return
arg
->
get_shape
().
broadcasted
();
});
args
.
begin
(),
last
,
[](
auto
arg
)
{
return
arg
->
get_shape
().
broadcasted
();
});
if
(
it
!=
args
.
end
()
)
if
(
it
!=
last
)
std
::
swap
(
*
it
,
*
std
::
prev
(
args
.
end
(),
2
));
std
::
swap
(
*
it
,
*
std
::
prev
(
last
));
}
}
void
move_standard_front
(
std
::
vector
<
instruction_ref
>&
args
)
void
move_standard_front
(
std
::
vector
<
instruction_ref
>&
args
)
{
{
// Ensure the first arguments is the standard one
// Ensure the first arguments is the standard one
auto
last
=
std
::
prev
(
args
.
end
());
auto
it
=
std
::
find_if
(
auto
it
=
std
::
find_if
(
args
.
begin
(),
args
.
end
()
,
[](
auto
arg
)
{
return
arg
->
get_shape
().
standard
();
});
args
.
begin
(),
last
,
[](
auto
arg
)
{
return
arg
->
get_shape
().
standard
();
});
if
(
it
!=
args
.
end
()
)
if
(
it
!=
last
)
std
::
swap
(
*
it
,
args
.
front
());
std
::
swap
(
*
it
,
args
.
front
());
}
}
...
@@ -278,6 +300,32 @@ struct find_triadd
...
@@ -278,6 +300,32 @@ struct find_triadd
}
}
};
};
struct
find_mul_add
{
auto
matcher
()
const
{
return
match
::
name
(
"gpu::add"
)(
match
::
either_arg
(
0
,
1
)(
match
::
name
(
"gpu::mul"
).
bind
(
"mul"
),
match
::
any
().
bind
(
"b"
)));
}
void
apply
(
program
&
p
,
match
::
matcher_result
r
)
const
{
auto
mul_ins
=
r
.
instructions
[
"mul"
];
auto
b_ins
=
r
.
instructions
[
"b"
];
auto
ins
=
r
.
result
;
auto
args
=
mul_ins
->
inputs
();
assert
(
mul_ins
!=
b_ins
);
move_standard_front
(
args
);
move_broadcasted_back
(
args
);
args
.
insert
(
std
::
prev
(
args
.
end
()),
b_ins
);
args
.
back
()
=
ins
->
inputs
().
back
();
p
.
replace_instruction
(
ins
,
hip_mul_add
{},
args
);
}
};
struct
miopen_conv_bias
struct
miopen_conv_bias
{
{
op
::
convolution
op
;
op
::
convolution
op
;
...
@@ -433,7 +481,8 @@ void fuse_ops::apply(program& p) const
...
@@ -433,7 +481,8 @@ void fuse_ops::apply(program& p) const
match
::
find_matches
(
p
,
match
::
find_matches
(
p
,
find_conv_bias_relu
{
ctx
},
find_conv_bias_relu
{
ctx
},
find_conv_bias
{
ctx
},
find_conv_bias
{
ctx
},
find_add_relu
{}
find_add_relu
{},
find_mul_add
{}
);
);
// clang-format on
// clang-format on
}
}
...
...
test/gpu/miopen.cpp
View file @
8e68f6ca
...
@@ -490,6 +490,24 @@ struct test_triadd2 : verify_program<test_triadd2>
...
@@ -490,6 +490,24 @@ struct test_triadd2 : verify_program<test_triadd2>
}
}
};
};
struct
test_mul_add
:
verify_program
<
test_mul_add
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
3
}};
migraphx
::
shape
bs
{
migraphx
::
shape
::
float_type
,
{
3
}};
auto
x
=
p
.
add_parameter
(
"x"
,
s
);
auto
a
=
p
.
add_parameter
(
"a"
,
bs
);
auto
b
=
p
.
add_parameter
(
"b"
,
bs
);
auto
ab
=
p
.
add_instruction
(
migraphx
::
op
::
broadcast
{
1
,
s
.
lens
()},
a
);
auto
bb
=
p
.
add_instruction
(
migraphx
::
op
::
broadcast
{
1
,
s
.
lens
()},
b
);
auto
mul
=
p
.
add_instruction
(
migraphx
::
op
::
mul
{},
x
,
ab
);
p
.
add_instruction
(
migraphx
::
op
::
add
{},
mul
,
bb
);
return
p
;
}
};
struct
test_add_broadcast
:
verify_program
<
test_add_broadcast
>
struct
test_add_broadcast
:
verify_program
<
test_add_broadcast
>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment