Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
9dc0b779
Unverified
Commit
9dc0b779
authored
Sep 12, 2022
by
Paul Fultz II
Committed by
GitHub
Sep 12, 2022
Browse files
Merge branch 'develop' into jit-concat
parents
30cde6df
d78bcdfb
Changes
146
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
271 additions
and
92 deletions
+271
-92
src/targets/gpu/device/multinomial.cpp
src/targets/gpu/device/multinomial.cpp
+1
-1
src/targets/gpu/driver/compile_op.cpp
src/targets/gpu/driver/compile_op.cpp
+5
-2
src/targets/gpu/driver/include/migraphx/gpu/driver/perf.hpp
src/targets/gpu/driver/include/migraphx/gpu/driver/perf.hpp
+2
-1
src/targets/gpu/driver/perf.cpp
src/targets/gpu/driver/perf.cpp
+19
-10
src/targets/gpu/driver/run_op.cpp
src/targets/gpu/driver/run_op.cpp
+2
-2
src/targets/gpu/fuse_ops.cpp
src/targets/gpu/fuse_ops.cpp
+137
-20
src/targets/gpu/gemm_impl.cpp
src/targets/gpu/gemm_impl.cpp
+25
-8
src/targets/gpu/include/migraphx/gpu/context.hpp
src/targets/gpu/include/migraphx/gpu/context.hpp
+48
-0
src/targets/gpu/include/migraphx/gpu/gather.hpp
src/targets/gpu/include/migraphx/gpu/gather.hpp
+1
-1
src/targets/gpu/include/migraphx/gpu/gemm.hpp
src/targets/gpu/include/migraphx/gpu/gemm.hpp
+13
-8
src/targets/gpu/include/migraphx/gpu/hip.hpp
src/targets/gpu/include/migraphx/gpu/hip.hpp
+2
-0
src/targets/gpu/include/migraphx/gpu/int8_conv_pack.hpp
src/targets/gpu/include/migraphx/gpu/int8_conv_pack.hpp
+0
-1
src/targets/gpu/include/migraphx/gpu/int8_gemm_pack.hpp
src/targets/gpu/include/migraphx/gpu/int8_gemm_pack.hpp
+0
-1
src/targets/gpu/include/migraphx/gpu/kernel.hpp
src/targets/gpu/include/migraphx/gpu/kernel.hpp
+9
-4
src/targets/gpu/include/migraphx/gpu/logsoftmax.hpp
src/targets/gpu/include/migraphx/gpu/logsoftmax.hpp
+2
-14
src/targets/gpu/include/migraphx/gpu/lrn.hpp
src/targets/gpu/include/migraphx/gpu/lrn.hpp
+1
-1
src/targets/gpu/include/migraphx/gpu/prefuse_ops.hpp
src/targets/gpu/include/migraphx/gpu/prefuse_ops.hpp
+1
-1
src/targets/gpu/include/migraphx/gpu/reverse.hpp
src/targets/gpu/include/migraphx/gpu/reverse.hpp
+1
-1
src/targets/gpu/include/migraphx/gpu/softmax.hpp
src/targets/gpu/include/migraphx/gpu/softmax.hpp
+2
-14
src/targets/gpu/include/migraphx/gpu/sync_device.hpp
src/targets/gpu/include/migraphx/gpu/sync_device.hpp
+0
-2
No files found.
src/targets/gpu/device/multinomial.cpp
View file @
9dc0b779
...
@@ -47,7 +47,7 @@ constexpr Iterator upper_bound(Iterator first, Iterator last, const T& value)
...
@@ -47,7 +47,7 @@ constexpr Iterator upper_bound(Iterator first, Iterator last, const T& value)
it
=
first
;
it
=
first
;
step
=
count
/
2
;
step
=
count
/
2
;
std
::
advance
(
it
,
step
);
std
::
advance
(
it
,
step
);
if
(
!
(
value
<
*
it
))
if
(
not
(
value
<
*
it
))
{
{
first
=
++
it
;
first
=
++
it
;
count
-=
step
+
1
;
count
-=
step
+
1
;
...
...
src/targets/gpu/driver/compile_op.cpp
View file @
9dc0b779
...
@@ -38,8 +38,11 @@ struct compile_op : action<compile_op>
...
@@ -38,8 +38,11 @@ struct compile_op : action<compile_op>
context
ctx
;
context
ctx
;
auto
inputs
=
p
.
parse_shapes
(
v
.
at
(
"inputs"
));
auto
inputs
=
p
.
parse_shapes
(
v
.
at
(
"inputs"
));
auto
op
=
gpu
::
compile_op
(
v
.
at
(
"name"
).
to
<
std
::
string
>
(),
ctx
,
inputs
,
v
);
auto
op
=
gpu
::
compile_op
(
v
.
at
(
"name"
).
to
<
std
::
string
>
(),
ctx
,
inputs
,
v
);
double
t
=
time_op
(
ctx
,
op
,
inputs
,
p
.
get
(
v
,
"iterations"
,
100
));
auto
[
host_time
,
device_time
]
=
time_op
(
ctx
,
op
,
inputs
,
p
.
get
(
v
,
"iterations"
,
100
));
std
::
cout
<<
op
<<
": "
<<
t
<<
"ms"
<<
std
::
endl
;
std
::
cout
<<
op
<<
": "
<<
host_time
<<
"ms"
;
if
(
device_time
>
0
)
std
::
cout
<<
", "
<<
device_time
<<
"ms"
;
std
::
cout
<<
std
::
endl
;
}
}
};
};
...
...
src/targets/gpu/driver/include/migraphx/gpu/driver/perf.hpp
View file @
9dc0b779
...
@@ -33,7 +33,8 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -33,7 +33,8 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
gpu
{
namespace
gpu
{
namespace
driver
{
namespace
driver
{
double
time_op
(
context
&
ctx
,
operation
op
,
const
std
::
vector
<
shape
>&
inputs
,
int
n
=
100
);
std
::
pair
<
double
,
double
>
time_op
(
context
&
ictx
,
operation
op
,
const
std
::
vector
<
shape
>&
inputs
,
int
n
=
100
);
}
// namespace driver
}
// namespace driver
}
// namespace gpu
}
// namespace gpu
...
...
src/targets/gpu/driver/perf.cpp
View file @
9dc0b779
...
@@ -42,22 +42,31 @@ std::vector<argument> generate_arguments(const std::vector<shape>& shapes, unsig
...
@@ -42,22 +42,31 @@ std::vector<argument> generate_arguments(const std::vector<shape>& shapes, unsig
}
}
using
milliseconds
=
std
::
chrono
::
duration
<
double
,
std
::
milli
>
;
using
milliseconds
=
std
::
chrono
::
duration
<
double
,
std
::
milli
>
;
double
time_op
(
context
&
ctx
,
operation
op
,
const
std
::
vector
<
shape
>&
inputs
,
int
n
)
std
::
pair
<
double
,
double
>
time_op
(
context
&
ictx
,
operation
op
,
const
std
::
vector
<
shape
>&
inputs
,
int
n
)
{
{
// TODO: Use std::ref
// TODO: Use std::ref
migraphx
::
context
gctx
=
ctx
;
migraphx
::
context
ctx
=
ictx
;
auto
output
=
op
.
compute_shape
(
inputs
);
auto
&
gctx
=
any_cast
<
migraphx
::
gpu
::
context
>
(
ctx
);
op
.
finalize
(
gctx
,
output
,
inputs
);
auto
output
=
op
.
compute_shape
(
inputs
);
op
.
finalize
(
ctx
,
output
,
inputs
);
auto
args
=
generate_arguments
(
inputs
);
auto
args
=
generate_arguments
(
inputs
);
auto
run
=
[
&
]
{
auto
run
=
[
&
]
{
op
.
compute
(
g
ctx
,
output
,
args
);
op
.
compute
(
ctx
,
output
,
args
);
g
ctx
.
finish
();
ctx
.
finish
();
};
};
gctx
.
enable_perf_measurement
();
run
();
run
();
auto
r
=
range
(
n
);
double
host_time
=
0.0
;
double
t
=
std
::
accumulate
(
double
device_time
=
0.0
;
r
.
begin
(),
r
.
end
(),
double
{
0.0
},
[
&
](
auto
x
,
auto
)
{
return
x
+
time
<
milliseconds
>
(
run
);
});
for
(
auto
i
:
range
(
n
))
return
t
/
n
;
{
(
void
)
i
;
host_time
+=
time
<
milliseconds
>
(
run
);
device_time
+=
gctx
.
get_elapsed_ms
();
}
return
std
::
make_pair
(
host_time
/
n
,
device_time
/
n
);
}
}
}
// namespace driver
}
// namespace driver
...
...
src/targets/gpu/driver/run_op.cpp
View file @
9dc0b779
...
@@ -43,8 +43,8 @@ struct run_op : action<run_op>
...
@@ -43,8 +43,8 @@ struct run_op : action<run_op>
auto
op
=
make_op
(
name
);
auto
op
=
make_op
(
name
);
if
(
v
.
contains
(
"fields"
))
if
(
v
.
contains
(
"fields"
))
op
.
from_value
(
v
.
at
(
"fields"
));
op
.
from_value
(
v
.
at
(
"fields"
));
double
t
=
time_op
(
ctx
,
op
,
inputs
,
p
.
get
(
v
,
"iterations"
,
100
));
auto
[
host_time
,
device_time
]
=
time_op
(
ctx
,
op
,
inputs
,
p
.
get
(
v
,
"iterations"
,
100
));
std
::
cout
<<
op
<<
": "
<<
t
<<
"ms"
<<
std
::
endl
;
std
::
cout
<<
op
<<
": "
<<
host_time
<<
"ms"
<<
std
::
endl
;
}
}
};
};
...
...
src/targets/gpu/fuse_ops.cpp
View file @
9dc0b779
...
@@ -26,7 +26,6 @@
...
@@ -26,7 +26,6 @@
#include <migraphx/gpu/fuse_ops.hpp>
#include <migraphx/gpu/fuse_ops.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/clip.hpp>
#include <migraphx/gpu/convolution.hpp>
#include <migraphx/gpu/convolution.hpp>
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/oper.hpp>
...
@@ -48,8 +47,8 @@
...
@@ -48,8 +47,8 @@
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/array.hpp>
#include <migraphx/array.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/clip.hpp>
#include <cmath>
#include <cmath>
#include <set>
#include <set>
...
@@ -279,6 +278,11 @@ MIGRAPHX_REGISTER_OP(hip_layernorm)
...
@@ -279,6 +278,11 @@ MIGRAPHX_REGISTER_OP(hip_layernorm)
struct
hip_triadd_layernorm
:
ternary_device
<
hip_triadd_layernorm
,
&
device
::
triadd_layernorm
>
struct
hip_triadd_layernorm
:
ternary_device
<
hip_triadd_layernorm
,
&
device
::
triadd_layernorm
>
{
{
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
4
).
standard
();
return
inputs
[
0
];
}
// Empty finalize to skip dimension reduction
// Empty finalize to skip dimension reduction
void
finalize
(
context
&
,
const
shape
&
,
const
std
::
vector
<
shape
>&
)
{}
void
finalize
(
context
&
,
const
shape
&
,
const
std
::
vector
<
shape
>&
)
{}
};
};
...
@@ -943,28 +947,70 @@ struct find_gemm_add
...
@@ -943,28 +947,70 @@ struct find_gemm_add
}
}
};
};
auto
pointwise_name
(
const
std
::
string
&
s
)
{
return
precompile_name
(
"pointwise"
)(
match
::
make_basic_pred_matcher
([
=
](
auto
ins
)
{
module_ref
pm
=
ins
->
module_inputs
().
front
();
auto
n
=
std
::
count_if
(
pm
->
begin
(),
pm
->
end
(),
[
&
](
auto
&
i
)
{
return
i
.
name
()
==
s
;
});
if
(
n
!=
1
)
return
false
;
return
std
::
all_of
(
pm
->
begin
(),
pm
->
end
(),
[
&
](
auto
&
i
)
{
return
starts_with
(
i
.
name
(),
"@"
)
or
i
.
name
()
==
s
;
});
}));
}
struct
find_gemm_pointwise
struct
find_gemm_pointwise
{
{
auto
matcher
()
const
auto
matcher
()
const
{
{
return
p
ointwise_name
(
"add
"
)(
return
p
recompile_name
(
"pointwise
"
)(
match
::
nargs
(
3
),
match
::
nargs
(
3
),
match
::
all_of
[
match
::
inputs
()](
match
::
standard_shape
()),
match
::
either_arg
(
0
,
1
)(
match
::
either_arg
(
0
,
1
)(
match
::
used_once
().
bind
(
"c"
),
match
::
any_of
(
match
::
standard_shape
(),
match
::
is_constant
()).
bind
(
"c"
),
match
::
name
(
"gpu::gemm"
)(
match
::
nargs
(
3
)).
bind
(
"gemm"
)));
match
::
name
(
"gpu::gemm"
)(
match
::
nargs
(
3
),
match
::
used_once
()).
bind
(
"gemm"
)));
}
// TODO: Move to matcher.hpp
static
auto
match_param
(
const
std
::
string
&
name
)
{
return
match
::
make_basic_pred_matcher
([
=
](
auto
ins
)
{
if
(
ins
->
name
()
!=
"@param"
)
return
false
;
auto
p
=
any_cast
<
builtin
::
param
>
(
ins
->
get_operator
());
return
p
.
parameter
==
name
;
});
}
template
<
class
M
>
static
auto
match_mul_const
(
M
m
,
const
std
::
string
&
var
)
{
return
match
::
name
(
"mul"
)(
match
::
either_arg
(
0
,
1
)(
match
::
name
(
"@literal"
).
bind
(
var
),
m
))
.
bind
(
var
+
"_mul"
);
}
static
auto
match_add
(
const
std
::
string
&
input
,
const
std
::
string
&
output
)
{
auto
param
=
match
::
name
(
"@param"
);
auto
add
=
match
::
name
(
"add"
)(
match
::
args
(
param
,
param
));
auto
inner_mul
=
match
::
any_of
(
match_mul_const
(
match_param
(
input
),
"alpha"
),
match_mul_const
(
match_param
(
output
),
"beta"
));
auto
mul_add
=
match
::
name
(
"add"
)(
match
::
either_arg
(
0
,
1
)(
inner_mul
,
param
));
auto
add_mul
=
match_mul_const
(
add
,
"gamma"
);
return
match
::
name
(
"@return"
)(
match
::
args
(
match
::
any_of
(
add
,
mul_add
,
add_mul
)));
}
static
float
get_float
(
instruction_ref
ins
)
{
return
ins
->
get_literal
().
at
<
float
>
();
}
template
<
class
Gemm
>
static
bool
update_gemm
(
Gemm
&
gemm
,
module_ref
pm
,
unsigned
input
)
{
auto
names
=
pm
->
get_parameter_names
();
if
(
names
.
size
()
!=
2
)
return
false
;
std
::
sort
(
names
.
begin
(),
names
.
end
());
unsigned
output
=
input
==
0
?
1
:
0
;
auto
mr
=
match
::
match_instruction
(
*
pm
,
std
::
prev
(
pm
->
end
()),
match_add
(
names
[
input
],
names
[
output
]));
if
(
mr
.
result
==
pm
->
end
())
return
false
;
if
(
contains
(
mr
.
instructions
,
"alpha_mul"
))
gemm
.
alpha
*=
get_float
(
mr
.
instructions
[
"alpha"
]);
else
if
(
contains
(
mr
.
instructions
,
"beta_mul"
))
gemm
.
beta
*=
get_float
(
mr
.
instructions
[
"beta"
]);
else
if
(
contains
(
mr
.
instructions
,
"gamma_mul"
))
{
gemm
.
alpha
*=
get_float
(
mr
.
instructions
[
"gamma"
]);
gemm
.
beta
*=
get_float
(
mr
.
instructions
[
"gamma"
]);
}
return
true
;
}
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
...
@@ -978,6 +1024,19 @@ struct find_gemm_pointwise
...
@@ -978,6 +1024,19 @@ struct find_gemm_pointwise
// Already fused gemm
// Already fused gemm
if
(
not
float_equal
(
gemm
.
beta
,
0
))
if
(
not
float_equal
(
gemm
.
beta
,
0
))
return
;
return
;
gemm
.
beta
=
1
;
if
(
not
update_gemm
(
gemm
,
ins
->
module_inputs
().
front
(),
ins
->
inputs
().
front
()
==
gemm_ins
?
0
:
1
))
return
;
// const-fold input if not standard shape since rocblas can't handle it
if
(
not
c_ins
->
get_shape
().
standard
())
{
auto
c
=
make_op
(
"contiguous"
);
auto
l
=
c
.
compute
(
c
.
compute_shape
({
c_ins
->
get_shape
()}),
{
c_ins
->
eval
()});
c_ins
=
m
.
add_literal
(
l
.
get_shape
(),
l
.
data
());
}
auto
inputs
=
gemm_ins
->
inputs
();
auto
inputs
=
gemm_ins
->
inputs
();
inputs
.
pop_back
();
inputs
.
pop_back
();
...
@@ -985,11 +1044,68 @@ struct find_gemm_pointwise
...
@@ -985,11 +1044,68 @@ struct find_gemm_pointwise
inputs
.
push_back
(
c_ins
);
inputs
.
push_back
(
c_ins
);
inputs
.
push_back
(
ins
->
inputs
().
back
());
inputs
.
push_back
(
ins
->
inputs
().
back
());
gemm
.
beta
=
1
;
m
.
replace_instruction
(
ins
,
gemm
,
inputs
);
m
.
replace_instruction
(
ins
,
gemm
,
inputs
);
}
}
};
};
struct
find_contiguous_tranpose_gemm
{
auto
matcher
()
const
{
return
match
::
name
(
"gpu::contiguous"
)(
match
::
arg
(
0
)(
match
::
name
(
"transpose"
)(
match
::
arg
(
0
)(
match
::
name
(
"gpu::gemm"
)(
match
::
used_once
()).
bind
(
"gemm"
)))
.
bind
(
"transpose"
)));
}
template
<
class
Vector
>
static
bool
is_swapped
(
const
Vector
&
perm
,
std
::
size_t
i
,
std
::
size_t
j
)
{
if
(
i
>=
perm
.
size
()
or
j
>=
perm
.
size
())
return
false
;
auto
perm2
=
perm
;
std
::
iota
(
perm2
.
begin
(),
perm2
.
end
(),
0
);
std
::
swap
(
perm2
[
i
],
perm2
[
j
]);
return
perm2
==
perm
;
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
gemm
=
r
.
instructions
[
"gemm"
];
auto
alloc
=
gemm
->
inputs
().
back
();
auto
transpose
=
r
.
instructions
[
"transpose"
];
auto
perm
=
transpose
->
get_operator
().
to_value
()[
"permutation"
].
to_vector
<
int64_t
>
();
auto
iperm
=
invert_permutation
(
perm
);
if
(
perm
.
size
()
<
3
)
return
;
if
(
not
is_swapped
(
perm
,
perm
.
size
()
-
3
,
perm
.
size
()
-
2
))
return
;
auto
lens
=
gemm
->
get_shape
().
lens
();
if
(
lens
.
size
()
>
3
and
not
std
::
all_of
(
lens
.
begin
(),
lens
.
end
()
-
3
,
[](
auto
i
)
{
return
i
==
1
;
}))
return
;
auto
gemmv
=
gemm
->
get_operator
().
to_value
();
gemmv
[
"trans_batch"
]
=
1
;
auto
s
=
shape
{
alloc
->
get_shape
().
type
(),
reorder_dims
(
alloc
->
get_shape
().
lens
(),
iperm
)};
auto
new_alloc
=
m
.
insert_instruction
(
gemm
,
make_op
(
"allocate"
,
{{
"shape"
,
to_value
(
s
)}}));
auto
alloc_transpose
=
m
.
insert_instruction
(
gemm
,
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
new_alloc
);
auto
inputs
=
gemm
->
inputs
();
inputs
.
back
()
=
alloc_transpose
;
auto
new_gemm
=
m
.
insert_instruction
(
gemm
,
make_op
(
"gpu::gemm"
,
gemmv
),
inputs
);
auto
gemm_transpoe
=
m
.
insert_instruction
(
gemm
,
transpose
->
get_operator
(),
new_gemm
);
m
.
replace_instruction
(
ins
,
gemm_transpoe
);
}
};
struct
find_commutative_broadcast
struct
find_commutative_broadcast
{
{
auto
matcher
()
const
auto
matcher
()
const
...
@@ -1091,6 +1207,7 @@ void fuse_ops::apply(module& m) const
...
@@ -1091,6 +1207,7 @@ void fuse_ops::apply(module& m) const
find_gemm_add
{},
find_gemm_add
{},
find_layernorm_pointwise
{},
find_layernorm_pointwise
{},
find_gemm_pointwise
{},
find_gemm_pointwise
{},
find_contiguous_tranpose_gemm
{},
find_commutative_broadcast
{});
find_commutative_broadcast
{});
match
::
find_matches
(
m
,
find_contiguous
{});
match
::
find_matches
(
m
,
find_contiguous
{});
}
}
...
...
src/targets/gpu/gemm_impl.cpp
View file @
9dc0b779
...
@@ -24,6 +24,7 @@
...
@@ -24,6 +24,7 @@
#include <rocblas.h>
#include <rocblas.h>
#include <migraphx/gpu/gemm_impl.hpp>
#include <migraphx/gpu/gemm_impl.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/permutation.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -67,6 +68,19 @@ void blas_shape(const shape& s)
...
@@ -67,6 +68,19 @@ void blas_shape(const shape& s)
MIGRAPHX_THROW
(
"GPU_GEMM: Batch dimension is not collapsible"
);
MIGRAPHX_THROW
(
"GPU_GEMM: Batch dimension is not collapsible"
);
}
}
shape
transpose_batch
(
const
shape
&
s
,
unsigned
trans_batch
)
{
if
(
trans_batch
==
0
)
return
s
;
if
(
s
.
lens
().
size
()
<
3
)
return
s
;
auto
batch
=
s
.
lens
().
size
()
-
3
;
std
::
vector
<
int64_t
>
perm
(
s
.
lens
().
size
());
std
::
iota
(
perm
.
begin
(),
perm
.
end
(),
0
);
std
::
swap
(
perm
[
batch
],
perm
[
batch
+
trans_batch
]);
return
shape
::
from_permutation
(
s
.
type
(),
s
.
lens
(),
perm
);
}
template
<
class
R
,
class
...
Ts
,
class
...
Us
>
template
<
class
R
,
class
...
Ts
,
class
...
Us
>
R
rocblas_invoke
(
R
(
*
f
)(
Ts
...),
Us
...
xs
)
R
rocblas_invoke
(
R
(
*
f
)(
Ts
...),
Us
...
xs
)
{
{
...
@@ -97,6 +111,12 @@ void gemm_impl(context& ctx,
...
@@ -97,6 +111,12 @@ void gemm_impl(context& ctx,
bool
int8_x4_format
,
bool
int8_x4_format
,
bool
compute_fp32
)
bool
compute_fp32
)
{
{
const
bool
is_3inputs
=
(
args
.
size
()
==
4
);
if
(
not
is_3inputs
)
{
beta
=
0
;
}
bool
transa
=
is_transposed
(
args
[
0
].
get_shape
());
bool
transa
=
is_transposed
(
args
[
0
].
get_shape
());
bool
transb
=
is_transposed
(
args
[
1
].
get_shape
());
bool
transb
=
is_transposed
(
args
[
1
].
get_shape
());
auto
n_dim
=
output_shape
.
lens
().
size
();
auto
n_dim
=
output_shape
.
lens
().
size
();
...
@@ -105,12 +125,8 @@ void gemm_impl(context& ctx,
...
@@ -105,12 +125,8 @@ void gemm_impl(context& ctx,
rocblas_int
lda
=
args
[
0
].
get_shape
().
strides
()[
transa
?
dim_1
:
dim_0
];
rocblas_int
lda
=
args
[
0
].
get_shape
().
strides
()[
transa
?
dim_1
:
dim_0
];
rocblas_int
ldb
=
args
[
1
].
get_shape
().
strides
()[
transb
?
dim_1
:
dim_0
];
rocblas_int
ldb
=
args
[
1
].
get_shape
().
strides
()[
transb
?
dim_1
:
dim_0
];
rocblas_int
ldc
=
args
[
2
].
get_shape
().
strides
()[
dim_0
];
rocblas_int
ldc
=
args
[
2
].
get_shape
().
strides
()[
dim_0
];
rocblas_int
ldd
=
is_3inputs
?
args
[
3
].
get_shape
().
strides
()[
dim_0
]
:
ldc
;
bool
is_3inputs
=
(
args
.
size
()
==
4
);
if
(
!
is_3inputs
)
{
beta
=
0
;
}
rocblas_datatype
arg_type
=
get_type
(
args
[
0
].
get_shape
().
type
());
rocblas_datatype
arg_type
=
get_type
(
args
[
0
].
get_shape
().
type
());
auto
output_type
=
arg_type
;
auto
output_type
=
arg_type
;
if
(
output_type
==
rocblas_datatype_i8_r
)
if
(
output_type
==
rocblas_datatype_i8_r
)
...
@@ -186,7 +202,7 @@ void gemm_impl(context& ctx,
...
@@ -186,7 +202,7 @@ void gemm_impl(context& ctx,
ldc
,
ldc
,
is_3inputs
?
to_pointer
(
args
[
3
])
:
to_pointer
(
args
[
2
]),
is_3inputs
?
to_pointer
(
args
[
3
])
:
to_pointer
(
args
[
2
]),
output_type
,
output_type
,
ld
c
,
ld
d
,
compute_type
,
compute_type
,
rocblas_gemm_algo_standard
,
rocblas_gemm_algo_standard
,
0
,
0
,
...
@@ -197,6 +213,7 @@ void gemm_impl(context& ctx,
...
@@ -197,6 +213,7 @@ void gemm_impl(context& ctx,
auto
a_stride
=
get_batch_stride
(
args
[
0
]);
auto
a_stride
=
get_batch_stride
(
args
[
0
]);
auto
b_stride
=
get_batch_stride
(
args
[
1
]);
auto
b_stride
=
get_batch_stride
(
args
[
1
]);
auto
c_stride
=
get_batch_stride
(
args
[
2
]);
auto
c_stride
=
get_batch_stride
(
args
[
2
]);
auto
d_stride
=
is_3inputs
?
get_batch_stride
(
args
[
3
])
:
c_stride
;
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex
,
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex
,
ctx
.
get_stream
().
get_rocblas
(),
ctx
.
get_stream
().
get_rocblas
(),
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
...
@@ -220,8 +237,8 @@ void gemm_impl(context& ctx,
...
@@ -220,8 +237,8 @@ void gemm_impl(context& ctx,
c_stride
,
c_stride
,
is_3inputs
?
to_pointer
(
args
[
3
])
:
to_pointer
(
args
[
2
]),
is_3inputs
?
to_pointer
(
args
[
3
])
:
to_pointer
(
args
[
2
]),
output_type
,
output_type
,
ld
c
,
ld
d
,
c
_stride
,
d
_stride
,
num_matrices
,
num_matrices
,
compute_type
,
compute_type
,
rocblas_gemm_algo_standard
,
rocblas_gemm_algo_standard
,
...
...
src/targets/gpu/include/migraphx/gpu/context.hpp
View file @
9dc0b779
...
@@ -244,6 +244,15 @@ struct context
...
@@ -244,6 +244,15 @@ struct context
return
hip_event_ptr
{
event
};
return
hip_event_ptr
{
event
};
}
}
static
hip_event_ptr
create_event_for_timing
()
{
hipEvent_t
event
;
auto
status
=
hipEventCreate
(
&
event
);
if
(
status
!=
hipSuccess
)
MIGRAPHX_THROW
(
"Failed to create event"
);
return
hip_event_ptr
{
event
};
}
value
to_value
()
const
value
to_value
()
const
{
{
value
result
;
value
result
;
...
@@ -267,10 +276,49 @@ struct context
...
@@ -267,10 +276,49 @@ struct context
any_ptr
get_queue
()
{
return
get_stream
().
get
();
}
any_ptr
get_queue
()
{
return
get_stream
().
get
();
}
void
enable_perf_measurement
(
bool
b
=
true
)
{
if
(
b
)
{
start_event
=
create_event_for_timing
();
stop_event
=
create_event_for_timing
();
get_stream
().
record
(
start_event
.
get
());
get_stream
().
record
(
stop_event
.
get
());
}
else
{
start_event
=
nullptr
;
stop_event
=
nullptr
;
}
measure_perf
=
b
;
}
std
::
pair
<
hipEvent_t
,
hipEvent_t
>
get_perf_events
()
const
{
if
(
measure_perf
)
return
std
::
make_pair
(
start_event
.
get
(),
stop_event
.
get
());
return
std
::
make_pair
(
nullptr
,
nullptr
);
}
float
get_elapsed_ms
()
const
{
float
result
=
0
;
if
(
start_event
!=
nullptr
and
stop_event
!=
nullptr
)
{
auto
status
=
hipEventElapsedTime
(
&
result
,
start_event
.
get
(),
stop_event
.
get
());
if
(
status
!=
hipSuccess
)
MIGRAPHX_THROW
(
"Failed hipEventElapsedTime: "
+
hip_error
(
status
));
}
return
result
;
}
private:
private:
// TODO: Make this a vector to support multiple devices
// TODO: Make this a vector to support multiple devices
std
::
shared_ptr
<
hip_device
>
current_device
;
std
::
shared_ptr
<
hip_device
>
current_device
;
std
::
vector
<
shared
<
hip_event_ptr
>>
events
;
std
::
vector
<
shared
<
hip_event_ptr
>>
events
;
bool
measure_perf
=
false
;
shared
<
hip_event_ptr
>
start_event
=
nullptr
;
shared
<
hip_event_ptr
>
stop_event
=
nullptr
;
};
};
inline
void
migraphx_to_value
(
value
&
v
,
const
context
&
ctx
)
{
v
=
ctx
.
to_value
();
}
inline
void
migraphx_to_value
(
value
&
v
,
const
context
&
ctx
)
{
v
=
ctx
.
to_value
();
}
...
...
src/targets/gpu/include/migraphx/gpu/gather.hpp
View file @
9dc0b779
...
@@ -27,7 +27,7 @@
...
@@ -27,7 +27,7 @@
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/op/gather.hpp>
#include <migraphx/op/gather.hpp>
#include <migraphx/gpu/
miopen
.hpp>
#include <migraphx/gpu/
context
.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/targets/gpu/include/migraphx/gpu/gemm.hpp
View file @
9dc0b779
...
@@ -42,15 +42,17 @@ namespace gpu {
...
@@ -42,15 +42,17 @@ namespace gpu {
struct
context
;
struct
context
;
void
blas_shape
(
const
shape
&
s
);
void
blas_shape
(
const
shape
&
s
);
shape
transpose_batch
(
const
shape
&
s
,
unsigned
trans_batch
);
template
<
class
Op
>
template
<
class
Op
>
struct
rocblas_gemm
struct
rocblas_gemm
{
{
Op
op
;
Op
op
;
float
alpha
=
1
;
float
alpha
=
1
;
float
beta
=
0
;
float
beta
=
0
;
bool
int8_x4_format
=
true
;
bool
int8_x4_format
=
true
;
bool
compute_fp32
=
false
;
bool
compute_fp32
=
false
;
unsigned
trans_batch
=
0
;
template
<
class
Self
,
class
F
>
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
static
auto
reflect
(
Self
&
self
,
F
f
)
...
@@ -58,7 +60,9 @@ struct rocblas_gemm
...
@@ -58,7 +60,9 @@ struct rocblas_gemm
return
pack_join
(
migraphx
::
reflect
(
self
.
op
,
f
),
return
pack_join
(
migraphx
::
reflect
(
self
.
op
,
f
),
pack
(
f
(
self
.
alpha
,
"alpha"
),
pack
(
f
(
self
.
alpha
,
"alpha"
),
f
(
self
.
beta
,
"beta"
),
f
(
self
.
beta
,
"beta"
),
f
(
self
.
int8_x4_format
,
"int8_x4_format"
)));
f
(
self
.
int8_x4_format
,
"int8_x4_format"
),
f
(
self
.
compute_fp32
,
"compute_fp32"
),
f
(
self
.
trans_batch
,
"trans_batch"
)));
}
}
std
::
string
name
()
const
std
::
string
name
()
const
...
@@ -74,13 +78,14 @@ struct rocblas_gemm
...
@@ -74,13 +78,14 @@ struct rocblas_gemm
{
{
std
::
vector
<
shape
>
in_shapes
(
inputs
);
std
::
vector
<
shape
>
in_shapes
(
inputs
);
in_shapes
.
pop_back
();
in_shapes
.
pop_back
();
check_shapes
{
in_shapes
,
*
this
}.
not_broadcasted
(
);
check_shapes
{
in_shapes
,
*
this
}.
has
(
2
,
3
);
blas_shape
(
inputs
[
0
]);
blas_shape
(
inputs
[
0
]);
blas_shape
(
inputs
[
1
]);
blas_shape
(
inputs
[
1
]);
// if gemm and add are fused
// if gemm and add are fused
if
(
in_shapes
.
size
()
>
2
)
if
(
in_shapes
.
size
()
>
2
)
{
{
auto
cmat_shape
=
in_shapes
.
back
();
auto
cmat_shape
=
in_shapes
.
back
();
check_shapes
{{
cmat_shape
},
*
this
}.
not_transposed
().
not_broadcasted
();
in_shapes
.
pop_back
();
in_shapes
.
pop_back
();
blas_shape
(
cmat_shape
);
blas_shape
(
cmat_shape
);
auto
op_out_shape
=
op
.
compute_shape
(
in_shapes
);
auto
op_out_shape
=
op
.
compute_shape
(
in_shapes
);
...
@@ -97,10 +102,10 @@ struct rocblas_gemm
...
@@ -97,10 +102,10 @@ struct rocblas_gemm
to_string
(
cmat_shape
.
type
())
+
to_string
(
cmat_shape
.
type
())
+
", it must be: "
+
to_string
(
op_out_shape
.
type
()));
", it must be: "
+
to_string
(
op_out_shape
.
type
()));
}
}
return
op_out_shape
;
return
transpose_batch
(
op_out_shape
,
trans_batch
)
;
}
}
return
op
.
compute_shape
(
in_shapes
);
return
transpose_batch
(
op
.
compute_shape
(
in_shapes
)
,
trans_batch
)
;
}
}
argument
argument
...
...
src/targets/gpu/include/migraphx/gpu/hip.hpp
View file @
9dc0b779
...
@@ -37,6 +37,8 @@ namespace gpu {
...
@@ -37,6 +37,8 @@ namespace gpu {
struct
context
;
struct
context
;
std
::
string
hip_error
(
int
error
);
argument
allocate_gpu
(
const
shape
&
s
,
bool
host
=
false
);
argument
allocate_gpu
(
const
shape
&
s
,
bool
host
=
false
);
argument
register_on_gpu
(
const
argument
&
arg
);
argument
register_on_gpu
(
const
argument
&
arg
);
...
...
src/targets/gpu/include/migraphx/gpu/int8_conv_pack.hpp
View file @
9dc0b779
...
@@ -25,7 +25,6 @@
...
@@ -25,7 +25,6 @@
#define MIGRAPHX_GUARD_RTGLIB_INT8_CONV_PACK_HPP
#define MIGRAPHX_GUARD_RTGLIB_INT8_CONV_PACK_HPP
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <utility>
#include <utility>
...
...
src/targets/gpu/include/migraphx/gpu/int8_gemm_pack.hpp
View file @
9dc0b779
...
@@ -25,7 +25,6 @@
...
@@ -25,7 +25,6 @@
#define MIGRAPHX_GUARD_RTGLIB_INT8_GEMM_PACK_HPP
#define MIGRAPHX_GUARD_RTGLIB_INT8_GEMM_PACK_HPP
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <utility>
#include <utility>
...
...
src/targets/gpu/include/migraphx/gpu/kernel.hpp
View file @
9dc0b779
...
@@ -50,17 +50,22 @@ struct kernel
...
@@ -50,17 +50,22 @@ struct kernel
void
launch
(
hipStream_t
stream
,
void
launch
(
hipStream_t
stream
,
std
::
size_t
global
,
std
::
size_t
global
,
std
::
size_t
local
,
std
::
size_t
local
,
const
std
::
vector
<
kernel_argument
>&
args
)
const
;
const
std
::
vector
<
kernel_argument
>&
args
,
hipEvent_t
start
=
nullptr
,
hipEvent_t
stop
=
nullptr
)
const
;
void
launch
(
hipStream_t
stream
,
void
launch
(
hipStream_t
stream
,
std
::
size_t
global
,
std
::
size_t
global
,
std
::
size_t
local
,
std
::
size_t
local
,
std
::
vector
<
void
*>
args
)
const
;
std
::
vector
<
void
*>
args
,
hipEvent_t
start
=
nullptr
,
hipEvent_t
stop
=
nullptr
)
const
;
auto
launch
(
hipStream_t
stream
,
std
::
size_t
global
,
std
::
size_t
local
)
const
template
<
class
...
Ts
>
auto
launch
(
hipStream_t
stream
,
std
::
size_t
global
,
std
::
size_t
local
,
Ts
...
zs
)
const
{
{
return
[
=
](
auto
&&
...
xs
)
{
return
[
=
](
auto
&&
...
xs
)
{
launch
(
stream
,
global
,
local
,
std
::
vector
<
kernel_argument
>
{
xs
...});
launch
(
stream
,
global
,
local
,
std
::
vector
<
kernel_argument
>
{
xs
...}
,
zs
...
);
};
};
}
}
...
...
src/targets/gpu/include/migraphx/gpu/logsoftmax.hpp
View file @
9dc0b779
...
@@ -24,22 +24,10 @@
...
@@ -24,22 +24,10 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_LOGSOFTMAX_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_LOGSOFTMAX_HPP
#define MIGRAPHX_GUARD_RTGLIB_LOGSOFTMAX_HPP
#define MIGRAPHX_GUARD_RTGLIB_LOGSOFTMAX_HPP
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/op/logsoftmax.hpp>
#include <migraphx/op/logsoftmax.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/config.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/gpu/device/contiguous.hpp>
#include <migraphx/gpu/device/add.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/context.hpp>
#include <utility>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/targets/gpu/include/migraphx/gpu/lrn.hpp
View file @
9dc0b779
...
@@ -26,7 +26,7 @@
...
@@ -26,7 +26,7 @@
#include <migraphx/shape.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/gpu/
miopen
.hpp>
#include <migraphx/gpu/
context
.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/targets/gpu/include/migraphx/gpu/prefuse_ops.hpp
View file @
9dc0b779
...
@@ -25,7 +25,7 @@
...
@@ -25,7 +25,7 @@
#define MIGRAPHX_GUARD_GPU_PREFUSE_OPS_HPP
#define MIGRAPHX_GUARD_GPU_PREFUSE_OPS_HPP
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <
migraphx/gpu/context.hpp
>
#include <
string
>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/targets/gpu/include/migraphx/gpu/reverse.hpp
View file @
9dc0b779
...
@@ -27,7 +27,7 @@
...
@@ -27,7 +27,7 @@
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/op/reverse.hpp>
#include <migraphx/op/reverse.hpp>
#include <migraphx/gpu/
miopen
.hpp>
#include <migraphx/gpu/
context
.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/targets/gpu/include/migraphx/gpu/softmax.hpp
View file @
9dc0b779
...
@@ -24,22 +24,10 @@
...
@@ -24,22 +24,10 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_SOFTMAX_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_SOFTMAX_HPP
#define MIGRAPHX_GUARD_RTGLIB_SOFTMAX_HPP
#define MIGRAPHX_GUARD_RTGLIB_SOFTMAX_HPP
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/op/softmax.hpp>
#include <migraphx/op/softmax.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/config.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/gpu/device/contiguous.hpp>
#include <migraphx/gpu/device/add.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/context.hpp>
#include <utility>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/targets/gpu/include/migraphx/gpu/sync_device.hpp
View file @
9dc0b779
...
@@ -25,8 +25,6 @@
...
@@ -25,8 +25,6 @@
#define MIGRAPHX_GUARD_RTGLIB_GPU_SYNC_DEVICE_HPP
#define MIGRAPHX_GUARD_RTGLIB_GPU_SYNC_DEVICE_HPP
#include <string>
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
namespace
migraphx
{
...
...
Prev
1
2
3
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment