Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
11e155c2
Commit
11e155c2
authored
Jun 13, 2022
by
Paul
Browse files
Merge
parents
8a9c5bce
aa7ff911
Changes
397
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
549 additions
and
187 deletions
+549
-187
src/targets/gpu/compiler.cpp
src/targets/gpu/compiler.cpp
+39
-0
src/targets/gpu/device/include/migraphx/gpu/device/multi_index.hpp
...ts/gpu/device/include/migraphx/gpu/device/multi_index.hpp
+4
-3
src/targets/gpu/device/include/migraphx/gpu/device/scan.hpp
src/targets/gpu/device/include/migraphx/gpu/device/scan.hpp
+13
-6
src/targets/gpu/device/include/migraphx/gpu/device/visit.hpp
src/targets/gpu/device/include/migraphx/gpu/device/visit.hpp
+12
-11
src/targets/gpu/device/nonzero.cpp
src/targets/gpu/device/nonzero.cpp
+16
-15
src/targets/gpu/device/prefix_scan_sum.cpp
src/targets/gpu/device/prefix_scan_sum.cpp
+100
-20
src/targets/gpu/driver/CMakeLists.txt
src/targets/gpu/driver/CMakeLists.txt
+2
-6
src/targets/gpu/driver/compile_op.cpp
src/targets/gpu/driver/compile_op.cpp
+3
-3
src/targets/gpu/driver/run_op.cpp
src/targets/gpu/driver/run_op.cpp
+4
-2
src/targets/gpu/eliminate_workspace.cpp
src/targets/gpu/eliminate_workspace.cpp
+5
-5
src/targets/gpu/fuse_ops.cpp
src/targets/gpu/fuse_ops.cpp
+116
-66
src/targets/gpu/gemm_impl.cpp
src/targets/gpu/gemm_impl.cpp
+68
-17
src/targets/gpu/hip.cpp
src/targets/gpu/hip.cpp
+19
-6
src/targets/gpu/include/migraphx/gpu/analyze_streams.hpp
src/targets/gpu/include/migraphx/gpu/analyze_streams.hpp
+1
-1
src/targets/gpu/include/migraphx/gpu/code_object_op.hpp
src/targets/gpu/include/migraphx/gpu/code_object_op.hpp
+4
-0
src/targets/gpu/include/migraphx/gpu/compile_gen.hpp
src/targets/gpu/include/migraphx/gpu/compile_gen.hpp
+46
-0
src/targets/gpu/include/migraphx/gpu/compile_hip.hpp
src/targets/gpu/include/migraphx/gpu/compile_hip.hpp
+0
-2
src/targets/gpu/include/migraphx/gpu/compile_hip_code_object.hpp
...gets/gpu/include/migraphx/gpu/compile_hip_code_object.hpp
+27
-0
src/targets/gpu/include/migraphx/gpu/compile_pointwise.hpp
src/targets/gpu/include/migraphx/gpu/compile_pointwise.hpp
+0
-24
src/targets/gpu/include/migraphx/gpu/compiler.hpp
src/targets/gpu/include/migraphx/gpu/compiler.hpp
+70
-0
No files found.
src/targets/gpu/compiler.cpp
0 → 100644
View file @
11e155c2
#include <migraphx/gpu/compiler.hpp>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
auto
&
compiler_map
()
{
static
std
::
unordered_map
<
std
::
string
,
compiler_compile
>
m
;
// NOLINT
return
m
;
}
auto
&
compiler_op_map
()
{
static
std
::
unordered_map
<
std
::
string
,
compiler_compile_op
>
m
;
// NOLINT
return
m
;
}
void
register_compiler
(
const
std
::
string
&
name
,
compiler_compile
c
,
compiler_compile_op
cop
)
{
compiler_map
()[
name
]
=
std
::
move
(
c
);
compiler_op_map
()[
name
]
=
std
::
move
(
cop
);
}
bool
has_compiler_for
(
const
std
::
string
&
name
)
{
return
compiler_map
().
count
(
name
)
>
0
;
}
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
{
return
compiler_map
().
at
(
op
.
name
())(
ctx
,
ins
,
op
);
}
operation
compile_op
(
const
std
::
string
&
name
,
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
{
return
compiler_op_map
().
at
(
name
)(
ctx
,
inputs
,
v
);
}
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/device/include/migraphx/gpu/device/multi_index.hpp
View file @
11e155c2
...
@@ -57,9 +57,10 @@ inline auto mi_nglobal(const hip_shape<N>& s, index_int nlocal)
...
@@ -57,9 +57,10 @@ inline auto mi_nglobal(const hip_shape<N>& s, index_int nlocal)
{
{
assert
(
s
.
standard
);
assert
(
s
.
standard
);
assert
(
s
.
elements
()
>
0
);
assert
(
s
.
elements
()
>
0
);
index_int
n
=
s
.
elements
();
index_int
n
=
s
.
elements
();
index_int
groups
=
(
n
+
nlocal
-
1
)
/
nlocal
;
index_int
groups
=
(
n
+
nlocal
-
1
)
/
nlocal
;
index_int
nglobal
=
std
::
min
<
index_int
>
(
128
,
groups
)
*
nlocal
;
// max possible number of blocks is set to 1B (1,073,741,824)
index_int
nglobal
=
std
::
min
<
index_int
>
(
1073741824
,
groups
)
*
nlocal
;
assert
(
groups
>
0
);
assert
(
groups
>
0
);
assert
(
nglobal
>
0
);
assert
(
nglobal
>
0
);
...
...
src/targets/gpu/device/include/migraphx/gpu/device/scan.hpp
View file @
11e155c2
...
@@ -44,12 +44,19 @@ __device__ void block_scan(index idx, Op op, T init, ForStride fs, Input input,
...
@@ -44,12 +44,19 @@ __device__ void block_scan(index idx, Op op, T init, ForStride fs, Input input,
template
<
index_int
N
,
class
Op
,
class
T
,
class
Input
,
class
Output
>
template
<
index_int
N
,
class
Op
,
class
T
,
class
Input
,
class
Output
>
__device__
void
block_scan
(
index
idx
,
Op
op
,
T
init
,
index_int
n
,
Input
input
,
Output
output
)
__device__
void
block_scan
(
index
idx
,
Op
op
,
T
init
,
index_int
n
,
Input
input
,
Output
output
)
{
{
block_scan
<
N
>
(
idx
,
block_scan
<
N
>
(
op
,
idx
,
init
,
op
,
[
&
](
auto
f
)
->
decltype
(
f
(
index_int
{}))
{
return
idx
.
local_stride
(
n
,
f
);
},
init
,
input
,
[
&
](
auto
f
)
->
decltype
(
f
(
index_int
{}))
{
return
idx
.
local_stride
(
n
,
f
);
},
output
);
input
,
output
);
}
template
<
class
F
>
constexpr
auto
reverse_scan
(
index_int
n
,
F
f
)
{
return
[
=
](
auto
i
,
auto
&&
...
xs
)
{
return
f
(
n
-
i
-
1
,
xs
...);
};
}
}
}
// namespace device
}
// namespace device
...
...
src/targets/gpu/device/include/migraphx/gpu/device/visit.hpp
View file @
11e155c2
...
@@ -14,28 +14,23 @@ constexpr void visit_tensor_size(index_int n, F f)
...
@@ -14,28 +14,23 @@ constexpr void visit_tensor_size(index_int n, F f)
{
{
switch
(
n
)
switch
(
n
)
{
{
case
1
:
case
1
:
{
{
f
(
std
::
integral_constant
<
index_int
,
1
>
{});
f
(
std
::
integral_constant
<
index_int
,
1
>
{});
break
;
break
;
}
}
case
2
:
case
2
:
{
{
f
(
std
::
integral_constant
<
index_int
,
2
>
{});
f
(
std
::
integral_constant
<
index_int
,
2
>
{});
break
;
break
;
}
}
case
3
:
case
3
:
{
{
f
(
std
::
integral_constant
<
index_int
,
3
>
{});
f
(
std
::
integral_constant
<
index_int
,
3
>
{});
break
;
break
;
}
}
case
4
:
case
4
:
{
{
f
(
std
::
integral_constant
<
index_int
,
4
>
{});
f
(
std
::
integral_constant
<
index_int
,
4
>
{});
break
;
break
;
}
}
case
5
:
case
5
:
{
{
f
(
std
::
integral_constant
<
index_int
,
5
>
{});
f
(
std
::
integral_constant
<
index_int
,
5
>
{});
break
;
break
;
}
}
...
@@ -181,7 +176,13 @@ template <index_int N, class T, class... Ts>
...
@@ -181,7 +176,13 @@ template <index_int N, class T, class... Ts>
auto
hip_vec_visit_all
(
T
&&
x
,
Ts
&&
...
xs
)
auto
hip_vec_visit_all
(
T
&&
x
,
Ts
&&
...
xs
)
{
{
return
[
&
](
auto
f
)
{
return
[
&
](
auto
f
)
{
hip_visit_all_impl
(
get_shape
(
x
),
auto
sx
=
get_shape
(
x
);
auto
lens
=
sx
.
lens
();
assert
(
lens
.
back
()
%
N
==
0
);
assert
(
sx
.
strides
().
back
()
==
1
);
lens
.
back
()
/=
N
;
shape
vec_sx
{
sx
.
type
(),
lens
};
hip_visit_all_impl
(
vec_sx
,
make_hip_convert
([](
auto
*
p
)
{
return
as_vec
<
N
>
(
device_cast
(
p
));
}),
make_hip_convert
([](
auto
*
p
)
{
return
as_vec
<
N
>
(
device_cast
(
p
));
}),
f
,
f
,
x
,
x
,
...
...
src/targets/gpu/device/nonzero.cpp
View file @
11e155c2
...
@@ -25,22 +25,23 @@ argument nonzero(hipStream_t stream, const argument& result, const argument& arg
...
@@ -25,22 +25,23 @@ argument nonzero(hipStream_t stream, const argument& result, const argument& arg
// fill all output to 0 first
// fill all output to 0 first
idx
.
local_stride
(
out_elem_num
,
[
&
](
auto
j
)
{
ptr
[
j
]
=
0
;
});
idx
.
local_stride
(
out_elem_num
,
[
&
](
auto
j
)
{
ptr
[
j
]
=
0
;
});
block_scan
<
block_size
>
(
idx
,
block_scan
<
block_size
>
(
sum
{},
idx
,
0
,
sum
{},
elem_num
,
0
,
[
&
](
auto
j
)
{
return
(
float_equal
(
in_ptr
[
j
],
0
))
?
0
:
1
;
},
elem_num
,
[
&
](
auto
j
,
auto
x
)
{
[
&
](
auto
j
)
{
return
(
float_equal
(
in_ptr
[
j
],
0
))
?
0
:
1
;
},
auto
out_loc
=
x
-
1
;
[
&
](
auto
j
,
auto
x
)
{
if
(
float_equal
(
in_ptr
[
j
],
0
))
auto
out_loc
=
x
-
1
;
return
;
if
(
float_equal
(
in_ptr
[
j
],
0
))
return
;
auto
index
=
si
.
multi
(
j
);
auto
index
=
si
.
multi
(
j
);
for
(
size_t
k
=
0
;
k
<
index
.
size
();
++
k
)
for
(
size_t
k
=
0
;
k
<
index
.
size
();
++
k
)
{
{
ptr
[
k
*
elem_num
+
out_loc
]
=
index
[
k
];
ptr
[
k
*
elem_num
+
out_loc
]
=
index
[
k
];
}
}
});
});
});
});
});
});
...
...
src/targets/gpu/device/prefix_scan_sum.cpp
View file @
11e155c2
#include <migraphx/gpu/device/prefix_scan_sum.hpp>
#include <migraphx/gpu/device/prefix_scan_sum.hpp>
#include <migraphx/gpu/device/scan.hpp>
#include <migraphx/gpu/device/scan.hpp>
#include <migraphx/gpu/device/reduce_ops.hpp>
#include <migraphx/gpu/device/reduce_ops.hpp>
#include <migraphx/gpu/device/reduce.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/device/types.hpp>
namespace
migraphx
{
namespace
migraphx
{
...
@@ -8,29 +9,108 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -8,29 +9,108 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
gpu
{
namespace
gpu
{
namespace
device
{
namespace
device
{
void
prefix_scan_sum
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int32_t
axis
)
void
prefix_scan_sum
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int32_t
axis
,
bool
exclusive
,
bool
reverse
)
{
{
const
index_int
block_size
=
256
;
const
index_int
max_block_size
=
256
;
const
index_int
n
=
arg
.
get_shape
().
lens
()[
axis
];
const
index_int
n
=
arg
.
get_shape
().
lens
()[
axis
];
auto
rlens
=
result
.
get_shape
().
lens
();
auto
rlens
=
result
.
get_shape
().
lens
();
rlens
[
axis
]
=
1
;
rlens
[
axis
]
=
1
;
hip_visit_all
(
result
,
arg
,
result
.
get_shape
().
with_lens
(
rlens
))(
hip_visit_all
(
result
,
arg
,
result
.
get_shape
().
with_lens
(
rlens
))(
[
=
](
auto
output
,
auto
input
,
auto
rshape
)
{
[
=
](
auto
output
,
auto
input
,
auto
rshape
)
{
gs_launch
(
stream
,
rshape
.
elements
()
*
block_size
,
block_size
)(
const
index_int
block_size
=
compute_block_size
(
rshape
.
elements
(),
max_block_size
);
[
=
](
auto
i
,
auto
idx
)
__device__
{
if
(
reverse
and
exclusive
)
const
auto
ridx
=
rshape
.
multi
(
i
/
block_size
);
{
auto
compute_idx
=
[
&
](
auto
j
)
{
gs_launch
(
stream
,
rshape
.
elements
()
*
block_size
,
block_size
)(
auto
k
=
ridx
;
[
=
](
auto
i
,
auto
idx
)
__device__
{
k
[
axis
]
=
j
;
const
auto
ridx
=
rshape
.
multi
(
i
/
block_size
);
return
k
;
auto
compute_idx
=
[
&
](
auto
j
)
{
};
auto
k
=
ridx
;
block_scan
<
block_size
>
(
idx
,
k
[
axis
]
=
j
;
sum
{},
return
k
;
0
,
};
n
,
block_scan
<
max_block_size
>
(
[
&
](
auto
j
)
{
return
input
[
compute_idx
(
j
)];
},
idx
,
[
&
](
auto
j
,
auto
x
)
{
output
[
compute_idx
(
j
)]
=
x
;
});
sum
{},
});
0
,
n
,
reverse_scan
(
n
,
[
&
](
auto
j
)
{
return
input
[
compute_idx
(
j
)];
}),
reverse_scan
(
n
,
[
&
](
auto
j
,
auto
x
)
{
if
(
j
==
n
-
1
)
output
[
compute_idx
(
j
)]
=
0
;
if
(
j
>
0
)
output
[
compute_idx
(
j
-
1
)]
=
x
;
}));
});
}
else
if
(
reverse
)
{
gs_launch
(
stream
,
rshape
.
elements
()
*
block_size
,
block_size
)(
[
=
](
auto
i
,
auto
idx
)
__device__
{
const
auto
ridx
=
rshape
.
multi
(
i
/
block_size
);
auto
compute_idx
=
[
&
](
auto
j
)
{
auto
k
=
ridx
;
k
[
axis
]
=
j
;
return
k
;
};
block_scan
<
max_block_size
>
(
idx
,
sum
{},
0
,
n
,
reverse_scan
(
n
,
[
&
](
auto
j
)
{
return
input
[
compute_idx
(
j
)];
}),
reverse_scan
(
n
,
[
&
](
auto
j
,
auto
x
)
{
output
[
compute_idx
(
j
)]
=
x
;
}));
});
}
else
if
(
exclusive
)
{
gs_launch
(
stream
,
rshape
.
elements
()
*
block_size
,
block_size
)(
[
=
](
auto
i
,
auto
idx
)
__device__
{
const
auto
ridx
=
rshape
.
multi
(
i
/
block_size
);
auto
compute_idx
=
[
&
](
auto
j
)
{
auto
k
=
ridx
;
k
[
axis
]
=
j
;
return
k
;
};
block_scan
<
max_block_size
>
(
idx
,
sum
{},
0
,
n
,
[
&
](
auto
j
)
{
return
input
[
compute_idx
(
j
)];
},
[
&
](
auto
j
,
auto
x
)
{
auto
k
=
j
+
1
;
if
(
j
==
0
)
output
[
compute_idx
(
0
)]
=
0
;
if
(
k
<
n
)
output
[
compute_idx
(
k
)]
=
x
;
});
});
}
else
{
gs_launch
(
stream
,
rshape
.
elements
()
*
block_size
,
block_size
)(
[
=
](
auto
i
,
auto
idx
)
__device__
{
const
auto
ridx
=
rshape
.
multi
(
i
/
block_size
);
auto
compute_idx
=
[
&
](
auto
j
)
{
auto
k
=
ridx
;
k
[
axis
]
=
j
;
return
k
;
};
block_scan
<
max_block_size
>
(
idx
,
sum
{},
0
,
n
,
[
&
](
auto
j
)
{
return
input
[
compute_idx
(
j
)];
},
[
&
](
auto
j
,
auto
x
)
{
output
[
compute_idx
(
j
)]
=
x
;
});
});
}
});
});
}
}
...
...
src/targets/gpu/driver/CMakeLists.txt
View file @
11e155c2
file
(
GLOB GPU_DRIVER_SRCS
${
CONFIGURE_DEPENDS
}
${
CMAKE_CURRENT_SOURCE_DIR
}
/*.cpp
)
add_executable
(
gpu-driver
add_executable
(
gpu-driver
action.cpp
${
GPU_DRIVER_SRCS
}
compile_pointwise.cpp
main.cpp
parser.cpp
perf.cpp
run_op.cpp
)
)
target_include_directories
(
gpu-driver PRIVATE include
)
target_include_directories
(
gpu-driver PRIVATE include
)
target_link_libraries
(
gpu-driver PRIVATE migraphx_gpu
)
target_link_libraries
(
gpu-driver PRIVATE migraphx_gpu
)
src/targets/gpu/driver/compile_p
ointwise
.cpp
→
src/targets/gpu/driver/compile_
o
p.cpp
100755 → 100644
View file @
11e155c2
#include <migraphx/gpu/driver/action.hpp>
#include <migraphx/gpu/driver/action.hpp>
#include <migraphx/gpu/driver/perf.hpp>
#include <migraphx/gpu/driver/perf.hpp>
#include <migraphx/gpu/compile
_pointwise
.hpp>
#include <migraphx/gpu/compile
r
.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/context.hpp>
namespace
migraphx
{
namespace
migraphx
{
...
@@ -8,13 +8,13 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -8,13 +8,13 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
gpu
{
namespace
gpu
{
namespace
driver
{
namespace
driver
{
struct
compile_p
ointwise
:
action
<
compile_p
ointwise
>
struct
compile_
o
p
:
action
<
compile_
o
p
>
{
{
static
void
apply
(
const
parser
&
p
,
const
value
&
v
)
static
void
apply
(
const
parser
&
p
,
const
value
&
v
)
{
{
context
ctx
;
context
ctx
;
auto
inputs
=
p
.
parse_shapes
(
v
.
at
(
"inputs"
));
auto
inputs
=
p
.
parse_shapes
(
v
.
at
(
"inputs"
));
auto
op
=
gpu
::
compile_
pointwise
(
ctx
,
inputs
,
v
.
at
(
"
l
am
bda
"
).
to
<
std
::
string
>
());
auto
op
=
gpu
::
compile_
op
(
v
.
at
(
"
n
am
e
"
).
to
<
std
::
string
>
()
,
ctx
,
inputs
,
v
);
double
t
=
time_op
(
ctx
,
op
,
inputs
,
p
.
get
(
v
,
"iterations"
,
100
));
double
t
=
time_op
(
ctx
,
op
,
inputs
,
p
.
get
(
v
,
"iterations"
,
100
));
std
::
cout
<<
op
<<
": "
<<
t
<<
"ms"
<<
std
::
endl
;
std
::
cout
<<
op
<<
": "
<<
t
<<
"ms"
<<
std
::
endl
;
}
}
...
...
src/targets/gpu/driver/run_op.cpp
View file @
11e155c2
...
@@ -17,8 +17,10 @@ struct run_op : action<run_op>
...
@@ -17,8 +17,10 @@ struct run_op : action<run_op>
auto
name
=
v
.
at
(
"name"
).
to
<
std
::
string
>
();
auto
name
=
v
.
at
(
"name"
).
to
<
std
::
string
>
();
if
(
not
contains
(
name
,
"::"
))
if
(
not
contains
(
name
,
"::"
))
name
=
"gpu::"
+
name
;
name
=
"gpu::"
+
name
;
auto
op
=
make_op
(
name
);
auto
op
=
make_op
(
name
);
double
t
=
time_op
(
ctx
,
op
,
inputs
);
if
(
v
.
contains
(
"fields"
))
op
.
from_value
(
v
.
at
(
"fields"
));
double
t
=
time_op
(
ctx
,
op
,
inputs
,
p
.
get
(
v
,
"iterations"
,
100
));
std
::
cout
<<
op
<<
": "
<<
t
<<
"ms"
<<
std
::
endl
;
std
::
cout
<<
op
<<
": "
<<
t
<<
"ms"
<<
std
::
endl
;
}
}
};
};
...
...
src/targets/gpu/eliminate_workspace.cpp
View file @
11e155c2
...
@@ -11,11 +11,11 @@ namespace migraphx {
...
@@ -11,11 +11,11 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
gpu
{
void
eliminate_workspace
::
apply
(
module
&
p
)
const
void
eliminate_workspace
::
apply
(
module
&
m
)
const
{
{
std
::
size_t
n
=
0
;
std
::
size_t
n
=
0
;
std
::
vector
<
instruction_ref
>
allocs
;
std
::
vector
<
instruction_ref
>
allocs
;
for
(
auto
ins
:
iterator_for
(
p
))
for
(
auto
ins
:
iterator_for
(
m
))
{
{
if
(
ins
->
outputs
().
size
()
!=
1
)
if
(
ins
->
outputs
().
size
()
!=
1
)
continue
;
continue
;
...
@@ -30,11 +30,11 @@ void eliminate_workspace::apply(module& p) const
...
@@ -30,11 +30,11 @@ void eliminate_workspace::apply(module& p) const
}
}
if
(
n
>
0
)
if
(
n
>
0
)
{
{
auto
ws
=
p
.
add_parameter
(
"workspace"
,
shape
{
shape
::
int8_type
,
{
n
}});
auto
ws
=
m
.
add_parameter
(
"workspace"
,
shape
{
shape
::
int8_type
,
{
n
}});
for
(
auto
&&
a
:
allocs
)
for
(
auto
&&
a
:
allocs
)
{
{
p
.
replace_instruction
(
a
,
ws
);
m
.
replace_instruction
(
a
,
ws
);
p
.
remove_instruction
(
a
);
m
.
remove_instruction
(
a
);
}
}
}
}
}
}
...
...
src/targets/gpu/fuse_ops.cpp
View file @
11e155c2
...
@@ -316,7 +316,7 @@ struct find_layernorm
...
@@ -316,7 +316,7 @@ struct find_layernorm
{
{
auto
matcher
()
const
{
return
match
::
layernorm
(
&
gpu_name
);
}
auto
matcher
()
const
{
return
match
::
layernorm
(
&
gpu_name
);
}
void
apply
(
module
&
p
,
match
::
matcher_result
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
{
auto
ins
=
r
.
result
;
auto
ins
=
r
.
result
;
auto
x_ins
=
r
.
instructions
[
"x"
];
auto
x_ins
=
r
.
instructions
[
"x"
];
...
@@ -331,7 +331,7 @@ struct find_layernorm
...
@@ -331,7 +331,7 @@ struct find_layernorm
if
(
relements
>
1024
or
(
relements
%
4
!=
0
and
relements
>
256
))
if
(
relements
>
1024
or
(
relements
%
4
!=
0
and
relements
>
256
))
return
;
return
;
p
.
replace_instruction
(
ins
,
hip_layernorm
{},
x_ins
,
args
.
back
());
m
.
replace_instruction
(
ins
,
hip_layernorm
{},
x_ins
,
args
.
back
());
}
}
};
};
...
@@ -343,11 +343,11 @@ struct find_triadd_layernorm
...
@@ -343,11 +343,11 @@ struct find_triadd_layernorm
match
::
used_once
(),
match
::
all_of
[
match
::
inputs
()](
match
::
standard_shape
()))));
match
::
used_once
(),
match
::
all_of
[
match
::
inputs
()](
match
::
standard_shape
()))));
}
}
void
apply
(
module
&
p
,
const
match
::
matcher_result
&
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
{
auto
ins
=
r
.
result
;
auto
ins
=
r
.
result
;
auto
triadd
=
ins
->
inputs
().
front
();
auto
triadd
=
ins
->
inputs
().
front
();
p
.
replace_instruction
(
ins
,
hip_triadd_layernorm
{},
triadd
->
inputs
());
m
.
replace_instruction
(
ins
,
hip_triadd_layernorm
{},
triadd
->
inputs
());
}
}
};
};
...
@@ -355,13 +355,13 @@ struct find_gelu
...
@@ -355,13 +355,13 @@ struct find_gelu
{
{
auto
matcher
()
const
{
return
match
::
gelu_erf
(
&
gpu_name
);
}
auto
matcher
()
const
{
return
match
::
gelu_erf
(
&
gpu_name
);
}
void
apply
(
module
&
p
,
match
::
matcher_result
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
{
auto
ins
=
r
.
result
;
auto
ins
=
r
.
result
;
auto
x_ins
=
r
.
instructions
[
"x"
];
auto
x_ins
=
r
.
instructions
[
"x"
];
auto
args
=
ins
->
inputs
();
auto
args
=
ins
->
inputs
();
p
.
replace_instruction
(
ins
,
hip_gelu
{},
x_ins
,
args
.
back
());
m
.
replace_instruction
(
ins
,
hip_gelu
{},
x_ins
,
args
.
back
());
}
}
};
};
...
@@ -372,7 +372,7 @@ struct find_add_gelu
...
@@ -372,7 +372,7 @@ struct find_add_gelu
return
match
::
name
(
"gpu::gelu"
)(
match
::
arg
(
0
)(
match
::
name
(
"gpu::add"
).
bind
(
"add"
)));
return
match
::
name
(
"gpu::gelu"
)(
match
::
arg
(
0
)(
match
::
name
(
"gpu::add"
).
bind
(
"add"
)));
}
}
void
apply
(
module
&
p
,
match
::
matcher_result
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
{
auto
add_ins
=
r
.
instructions
[
"add"
];
auto
add_ins
=
r
.
instructions
[
"add"
];
auto
ins
=
r
.
result
;
auto
ins
=
r
.
result
;
...
@@ -381,7 +381,7 @@ struct find_add_gelu
...
@@ -381,7 +381,7 @@ struct find_add_gelu
move_broadcasted_back
(
args
);
move_broadcasted_back
(
args
);
args
.
back
()
=
ins
->
inputs
().
back
();
args
.
back
()
=
ins
->
inputs
().
back
();
p
.
replace_instruction
(
ins
,
hip_add_gelu
{},
args
);
m
.
replace_instruction
(
ins
,
hip_add_gelu
{},
args
);
}
}
};
};
...
@@ -391,16 +391,16 @@ struct find_gelu_new
...
@@ -391,16 +391,16 @@ struct find_gelu_new
auto
matcher
()
const
{
return
match
::
gelu_tanh
(
&
gpu_name
);
}
auto
matcher
()
const
{
return
match
::
gelu_tanh
(
&
gpu_name
);
}
void
apply
(
module
&
p
,
match
::
matcher_result
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
{
auto
ins
=
r
.
result
;
auto
ins
=
r
.
result
;
auto
x_ins
=
r
.
instructions
[
"x"
];
auto
x_ins
=
r
.
instructions
[
"x"
];
auto
args
=
ins
->
inputs
();
auto
args
=
ins
->
inputs
();
if
(
fast_math
)
if
(
fast_math
)
p
.
replace_instruction
(
ins
,
hip_gelu
{},
x_ins
,
args
.
back
());
m
.
replace_instruction
(
ins
,
hip_gelu
{},
x_ins
,
args
.
back
());
else
else
p
.
replace_instruction
(
ins
,
hip_gelu_new
{},
x_ins
,
args
.
back
());
m
.
replace_instruction
(
ins
,
hip_gelu_new
{},
x_ins
,
args
.
back
());
}
}
};
};
...
@@ -411,7 +411,7 @@ struct find_add_gelu_new
...
@@ -411,7 +411,7 @@ struct find_add_gelu_new
return
match
::
name
(
"gpu::gelu_new"
)(
match
::
arg
(
0
)(
match
::
name
(
"gpu::add"
).
bind
(
"add"
)));
return
match
::
name
(
"gpu::gelu_new"
)(
match
::
arg
(
0
)(
match
::
name
(
"gpu::add"
).
bind
(
"add"
)));
}
}
void
apply
(
module
&
p
,
match
::
matcher_result
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
{
auto
add_ins
=
r
.
instructions
[
"add"
];
auto
add_ins
=
r
.
instructions
[
"add"
];
auto
ins
=
r
.
result
;
auto
ins
=
r
.
result
;
...
@@ -420,7 +420,7 @@ struct find_add_gelu_new
...
@@ -420,7 +420,7 @@ struct find_add_gelu_new
move_broadcasted_back
(
args
);
move_broadcasted_back
(
args
);
args
.
back
()
=
ins
->
inputs
().
back
();
args
.
back
()
=
ins
->
inputs
().
back
();
p
.
replace_instruction
(
ins
,
hip_add_gelu_new
{},
args
);
m
.
replace_instruction
(
ins
,
hip_add_gelu_new
{},
args
);
}
}
};
};
...
@@ -435,7 +435,7 @@ struct find_add_clip
...
@@ -435,7 +435,7 @@ struct find_add_clip
.
bind
(
"add"
)));
.
bind
(
"add"
)));
}
}
void
apply
(
module
&
p
,
match
::
matcher_result
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
{
auto
add_ins
=
r
.
instructions
[
"add"
];
auto
add_ins
=
r
.
instructions
[
"add"
];
auto
ins
=
r
.
result
;
auto
ins
=
r
.
result
;
...
@@ -448,9 +448,9 @@ struct find_add_clip
...
@@ -448,9 +448,9 @@ struct find_add_clip
add_args
.
pop_back
();
add_args
.
pop_back
();
add_args
.
insert
(
add_args
.
end
(),
std
::
next
(
ins_args
.
begin
()),
ins_args
.
end
());
add_args
.
insert
(
add_args
.
end
(),
std
::
next
(
ins_args
.
begin
()),
ins_args
.
end
());
if
(
add_ins
->
name
()
==
"gpu::add"
)
if
(
add_ins
->
name
()
==
"gpu::add"
)
p
.
replace_instruction
(
ins
,
hip_add_clip
{},
add_args
);
m
.
replace_instruction
(
ins
,
hip_add_clip
{},
add_args
);
else
if
(
add_ins
->
name
()
==
"gpu::triadd"
)
else
if
(
add_ins
->
name
()
==
"gpu::triadd"
)
p
.
replace_instruction
(
ins
,
hip_triadd_clip
{},
add_args
);
m
.
replace_instruction
(
ins
,
hip_triadd_clip
{},
add_args
);
}
}
};
};
...
@@ -470,7 +470,7 @@ struct find_add_unary
...
@@ -470,7 +470,7 @@ struct find_add_unary
.
bind
(
"add"
)));
.
bind
(
"add"
)));
}
}
void
apply
(
module
&
p
,
match
::
matcher_result
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
{
auto
add_ins
=
r
.
instructions
[
"add"
];
auto
add_ins
=
r
.
instructions
[
"add"
];
auto
ins
=
r
.
result
;
auto
ins
=
r
.
result
;
...
@@ -481,9 +481,9 @@ struct find_add_unary
...
@@ -481,9 +481,9 @@ struct find_add_unary
// Use the allocation from the relu operator
// Use the allocation from the relu operator
args
.
back
()
=
ins
->
inputs
().
back
();
args
.
back
()
=
ins
->
inputs
().
back
();
if
(
add_ins
->
name
()
==
"gpu::add"
)
if
(
add_ins
->
name
()
==
"gpu::add"
)
p
.
replace_instruction
(
ins
,
binary_add_op
,
args
);
m
.
replace_instruction
(
ins
,
binary_add_op
,
args
);
else
if
(
add_ins
->
name
()
==
"gpu::triadd"
)
else
if
(
add_ins
->
name
()
==
"gpu::triadd"
)
p
.
replace_instruction
(
ins
,
ternary_add_op
,
args
);
m
.
replace_instruction
(
ins
,
ternary_add_op
,
args
);
}
}
};
};
...
@@ -498,7 +498,7 @@ struct find_triadd
...
@@ -498,7 +498,7 @@ struct find_triadd
.
bind
(
"input"
)));
.
bind
(
"input"
)));
}
}
void
apply
(
module
&
p
,
match
::
matcher_result
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
{
auto
add_ins
=
r
.
instructions
[
"add"
];
auto
add_ins
=
r
.
instructions
[
"add"
];
auto
input_ins
=
r
.
instructions
[
"input"
];
auto
input_ins
=
r
.
instructions
[
"input"
];
...
@@ -513,7 +513,7 @@ struct find_triadd
...
@@ -513,7 +513,7 @@ struct find_triadd
move_broadcasted_back
(
args
);
move_broadcasted_back
(
args
);
args
.
back
()
=
ins
->
inputs
().
back
();
args
.
back
()
=
ins
->
inputs
().
back
();
p
.
replace_instruction
(
ins
,
hip_triadd
{},
args
);
m
.
replace_instruction
(
ins
,
hip_triadd
{},
args
);
}
}
};
};
...
@@ -525,7 +525,7 @@ struct find_mul_add
...
@@ -525,7 +525,7 @@ struct find_mul_add
match
::
name
(
"gpu::mul"
)(
match
::
used_once
()).
bind
(
"mul"
),
match
::
any
().
bind
(
"b"
)));
match
::
name
(
"gpu::mul"
)(
match
::
used_once
()).
bind
(
"mul"
),
match
::
any
().
bind
(
"b"
)));
}
}
void
apply
(
module
&
p
,
match
::
matcher_result
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
{
auto
mul_ins
=
r
.
instructions
[
"mul"
];
auto
mul_ins
=
r
.
instructions
[
"mul"
];
auto
b_ins
=
r
.
instructions
[
"b"
];
auto
b_ins
=
r
.
instructions
[
"b"
];
...
@@ -538,7 +538,7 @@ struct find_mul_add
...
@@ -538,7 +538,7 @@ struct find_mul_add
args
.
insert
(
std
::
prev
(
args
.
end
()),
b_ins
);
args
.
insert
(
std
::
prev
(
args
.
end
()),
b_ins
);
args
.
back
()
=
ins
->
inputs
().
back
();
args
.
back
()
=
ins
->
inputs
().
back
();
p
.
replace_instruction
(
ins
,
hip_mul_add
{},
args
);
m
.
replace_instruction
(
ins
,
hip_mul_add
{},
args
);
}
}
};
};
...
@@ -550,7 +550,7 @@ struct find_mul_add_relu
...
@@ -550,7 +550,7 @@ struct find_mul_add_relu
match
::
arg
(
0
)(
match
::
name
(
"gpu::mul_add"
)(
match
::
used_once
()).
bind
(
"mul_add"
)));
match
::
arg
(
0
)(
match
::
name
(
"gpu::mul_add"
)(
match
::
used_once
()).
bind
(
"mul_add"
)));
}
}
void
apply
(
module
&
p
,
match
::
matcher_result
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
{
auto
mul_add_ins
=
r
.
instructions
[
"mul_add"
];
auto
mul_add_ins
=
r
.
instructions
[
"mul_add"
];
auto
ins
=
r
.
result
;
auto
ins
=
r
.
result
;
...
@@ -558,7 +558,7 @@ struct find_mul_add_relu
...
@@ -558,7 +558,7 @@ struct find_mul_add_relu
// Use the allocation from the relu operator
// Use the allocation from the relu operator
args
.
back
()
=
ins
->
inputs
().
back
();
args
.
back
()
=
ins
->
inputs
().
back
();
p
.
replace_instruction
(
ins
,
hip_mul_add_relu
{},
args
);
m
.
replace_instruction
(
ins
,
hip_mul_add_relu
{},
args
);
}
}
};
};
...
@@ -587,6 +587,11 @@ struct miopen_fusion
...
@@ -587,6 +587,11 @@ struct miopen_fusion
return
pack
(
f
(
self
.
ops
,
"ops"
));
return
pack
(
f
(
self
.
ops
,
"ops"
));
}
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
shapes
)
const
{
return
shapes
.
size
()
-
1
;
}
value
compile
(
context
&
ctx
,
const
shape
&
,
std
::
vector
<
shape
>
inputs
)
value
compile
(
context
&
ctx
,
const
shape
&
,
std
::
vector
<
shape
>
inputs
)
{
{
// Compensate for allocation
// Compensate for allocation
...
@@ -676,7 +681,7 @@ struct miopen_fusion
...
@@ -676,7 +681,7 @@ struct miopen_fusion
struct
miopen_conv_bias
struct
miopen_conv_bias
{
{
op
::
convolution
op
;
op
::
convolution
op
;
fusion
f
=
{};
fusion
f
p
=
{};
fusion
::
op_t
conv
=
{};
fusion
::
op_t
conv
=
{};
fusion
::
op_t
bias
=
{};
fusion
::
op_t
bias
=
{};
...
@@ -700,19 +705,19 @@ struct miopen_conv_bias
...
@@ -700,19 +705,19 @@ struct miopen_conv_bias
float
beta
=
0
;
float
beta
=
0
;
miopenSetOpArgsConvForward
(
fargs
.
get
(),
conv
,
&
alpha
,
&
beta
,
args
[
1
].
implicit
());
miopenSetOpArgsConvForward
(
fargs
.
get
(),
conv
,
&
alpha
,
&
beta
,
args
[
1
].
implicit
());
miopenSetOpArgsBiasForward
(
fargs
.
get
(),
bias
,
&
alpha
,
&
beta
,
args
[
3
].
implicit
());
miopenSetOpArgsBiasForward
(
fargs
.
get
(),
bias
,
&
alpha
,
&
beta
,
args
[
3
].
implicit
());
return
f
.
execute
(
ctx
,
fargs
,
args
[
0
],
args
[
4
]);
return
f
p
.
execute
(
ctx
,
fargs
,
args
[
0
],
args
[
4
]);
}
}
void
finalize
(
context
&
ctx
,
const
shape
&
,
const
std
::
vector
<
shape
>&
inputs
)
void
finalize
(
context
&
ctx
,
const
shape
&
,
const
std
::
vector
<
shape
>&
inputs
)
{
{
f
=
fusion
(
inputs
[
0
]);
f
p
=
fusion
(
inputs
[
0
]);
conv
=
f
.
create_conv
(
op
,
inputs
[
1
]);
conv
=
f
p
.
create_conv
(
op
,
inputs
[
1
]);
bias
=
f
.
create_bias
(
inputs
[
3
]);
bias
=
f
p
.
create_bias
(
inputs
[
3
]);
if
(
not
f
.
compile
(
ctx
))
if
(
not
f
p
.
compile
(
ctx
))
MIGRAPHX_THROW
(
"Failed to compile fusion plan"
);
MIGRAPHX_THROW
(
"Failed to compile fusion plan"
);
}
}
shape
get_workspace
(
context
&
ctx
)
{
return
f
.
get_workspace
(
ctx
);
}
shape
get_workspace
(
context
&
ctx
)
{
return
f
p
.
get_workspace
(
ctx
);
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
shapes
)
const
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
shapes
)
const
{
{
return
shapes
.
size
()
-
1
;
return
shapes
.
size
()
-
1
;
...
@@ -723,7 +728,7 @@ MIGRAPHX_REGISTER_OP(miopen_conv_bias)
...
@@ -723,7 +728,7 @@ MIGRAPHX_REGISTER_OP(miopen_conv_bias)
struct
miopen_conv_bias_relu
struct
miopen_conv_bias_relu
{
{
op
::
convolution
op
;
op
::
convolution
op
;
fusion
f
=
{};
fusion
f
p
=
{};
fusion
::
op_t
conv
=
{};
fusion
::
op_t
conv
=
{};
fusion
::
op_t
bias
=
{};
fusion
::
op_t
bias
=
{};
fusion
::
op_t
relu
=
{};
fusion
::
op_t
relu
=
{};
...
@@ -749,18 +754,18 @@ struct miopen_conv_bias_relu
...
@@ -749,18 +754,18 @@ struct miopen_conv_bias_relu
miopenSetOpArgsConvForward
(
fargs
.
get
(),
conv
,
&
alpha
,
&
beta
,
args
[
1
].
implicit
());
miopenSetOpArgsConvForward
(
fargs
.
get
(),
conv
,
&
alpha
,
&
beta
,
args
[
1
].
implicit
());
miopenSetOpArgsBiasForward
(
fargs
.
get
(),
bias
,
&
alpha
,
&
beta
,
args
[
3
].
implicit
());
miopenSetOpArgsBiasForward
(
fargs
.
get
(),
bias
,
&
alpha
,
&
beta
,
args
[
3
].
implicit
());
miopenSetOpArgsActivForward
(
fargs
.
get
(),
relu
,
&
alpha
,
&
beta
,
0
,
0
,
0
);
miopenSetOpArgsActivForward
(
fargs
.
get
(),
relu
,
&
alpha
,
&
beta
,
0
,
0
,
0
);
return
f
.
execute
(
ctx
,
fargs
,
args
[
0
],
args
[
4
]);
return
f
p
.
execute
(
ctx
,
fargs
,
args
[
0
],
args
[
4
]);
}
}
void
finalize
(
context
&
ctx
,
const
shape
&
,
const
std
::
vector
<
shape
>&
inputs
)
void
finalize
(
context
&
ctx
,
const
shape
&
,
const
std
::
vector
<
shape
>&
inputs
)
{
{
f
=
fusion
(
inputs
[
0
]);
f
p
=
fusion
(
inputs
[
0
]);
conv
=
f
.
create_conv
(
op
,
inputs
[
1
]);
conv
=
f
p
.
create_conv
(
op
,
inputs
[
1
]);
bias
=
f
.
create_bias
(
inputs
[
3
]);
bias
=
f
p
.
create_bias
(
inputs
[
3
]);
relu
=
f
.
create_relu
();
relu
=
f
p
.
create_relu
();
f
.
compile
(
ctx
);
f
p
.
compile
(
ctx
);
}
}
shape
get_workspace
(
context
&
ctx
)
{
return
f
.
get_workspace
(
ctx
);
}
shape
get_workspace
(
context
&
ctx
)
{
return
f
p
.
get_workspace
(
ctx
);
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
shapes
)
const
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
shapes
)
const
{
{
return
shapes
.
size
()
-
1
;
return
shapes
.
size
()
-
1
;
...
@@ -778,7 +783,7 @@ auto conv_bias(Ms... ms)
...
@@ -778,7 +783,7 @@ auto conv_bias(Ms... ms)
}
}
template
<
class
Op
>
template
<
class
Op
>
void
apply_conv_bias
(
context
&
ctx
,
module
&
p
,
match
::
matcher_result
r
)
void
apply_conv_bias
(
context
&
ctx
,
module
&
m
,
const
match
::
matcher_result
&
r
)
{
{
auto
conv_ins
=
r
.
instructions
[
"conv"
];
auto
conv_ins
=
r
.
instructions
[
"conv"
];
auto
bias_ins
=
r
.
instructions
[
"bias"
];
auto
bias_ins
=
r
.
instructions
[
"bias"
];
...
@@ -793,7 +798,7 @@ void apply_conv_bias(context& ctx, module& p, match::matcher_result r)
...
@@ -793,7 +798,7 @@ void apply_conv_bias(context& ctx, module& p, match::matcher_result r)
// TODO: Insert ws allocation
// TODO: Insert ws allocation
auto
ws
=
cb
.
get_workspace
(
ctx
);
auto
ws
=
cb
.
get_workspace
(
ctx
);
(
void
)
ws
;
(
void
)
ws
;
p
.
replace_instruction
(
ins
,
cb
,
input_ins
,
weights_ins
,
old_ws_ins
,
bias_ins
,
alloc_ins
);
m
.
replace_instruction
(
ins
,
cb
,
input_ins
,
weights_ins
,
old_ws_ins
,
bias_ins
,
alloc_ins
);
}
}
inline
auto
precompile_name
(
std
::
string
s
)
// NOLINT
inline
auto
precompile_name
(
std
::
string
s
)
// NOLINT
...
@@ -824,9 +829,9 @@ struct find_conv_bias
...
@@ -824,9 +829,9 @@ struct find_conv_bias
match
::
output
(
match
::
name
(
std
::
unordered_set
<
std
::
string
>
{
"gpu::relu"
}))));
match
::
output
(
match
::
name
(
std
::
unordered_set
<
std
::
string
>
{
"gpu::relu"
}))));
}
}
void
apply
(
module
&
p
,
match
::
matcher_result
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
{
apply_conv_bias
<
miopen_conv_bias
>
(
*
ctx
,
p
,
std
::
move
(
r
)
);
apply_conv_bias
<
miopen_conv_bias
>
(
*
ctx
,
m
,
r
);
}
}
};
};
...
@@ -835,9 +840,9 @@ struct find_conv_bias_relu
...
@@ -835,9 +840,9 @@ struct find_conv_bias_relu
context
*
ctx
=
nullptr
;
context
*
ctx
=
nullptr
;
auto
matcher
()
const
{
return
match
::
name
(
"gpu::relu"
)(
match
::
arg
(
0
)(
conv_bias
()));
}
auto
matcher
()
const
{
return
match
::
name
(
"gpu::relu"
)(
match
::
arg
(
0
)(
conv_bias
()));
}
void
apply
(
module
&
p
,
match
::
matcher_result
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
{
apply_conv_bias
<
miopen_conv_bias_relu
>
(
*
ctx
,
p
,
std
::
move
(
r
)
);
apply_conv_bias
<
miopen_conv_bias_relu
>
(
*
ctx
,
m
,
r
);
}
}
};
};
...
@@ -852,7 +857,7 @@ struct find_conv_pointwise
...
@@ -852,7 +857,7 @@ struct find_conv_pointwise
fusable_conv
(
match
::
used_once
()).
bind
(
"conv"
)));
fusable_conv
(
match
::
used_once
()).
bind
(
"conv"
)));
}
}
void
apply
(
module
&
m
,
match
::
matcher_result
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
{
auto
conv_ins
=
r
.
instructions
[
"conv"
];
auto
conv_ins
=
r
.
instructions
[
"conv"
];
auto
bias_ins
=
r
.
instructions
[
"bias"
];
auto
bias_ins
=
r
.
instructions
[
"bias"
];
...
@@ -870,7 +875,6 @@ struct find_conv_pointwise
...
@@ -870,7 +875,6 @@ struct find_conv_pointwise
{
{
if
(
i
.
name
()[
0
]
==
'@'
)
if
(
i
.
name
()[
0
]
==
'@'
)
continue
;
continue
;
auto
inputs
=
to_shapes
(
i
.
inputs
());
op
.
ops
.
push_back
({{
i
.
get_operator
()}});
op
.
ops
.
push_back
({{
i
.
get_operator
()}});
}
}
std
::
vector
<
instruction_ref
>
inputs
=
{
input_ins
,
weights_ins
,
bias_ins
,
alloc_ins
};
std
::
vector
<
instruction_ref
>
inputs
=
{
input_ins
,
weights_ins
,
bias_ins
,
alloc_ins
};
...
@@ -891,7 +895,7 @@ struct find_gemm_add
...
@@ -891,7 +895,7 @@ struct find_gemm_add
match
::
name
(
"gpu::gemm"
)(
match
::
nargs
(
3
)).
bind
(
"gemm"
)));
match
::
name
(
"gpu::gemm"
)(
match
::
nargs
(
3
)).
bind
(
"gemm"
)));
}
}
void
apply
(
module
&
p
,
match
::
matcher_result
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
{
auto
ins
=
r
.
result
;
auto
ins
=
r
.
result
;
auto
gemm_ins
=
r
.
instructions
[
"gemm"
];
auto
gemm_ins
=
r
.
instructions
[
"gemm"
];
...
@@ -903,26 +907,68 @@ struct find_gemm_add
...
@@ -903,26 +907,68 @@ struct find_gemm_add
if
(
not
float_equal
(
gemm
.
beta
,
0
))
if
(
not
float_equal
(
gemm
.
beta
,
0
))
return
;
return
;
if
(
std
::
any_of
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
[](
auto
i
)
{
return
not
i
->
get_shape
().
standard
();
}))
return
;
auto
inputs
=
gemm_ins
->
inputs
();
auto
inputs
=
gemm_ins
->
inputs
();
inputs
.
pop_back
();
inputs
.
pop_back
();
auto
copy_ins
=
c_ins
;
auto
copy_ins
=
c_ins
;
// Insert copy
// Insert copy
if
(
ins
==
p
.
end
()
or
c_ins
->
outputs
().
size
()
>
1
or
c_ins
->
inputs
().
empty
())
if
(
ins
==
m
.
end
()
or
c_ins
->
outputs
().
size
()
>
1
or
c_ins
->
inputs
().
empty
())
{
{
copy_ins
=
p
.
insert_instruction
(
ins
,
hip_copy
{},
c_ins
,
ins
->
inputs
().
back
());
copy_ins
=
m
.
insert_instruction
(
ins
,
hip_copy
{},
c_ins
,
ins
->
inputs
().
back
());
}
}
inputs
.
push_back
(
copy_ins
);
inputs
.
push_back
(
copy_ins
);
inputs
.
push_back
(
copy_ins
);
inputs
.
push_back
(
copy_ins
);
gemm
.
beta
=
1
;
gemm
.
beta
=
1
;
p
.
replace_instruction
(
ins
,
gemm
,
inputs
);
m
.
replace_instruction
(
ins
,
gemm
,
inputs
);
}
};
auto
pointwise_name
(
const
std
::
string
&
s
)
{
return
precompile_name
(
"pointwise"
)(
match
::
make_basic_pred_matcher
([
=
](
auto
ins
)
{
module_ref
pm
=
ins
->
module_inputs
().
front
();
auto
n
=
std
::
count_if
(
pm
->
begin
(),
pm
->
end
(),
[
&
](
auto
&
i
)
{
return
i
.
name
()
==
s
;
});
if
(
n
!=
1
)
return
false
;
return
std
::
all_of
(
pm
->
begin
(),
pm
->
end
(),
[
&
](
auto
&
i
)
{
return
starts_with
(
i
.
name
(),
"@"
)
or
i
.
name
()
==
s
;
});
}));
}
struct
find_gemm_pointwise
{
auto
matcher
()
const
{
return
pointwise_name
(
"add"
)(
match
::
nargs
(
3
),
match
::
all_of
[
match
::
inputs
()](
match
::
standard_shape
()),
match
::
either_arg
(
0
,
1
)(
match
::
used_once
().
bind
(
"c"
),
match
::
name
(
"gpu::gemm"
)(
match
::
nargs
(
3
)).
bind
(
"gemm"
)));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
gemm_ins
=
r
.
instructions
[
"gemm"
];
auto
c_ins
=
r
.
instructions
[
"c"
];
auto
gemm
=
any_cast
<
rocblas_gemm
<
op
::
dot
>>
(
gemm_ins
->
get_operator
());
// Already fused gemm
if
(
not
float_equal
(
gemm
.
beta
,
0
))
return
;
auto
inputs
=
gemm_ins
->
inputs
();
inputs
.
pop_back
();
inputs
.
push_back
(
c_ins
);
inputs
.
push_back
(
ins
->
inputs
().
back
());
gemm
.
beta
=
1
;
m
.
replace_instruction
(
ins
,
gemm
,
inputs
);
}
}
};
};
...
@@ -933,22 +979,22 @@ struct find_commutative_broadcast
...
@@ -933,22 +979,22 @@ struct find_commutative_broadcast
return
match
::
name
(
"gpu::add"
,
"gpu::mul"
)(
match
::
arg
(
1
)(
match
::
broadcast_shape
()));
return
match
::
name
(
"gpu::add"
,
"gpu::mul"
)(
match
::
arg
(
1
)(
match
::
broadcast_shape
()));
}
}
void
apply
(
module
&
p
,
const
match
::
matcher_result
&
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
{
auto
ins
=
r
.
result
;
auto
ins
=
r
.
result
;
auto
args
=
ins
->
inputs
();
auto
args
=
ins
->
inputs
();
move_broadcasted_back
(
args
);
move_broadcasted_back
(
args
);
p
.
replace_instruction
(
ins
,
ins
->
get_operator
(),
args
);
m
.
replace_instruction
(
ins
,
ins
->
get_operator
(),
args
);
}
}
};
};
void
fuse_ops
::
apply
(
module
&
p
)
const
void
fuse_ops
::
apply
(
module
&
m
)
const
{
{
match
::
find_matches
(
p
,
find_gelu
{},
find_gelu_new
{
fast_math
});
match
::
find_matches
(
m
,
find_gelu
{},
find_gelu_new
{
fast_math
});
run_passes
(
p
,
{
dead_code_elimination
{}});
run_passes
(
m
,
{
dead_code_elimination
{}});
match
::
find_matches
(
p
,
find_triadd
{});
match
::
find_matches
(
m
,
find_triadd
{});
match
::
find_matches
(
p
,
match
::
find_matches
(
m
,
find_layernorm
{},
find_layernorm
{},
find_conv_pointwise
{
ctx
},
find_conv_pointwise
{
ctx
},
find_conv_bias_relu
{
ctx
},
find_conv_bias_relu
{
ctx
},
...
@@ -961,8 +1007,12 @@ void fuse_ops::apply(module& p) const
...
@@ -961,8 +1007,12 @@ void fuse_ops::apply(module& p) const
find_add_unary
{
"gpu::sigmoid"
,
hip_add_sigmoid
{},
hip_triadd_sigmoid
{}},
find_add_unary
{
"gpu::sigmoid"
,
hip_add_sigmoid
{},
hip_triadd_sigmoid
{}},
find_add_unary
{
"gpu::tanh"
,
hip_add_tanh
{},
hip_triadd_tanh
{}},
find_add_unary
{
"gpu::tanh"
,
hip_add_tanh
{},
hip_triadd_tanh
{}},
find_add_clip
{});
find_add_clip
{});
run_passes
(
p
,
{
dead_code_elimination
{}});
run_passes
(
m
,
{
dead_code_elimination
{}});
match
::
find_matches
(
p
,
find_triadd_layernorm
{},
find_gemm_add
{},
find_commutative_broadcast
{});
match
::
find_matches
(
m
,
find_triadd_layernorm
{},
find_gemm_add
{},
find_gemm_pointwise
{},
find_commutative_broadcast
{});
}
}
}
// namespace gpu
}
// namespace gpu
...
...
src/targets/gpu/gemm_impl.cpp
100755 → 100644
View file @
11e155c2
#include <rocblas.h>
#include <rocblas.h>
#include <migraphx/gpu/gemm_impl.hpp>
#include <migraphx/gpu/gemm_impl.hpp>
#include <migraphx/reduce_dims.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -27,6 +28,22 @@ rocblas_datatype get_type(shape::type_t type)
...
@@ -27,6 +28,22 @@ rocblas_datatype get_type(shape::type_t type)
MIGRAPHX_THROW
(
"ROCBLAS_GEMM: data type not supported!"
);
MIGRAPHX_THROW
(
"ROCBLAS_GEMM: data type not supported!"
);
}
}
void
blas_shape
(
const
shape
&
s
)
{
if
(
s
.
lens
().
size
()
<
2
)
return
;
if
(
std
::
none_of
(
s
.
strides
().
end
()
-
2
,
s
.
strides
().
end
(),
[
&
](
auto
i
)
{
return
i
==
1
;
}))
MIGRAPHX_THROW
(
"GPU_GEMM: needs to have one matrix stride as 1"
);
if
(
s
.
lens
().
size
()
<
3
)
return
;
shape
batch_shape
{
s
.
type
(),
{
s
.
lens
().
begin
(),
s
.
lens
().
end
()
-
2
},
{
s
.
strides
().
begin
(),
s
.
strides
().
end
()
-
2
}};
auto
batch_shapes
=
reduce_dims
({
batch_shape
});
if
(
batch_shapes
.
front
().
lens
().
size
()
!=
1
)
MIGRAPHX_THROW
(
"GPU_GEMM: Batch dimension is not collapsible"
);
}
template
<
class
R
,
class
...
Ts
,
class
...
Us
>
template
<
class
R
,
class
...
Ts
,
class
...
Us
>
R
rocblas_invoke
(
R
(
*
f
)(
Ts
...),
Us
...
xs
)
R
rocblas_invoke
(
R
(
*
f
)(
Ts
...),
Us
...
xs
)
{
{
...
@@ -36,16 +53,29 @@ R rocblas_invoke(R (*f)(Ts...), Us... xs)
...
@@ -36,16 +53,29 @@ R rocblas_invoke(R (*f)(Ts...), Us... xs)
return
f
(
xs
...,
nullptr
,
nullptr
);
return
f
(
xs
...,
nullptr
,
nullptr
);
}
}
static
bool
is_transposed
(
const
shape
&
s
)
{
if
(
not
s
.
transposed
())
return
false
;
return
s
.
strides
().
back
()
!=
1
;
}
static
rocblas_int
get_batch_stride
(
const
argument
&
a
)
{
return
a
.
get_shape
().
strides
()[
a
.
get_shape
().
strides
().
size
()
-
3
];
}
template
<
class
T
>
template
<
class
T
>
void
gemm_impl
(
context
&
ctx
,
void
gemm_impl
(
context
&
ctx
,
const
shape
&
output_shape
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
,
const
std
::
vector
<
argument
>&
args
,
T
alpha
,
T
alpha
,
T
beta
,
T
beta
,
bool
int8_x4_format
)
bool
int8_x4_format
,
bool
compute_fp32
)
{
{
bool
transa
=
args
[
0
].
get_shape
()
.
transposed
(
);
bool
transa
=
is_transposed
(
args
[
0
].
get_shape
());
bool
transb
=
args
[
1
].
get_shape
()
.
transposed
(
);
bool
transb
=
is_transposed
(
args
[
1
].
get_shape
());
auto
n_dim
=
output_shape
.
lens
().
size
();
auto
n_dim
=
output_shape
.
lens
().
size
();
auto
dim_1
=
n_dim
-
1
;
auto
dim_1
=
n_dim
-
1
;
auto
dim_0
=
n_dim
-
2
;
auto
dim_0
=
n_dim
-
2
;
...
@@ -65,6 +95,11 @@ void gemm_impl(context& ctx,
...
@@ -65,6 +95,11 @@ void gemm_impl(context& ctx,
output_type
=
rocblas_datatype_i32_r
;
output_type
=
rocblas_datatype_i32_r
;
}
}
auto
compute_type
=
output_type
;
auto
compute_type
=
output_type
;
if
(
compute_fp32
)
{
if
(
arg_type
==
rocblas_datatype_f16_r
)
compute_type
=
rocblas_datatype_f32_r
;
}
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
rocblas_gemm_flags
flag
=
rocblas_gemm_flags
flag
=
...
@@ -77,8 +112,19 @@ void gemm_impl(context& ctx,
...
@@ -77,8 +112,19 @@ void gemm_impl(context& ctx,
auto
a_lens
=
args
[
0
].
get_shape
().
lens
();
auto
a_lens
=
args
[
0
].
get_shape
().
lens
();
auto
b_lens
=
args
[
1
].
get_shape
().
lens
();
auto
b_lens
=
args
[
1
].
get_shape
().
lens
();
output_shape
.
visit_type
([
&
](
auto
as
)
{
output_shape
.
visit_type
([
&
](
auto
as
)
{
auto
alpha_r
=
as
(
alpha
);
auto
alpha_r
=
as
(
alpha
);
auto
beta_r
=
as
(
beta
);
auto
beta_r
=
as
(
beta
);
// use void pointer to select different data type if using fp32 mode
void
*
alpha_v
=
&
alpha_r
;
void
*
beta_v
=
&
beta_r
;
if
(
compute_fp32
)
{
alpha_v
=
&
alpha
;
beta_v
=
&
beta
;
}
auto
out_lens
=
output_shape
.
lens
();
auto
out_lens
=
output_shape
.
lens
();
rocblas_int
m
=
out_lens
[
dim_0
];
rocblas_int
m
=
out_lens
[
dim_0
];
rocblas_int
n
=
out_lens
[
dim_1
];
rocblas_int
n
=
out_lens
[
dim_1
];
...
@@ -104,14 +150,14 @@ void gemm_impl(context& ctx,
...
@@ -104,14 +150,14 @@ void gemm_impl(context& ctx,
n
,
n
,
m
,
m
,
k
,
k
,
&
alpha_
r
,
alpha_
v
,
to_pointer
(
args
.
at
(
1
)),
to_pointer
(
args
.
at
(
1
)),
arg_type
,
arg_type
,
ldb
,
ldb
,
to_pointer
(
args
.
at
(
0
)),
to_pointer
(
args
.
at
(
0
)),
arg_type
,
arg_type
,
lda
,
lda
,
&
beta_
r
,
beta_
v
,
to_pointer
(
args
[
2
]),
to_pointer
(
args
[
2
]),
output_type
,
output_type
,
ldc
,
ldc
,
...
@@ -125,6 +171,9 @@ void gemm_impl(context& ctx,
...
@@ -125,6 +171,9 @@ void gemm_impl(context& ctx,
}
}
else
else
{
{
auto
a_stride
=
get_batch_stride
(
args
[
0
]);
auto
b_stride
=
get_batch_stride
(
args
[
1
]);
auto
c_stride
=
get_batch_stride
(
args
[
2
]);
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex
,
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex
,
ctx
.
get_stream
().
get_rocblas
(),
ctx
.
get_stream
().
get_rocblas
(),
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
...
@@ -132,24 +181,24 @@ void gemm_impl(context& ctx,
...
@@ -132,24 +181,24 @@ void gemm_impl(context& ctx,
n
,
n
,
m
,
m
,
k
,
k
,
&
alpha_
r
,
alpha_
v
,
to_pointer
(
args
.
at
(
1
)),
to_pointer
(
args
.
at
(
1
)),
arg_type
,
arg_type
,
ldb
,
ldb
,
k
*
n
,
b_stride
,
to_pointer
(
args
.
at
(
0
)),
to_pointer
(
args
.
at
(
0
)),
arg_type
,
arg_type
,
lda
,
lda
,
m
*
k
,
a_stride
,
&
beta_
r
,
beta_
v
,
to_pointer
(
args
[
2
]),
to_pointer
(
args
[
2
]),
output_type
,
output_type
,
ldc
,
ldc
,
m
*
n
,
c_stride
,
is_3inputs
?
to_pointer
(
args
[
3
])
:
to_pointer
(
args
[
2
]),
is_3inputs
?
to_pointer
(
args
[
3
])
:
to_pointer
(
args
[
2
]),
output_type
,
output_type
,
ldc
,
ldc
,
m
*
n
,
c_stride
,
num_matrices
,
num_matrices
,
compute_type
,
compute_type
,
rocblas_gemm_algo_standard
,
rocblas_gemm_algo_standard
,
...
@@ -164,9 +213,10 @@ void gemm(context& ctx,
...
@@ -164,9 +213,10 @@ void gemm(context& ctx,
const
std
::
vector
<
argument
>&
args
,
const
std
::
vector
<
argument
>&
args
,
float
alpha
,
float
alpha
,
float
beta
,
float
beta
,
bool
int8_x4_format
)
bool
int8_x4_format
,
bool
compute_fp32
)
{
{
gemm_impl
(
ctx
,
output_shape
,
args
,
alpha
,
beta
,
int8_x4_format
);
gemm_impl
(
ctx
,
output_shape
,
args
,
alpha
,
beta
,
int8_x4_format
,
compute_fp32
);
}
}
void
gemm
(
context
&
ctx
,
void
gemm
(
context
&
ctx
,
...
@@ -174,9 +224,10 @@ void gemm(context& ctx,
...
@@ -174,9 +224,10 @@ void gemm(context& ctx,
const
std
::
vector
<
argument
>&
args
,
const
std
::
vector
<
argument
>&
args
,
int32_t
alpha
,
int32_t
alpha
,
int32_t
beta
,
int32_t
beta
,
bool
int8_x4_format
)
bool
int8_x4_format
,
bool
compute_fp32
)
{
{
gemm_impl
(
ctx
,
output_shape
,
args
,
alpha
,
beta
,
int8_x4_format
);
gemm_impl
(
ctx
,
output_shape
,
args
,
alpha
,
beta
,
int8_x4_format
,
compute_fp32
);
}
}
}
// namespace gpu
}
// namespace gpu
...
...
src/targets/gpu/hip.cpp
View file @
11e155c2
...
@@ -27,6 +27,15 @@ using hip_host_ptr = MIGRAPHX_MANAGE_PTR(void, hipHostUnregister);
...
@@ -27,6 +27,15 @@ using hip_host_ptr = MIGRAPHX_MANAGE_PTR(void, hipHostUnregister);
std
::
string
hip_error
(
int
error
)
{
return
hipGetErrorString
(
static_cast
<
hipError_t
>
(
error
));
}
std
::
string
hip_error
(
int
error
)
{
return
hipGetErrorString
(
static_cast
<
hipError_t
>
(
error
));
}
bool
is_device_ptr
(
const
void
*
ptr
)
{
hipPointerAttribute_t
attr
;
auto
status
=
hipPointerGetAttributes
(
&
attr
,
ptr
);
if
(
status
!=
hipSuccess
)
return
false
;
return
attr
.
memoryType
==
hipMemoryTypeDevice
;
}
std
::
size_t
get_available_gpu_memory
()
std
::
size_t
get_available_gpu_memory
()
{
{
size_t
free
;
size_t
free
;
...
@@ -50,8 +59,8 @@ hip_ptr allocate_gpu(std::size_t sz, bool host = false)
...
@@ -50,8 +59,8 @@ hip_ptr allocate_gpu(std::size_t sz, bool host = false)
{
{
if
(
sz
>
get_available_gpu_memory
())
if
(
sz
>
get_available_gpu_memory
())
MIGRAPHX_THROW
(
"Memory not available to allocate buffer: "
+
std
::
to_string
(
sz
));
MIGRAPHX_THROW
(
"Memory not available to allocate buffer: "
+
std
::
to_string
(
sz
));
void
*
result
;
void
*
result
=
nullptr
;
auto
status
=
host
?
hipHostMalloc
(
&
result
,
sz
)
:
hipMalloc
(
&
result
,
sz
);
auto
status
=
host
?
hipHostMalloc
(
&
result
,
sz
)
:
hipMalloc
(
&
result
,
sz
);
if
(
status
!=
hipSuccess
)
if
(
status
!=
hipSuccess
)
{
{
if
(
host
)
if
(
host
)
...
@@ -59,6 +68,7 @@ hip_ptr allocate_gpu(std::size_t sz, bool host = false)
...
@@ -59,6 +68,7 @@ hip_ptr allocate_gpu(std::size_t sz, bool host = false)
else
else
return
allocate_gpu
(
sz
,
true
);
return
allocate_gpu
(
sz
,
true
);
}
}
assert
(
result
!=
nullptr
);
return
hip_ptr
{
result
};
return
hip_ptr
{
result
};
}
}
...
@@ -75,6 +85,8 @@ std::vector<T> read_from_gpu(const void* x, std::size_t sz)
...
@@ -75,6 +85,8 @@ std::vector<T> read_from_gpu(const void* x, std::size_t sz)
{
{
gpu_sync
();
gpu_sync
();
std
::
vector
<
T
>
result
(
sz
);
std
::
vector
<
T
>
result
(
sz
);
assert
(
not
is_device_ptr
(
result
.
data
()));
assert
(
is_device_ptr
(
x
));
auto
status
=
hipMemcpy
(
result
.
data
(),
x
,
sz
*
sizeof
(
T
),
hipMemcpyDeviceToHost
);
auto
status
=
hipMemcpy
(
result
.
data
(),
x
,
sz
*
sizeof
(
T
),
hipMemcpyDeviceToHost
);
if
(
status
!=
hipSuccess
)
if
(
status
!=
hipSuccess
)
MIGRAPHX_THROW
(
"Copy from gpu failed: "
+
hip_error
(
status
));
// NOLINT
MIGRAPHX_THROW
(
"Copy from gpu failed: "
+
hip_error
(
status
));
// NOLINT
...
@@ -85,6 +97,8 @@ hip_ptr write_to_gpu(const void* x, std::size_t sz, bool host = false)
...
@@ -85,6 +97,8 @@ hip_ptr write_to_gpu(const void* x, std::size_t sz, bool host = false)
{
{
gpu_sync
();
gpu_sync
();
auto
result
=
allocate_gpu
(
sz
,
host
);
auto
result
=
allocate_gpu
(
sz
,
host
);
assert
(
is_device_ptr
(
result
.
get
()));
assert
(
not
is_device_ptr
(
x
));
auto
status
=
hipMemcpy
(
result
.
get
(),
x
,
sz
,
hipMemcpyHostToDevice
);
auto
status
=
hipMemcpy
(
result
.
get
(),
x
,
sz
,
hipMemcpyHostToDevice
);
if
(
status
!=
hipSuccess
)
if
(
status
!=
hipSuccess
)
MIGRAPHX_THROW
(
"Copy to gpu failed: "
+
hip_error
(
status
));
MIGRAPHX_THROW
(
"Copy to gpu failed: "
+
hip_error
(
status
));
...
@@ -109,10 +123,9 @@ argument register_on_gpu(const argument& arg)
...
@@ -109,10 +123,9 @@ argument register_on_gpu(const argument& arg)
{
{
auto
arg_shared
=
arg
.
share
();
auto
arg_shared
=
arg
.
share
();
auto
p
=
share
(
register_on_gpu
(
arg_shared
.
data
(),
arg_shared
.
get_shape
().
bytes
()));
auto
p
=
share
(
register_on_gpu
(
arg_shared
.
data
(),
arg_shared
.
get_shape
().
bytes
()));
return
{
arg_shared
.
get_shape
(),
return
{
arg_shared
.
get_shape
(),
[
p
,
a
=
std
::
move
(
arg_shared
)]()
mutable
{
[
p
,
a
=
std
::
move
(
arg_shared
)
]()
mutable
{
return
get_device_ptr
(
p
.
get
());
return
get_device_ptr
(
p
.
get
());
}
}};
// namespace gpu
};
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
argument
to_gpu
(
const
argument
&
arg
,
bool
host
)
argument
to_gpu
(
const
argument
&
arg
,
bool
host
)
...
...
src/targets/gpu/include/migraphx/gpu/analyze_streams.hpp
View file @
11e155c2
...
@@ -11,7 +11,7 @@ struct module;
...
@@ -11,7 +11,7 @@ struct module;
namespace
gpu
{
namespace
gpu
{
std
::
vector
<
stream_race
>
analyze_streams
(
const
module
&
p
);
std
::
vector
<
stream_race
>
analyze_streams
(
const
module
&
m
);
}
// namespace gpu
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/targets/gpu/include/migraphx/gpu/code_object_op.hpp
View file @
11e155c2
...
@@ -34,6 +34,10 @@ struct code_object_op
...
@@ -34,6 +34,10 @@ struct code_object_op
f
(
self
.
output
,
"output"
));
f
(
self
.
output
,
"output"
));
}
}
value
attributes
()
const
{
return
{{
"group"
,
group
()}};
}
std
::
string
group
()
const
{
return
"gpu::code_object::"
+
symbol_name
;
}
std
::
string
name
()
const
{
return
"gpu::code_object"
;
}
std
::
string
name
()
const
{
return
"gpu::code_object"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
;
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
;
argument
argument
...
...
src/targets/gpu/include/migraphx/gpu/compile_gen.hpp
0 → 100644
View file @
11e155c2
#ifndef MIGRAPHX_GUARD_GPU_COMPILE_GEN_HPP
#define MIGRAPHX_GUARD_GPU_COMPILE_GEN_HPP
#include <migraphx/config.hpp>
#include <string>
#include <unordered_map>
#include <vector>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
shape
;
namespace
gpu
{
namespace
gen
{
struct
vectorize
{
std
::
size_t
size
=
1
;
std
::
size_t
axis
=
0
;
static
vectorize
elements
(
std
::
size_t
axis
,
const
std
::
vector
<
shape
>&
inputs
);
std
::
string
str
()
const
;
};
struct
preload
{
std
::
vector
<
bool
>
args
=
{};
static
preload
broadcasts
(
std
::
size_t
axis
,
const
std
::
vector
<
shape
>&
inputs
);
bool
is_preloading
()
const
;
std
::
string
str
()
const
;
};
std
::
size_t
find_fast_axis
(
const
std
::
vector
<
shape
>&
inputs
);
std
::
string
make_transformer_args
(
std
::
vector
<
std
::
string
>
transformers
);
template
<
class
...
Ts
>
std
::
string
make_transformer_args
(
Ts
...
xs
)
{
return
make_transformer_args
({
xs
.
str
()...});
}
}
// namespace gen
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_COMPILE_GEN_HPP
src/targets/gpu/include/migraphx/gpu/compile_hip.hpp
View file @
11e155c2
...
@@ -17,8 +17,6 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
...
@@ -17,8 +17,6 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
std
::
string
enum_params
(
std
::
size_t
count
,
std
::
string
param
);
std
::
string
enum_params
(
std
::
size_t
count
,
std
::
string
param
);
std
::
size_t
compute_global
(
std
::
size_t
n
,
std
::
size_t
local
=
1024
);
}
// namespace gpu
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/targets/gpu/include/migraphx/gpu/compile_hip_code_object.hpp
View file @
11e155c2
...
@@ -8,6 +8,8 @@ namespace migraphx {
...
@@ -8,6 +8,8 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
gpu
{
struct
context
;
struct
hip_compile_options
struct
hip_compile_options
{
{
std
::
size_t
global
;
std
::
size_t
global
;
...
@@ -17,10 +19,35 @@ struct hip_compile_options
...
@@ -17,10 +19,35 @@ struct hip_compile_options
std
::
string
kernel_name
=
"kernel"
;
std
::
string
kernel_name
=
"kernel"
;
std
::
string
params
=
""
;
std
::
string
params
=
""
;
std
::
vector
<
shape
>
virtual_inputs
=
{};
std
::
vector
<
shape
>
virtual_inputs
=
{};
/**
* @brief Set the launch parameters but allow v to override the values
*
* @param v A value class which can have a "global" and/or "local" keys to override the default
* global and local
* @param compute_global A function used to compute the global based on the local
* @param default_local The defaul local to use if its missing from the v parameter
*/
void
set_launch_params
(
const
value
&
v
,
const
std
::
function
<
std
::
size_t
(
std
::
size_t
local
)
>&
compute_global
,
std
::
size_t
default_local
=
1024
);
void
set_launch_params
(
const
value
&
v
,
std
::
size_t
default_global
,
std
::
size_t
default_local
=
1024
)
{
set_launch_params
(
v
,
[
=
](
auto
)
{
return
default_global
;
},
default_local
);
}
};
};
/// Compute global for n elements, but max out on target-specific upper limit
std
::
function
<
std
::
size_t
(
std
::
size_t
local
)
>
compute_global_for
(
context
&
ctx
,
std
::
size_t
n
,
std
::
size_t
over
=
1
);
operation
compile_hip_code_object
(
const
std
::
string
&
content
,
hip_compile_options
options
);
operation
compile_hip_code_object
(
const
std
::
string
&
content
,
hip_compile_options
options
);
std
::
size_t
compute_block_size
(
std
::
size_t
n
,
std
::
size_t
max_block_size
=
1024
);
}
// namespace gpu
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/targets/gpu/include/migraphx/gpu/compile_pointwise.hpp
deleted
100644 → 0
View file @
8a9c5bce
#ifndef MIGRAPHX_GUARD_GPU_COMPILE_POINTWISE_HPP
#define MIGRAPHX_GUARD_GPU_COMPILE_POINTWISE_HPP
#include <migraphx/config.hpp>
#include <migraphx/operation.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
module
;
namespace
gpu
{
struct
context
;
operation
compile_pointwise
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
std
::
string
&
lambda
,
const
std
::
string
&
preamble
=
""
);
operation
compile_pointwise
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
module
m
);
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_COMPILE_POINTWISE_HPP
src/targets/gpu/include/migraphx/gpu/compiler.hpp
0 → 100644
View file @
11e155c2
#ifndef MIGRAPHX_GUARD_GPU_COMPILER_HPP
#define MIGRAPHX_GUARD_GPU_COMPILER_HPP
#include <migraphx/config.hpp>
#include <migraphx/auto_register.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/value.hpp>
#include <migraphx/module.hpp>
#include <migraphx/instruction.hpp>
#include <functional>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
struct
context
;
using
compiler_replace
=
std
::
function
<
void
(
module
&
m
,
instruction_ref
ins
)
>
;
using
compiler_compile
=
std
::
function
<
compiler_replace
(
context
&
,
instruction_ref
,
operation
)
>
;
using
compiler_compile_op
=
std
::
function
<
operation
(
context
&
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
)
>
;
void
register_compiler
(
const
std
::
string
&
name
,
compiler_compile
c
,
compiler_compile_op
cop
);
bool
has_compiler_for
(
const
std
::
string
&
name
);
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
);
operation
compile_op
(
const
std
::
string
&
name
,
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
);
template
<
class
T
>
void
register_compiler
()
{
T
c
;
for
(
auto
&&
name
:
c
.
names
())
{
register_compiler
(
name
,
[
=
](
auto
&&
...
xs
)
{
return
c
.
compile
(
std
::
forward
<
decltype
(
xs
)
>
(
xs
)...);
},
[
=
](
auto
&&
...
xs
)
{
return
c
.
compile_op
(
std
::
forward
<
decltype
(
xs
)
>
(
xs
)...);
});
}
}
struct
register_compiler_action
{
template
<
class
T
>
static
void
apply
()
{
register_compiler
<
T
>
();
}
};
template
<
class
T
>
using
auto_register_compiler
=
auto_register
<
register_compiler_action
,
T
>
;
template
<
class
Derived
>
struct
compiler
:
auto_register_compiler
<
Derived
>
{
auto
replace
(
const
operation
&
op
)
const
{
return
[
=
](
module
&
m
,
instruction_ref
ins
)
{
m
.
replace_instruction
(
ins
,
op
,
ins
->
inputs
());
};
}
operation
compile_op
(
context
&
,
const
std
::
vector
<
shape
>&
,
const
value
&
)
const
{
return
{};
}
};
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_COMPILER_HPP
Prev
1
…
6
7
8
9
10
11
12
13
14
…
20
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment