Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
37ddce62
Unverified
Commit
37ddce62
authored
Sep 06, 2022
by
kahmed10
Committed by
GitHub
Sep 06, 2022
Browse files
Merge branch 'develop' into jit-layernorm-merge
parents
d705e483
d37a4df9
Changes
119
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
305 additions
and
73 deletions
+305
-73
src/targets/gpu/driver/perf.cpp
src/targets/gpu/driver/perf.cpp
+19
-10
src/targets/gpu/driver/run_op.cpp
src/targets/gpu/driver/run_op.cpp
+2
-2
src/targets/gpu/fuse_ops.cpp
src/targets/gpu/fuse_ops.cpp
+138
-18
src/targets/gpu/gemm_impl.cpp
src/targets/gpu/gemm_impl.cpp
+25
-8
src/targets/gpu/include/migraphx/gpu/context.hpp
src/targets/gpu/include/migraphx/gpu/context.hpp
+48
-0
src/targets/gpu/include/migraphx/gpu/gemm.hpp
src/targets/gpu/include/migraphx/gpu/gemm.hpp
+13
-8
src/targets/gpu/include/migraphx/gpu/hip.hpp
src/targets/gpu/include/migraphx/gpu/hip.hpp
+2
-0
src/targets/gpu/include/migraphx/gpu/kernel.hpp
src/targets/gpu/include/migraphx/gpu/kernel.hpp
+9
-4
src/targets/gpu/kernel.cpp
src/targets/gpu/kernel.cpp
+30
-7
src/targets/gpu/kernels/include/migraphx/kernels/algorithm.hpp
...argets/gpu/kernels/include/migraphx/kernels/algorithm.hpp
+1
-1
src/targets/gpu/kernels/include/migraphx/kernels/array.hpp
src/targets/gpu/kernels/include/migraphx/kernels/array.hpp
+1
-1
src/targets/gpu/kernels/include/migraphx/kernels/integral_constant.hpp
...pu/kernels/include/migraphx/kernels/integral_constant.hpp
+3
-3
src/targets/gpu/lowering.cpp
src/targets/gpu/lowering.cpp
+1
-1
src/targets/gpu/mlir.cpp
src/targets/gpu/mlir.cpp
+2
-2
src/targets/gpu/pack_int8_args.cpp
src/targets/gpu/pack_int8_args.cpp
+1
-1
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+5
-2
src/tf/parse_conv.cpp
src/tf/parse_conv.cpp
+1
-1
src/tf/parse_depthwiseconv.cpp
src/tf/parse_depthwiseconv.cpp
+1
-1
src/tf/parse_pooling.cpp
src/tf/parse_pooling.cpp
+1
-1
src/tf/tf_parser.cpp
src/tf/tf_parser.cpp
+2
-2
No files found.
src/targets/gpu/driver/perf.cpp
View file @
37ddce62
...
...
@@ -42,22 +42,31 @@ std::vector<argument> generate_arguments(const std::vector<shape>& shapes, unsig
}
using
milliseconds
=
std
::
chrono
::
duration
<
double
,
std
::
milli
>
;
double
time_op
(
context
&
ctx
,
operation
op
,
const
std
::
vector
<
shape
>&
inputs
,
int
n
)
std
::
pair
<
double
,
double
>
time_op
(
context
&
ictx
,
operation
op
,
const
std
::
vector
<
shape
>&
inputs
,
int
n
)
{
// TODO: Use std::ref
migraphx
::
context
gctx
=
ctx
;
auto
output
=
op
.
compute_shape
(
inputs
);
op
.
finalize
(
gctx
,
output
,
inputs
);
migraphx
::
context
ctx
=
ictx
;
auto
&
gctx
=
any_cast
<
migraphx
::
gpu
::
context
>
(
ctx
);
auto
output
=
op
.
compute_shape
(
inputs
);
op
.
finalize
(
ctx
,
output
,
inputs
);
auto
args
=
generate_arguments
(
inputs
);
auto
run
=
[
&
]
{
op
.
compute
(
g
ctx
,
output
,
args
);
g
ctx
.
finish
();
op
.
compute
(
ctx
,
output
,
args
);
ctx
.
finish
();
};
gctx
.
enable_perf_measurement
();
run
();
auto
r
=
range
(
n
);
double
t
=
std
::
accumulate
(
r
.
begin
(),
r
.
end
(),
double
{
0.0
},
[
&
](
auto
x
,
auto
)
{
return
x
+
time
<
milliseconds
>
(
run
);
});
return
t
/
n
;
double
host_time
=
0.0
;
double
device_time
=
0.0
;
for
(
auto
i
:
range
(
n
))
{
(
void
)
i
;
host_time
+=
time
<
milliseconds
>
(
run
);
device_time
+=
gctx
.
get_elapsed_ms
();
}
return
std
::
make_pair
(
host_time
/
n
,
device_time
/
n
);
}
}
// namespace driver
...
...
src/targets/gpu/driver/run_op.cpp
View file @
37ddce62
...
...
@@ -43,8 +43,8 @@ struct run_op : action<run_op>
auto
op
=
make_op
(
name
);
if
(
v
.
contains
(
"fields"
))
op
.
from_value
(
v
.
at
(
"fields"
));
double
t
=
time_op
(
ctx
,
op
,
inputs
,
p
.
get
(
v
,
"iterations"
,
100
));
std
::
cout
<<
op
<<
": "
<<
t
<<
"ms"
<<
std
::
endl
;
auto
[
host_time
,
device_time
]
=
time_op
(
ctx
,
op
,
inputs
,
p
.
get
(
v
,
"iterations"
,
100
));
std
::
cout
<<
op
<<
": "
<<
host_time
<<
"ms"
<<
std
::
endl
;
}
};
...
...
src/targets/gpu/fuse_ops.cpp
View file @
37ddce62
...
...
@@ -48,8 +48,10 @@
#include <migraphx/instruction.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/array.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/clip.hpp>
#include <migraphx/op/contiguous.hpp>
#include <cmath>
#include <set>
...
...
@@ -279,6 +281,11 @@ MIGRAPHX_REGISTER_OP(hip_layernorm)
struct
hip_triadd_layernorm
:
ternary_device
<
hip_triadd_layernorm
,
&
device
::
triadd_layernorm
>
{
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
4
).
standard
();
return
inputs
[
0
];
}
// Empty finalize to skip dimension reduction
void
finalize
(
context
&
,
const
shape
&
,
const
std
::
vector
<
shape
>&
)
{}
};
...
...
@@ -943,28 +950,70 @@ struct find_gemm_add
}
};
auto
pointwise_name
(
const
std
::
string
&
s
)
{
return
precompile_name
(
"pointwise"
)(
match
::
make_basic_pred_matcher
([
=
](
auto
ins
)
{
module_ref
pm
=
ins
->
module_inputs
().
front
();
auto
n
=
std
::
count_if
(
pm
->
begin
(),
pm
->
end
(),
[
&
](
auto
&
i
)
{
return
i
.
name
()
==
s
;
});
if
(
n
!=
1
)
return
false
;
return
std
::
all_of
(
pm
->
begin
(),
pm
->
end
(),
[
&
](
auto
&
i
)
{
return
starts_with
(
i
.
name
(),
"@"
)
or
i
.
name
()
==
s
;
});
}));
}
struct
find_gemm_pointwise
{
auto
matcher
()
const
{
return
p
ointwise_name
(
"add
"
)(
return
p
recompile_name
(
"pointwise
"
)(
match
::
nargs
(
3
),
match
::
all_of
[
match
::
inputs
()](
match
::
standard_shape
()),
match
::
either_arg
(
0
,
1
)(
match
::
used_once
().
bind
(
"c"
),
match
::
name
(
"gpu::gemm"
)(
match
::
nargs
(
3
)).
bind
(
"gemm"
)));
match
::
either_arg
(
0
,
1
)(
match
::
any_of
(
match
::
standard_shape
(),
match
::
is_constant
()).
bind
(
"c"
),
match
::
name
(
"gpu::gemm"
)(
match
::
nargs
(
3
),
match
::
used_once
()).
bind
(
"gemm"
)));
}
// TODO: Move to matcher.hpp
static
auto
match_param
(
const
std
::
string
&
name
)
{
return
match
::
make_basic_pred_matcher
([
=
](
auto
ins
)
{
if
(
ins
->
name
()
!=
"@param"
)
return
false
;
auto
p
=
any_cast
<
builtin
::
param
>
(
ins
->
get_operator
());
return
p
.
parameter
==
name
;
});
}
template
<
class
M
>
static
auto
match_mul_const
(
M
m
,
const
std
::
string
&
var
)
{
return
match
::
name
(
"mul"
)(
match
::
either_arg
(
0
,
1
)(
match
::
name
(
"@literal"
).
bind
(
var
),
m
))
.
bind
(
var
+
"_mul"
);
}
static
auto
match_add
(
const
std
::
string
&
input
,
const
std
::
string
&
output
)
{
auto
param
=
match
::
name
(
"@param"
);
auto
add
=
match
::
name
(
"add"
)(
match
::
args
(
param
,
param
));
auto
inner_mul
=
match
::
any_of
(
match_mul_const
(
match_param
(
input
),
"alpha"
),
match_mul_const
(
match_param
(
output
),
"beta"
));
auto
mul_add
=
match
::
name
(
"add"
)(
match
::
either_arg
(
0
,
1
)(
inner_mul
,
param
));
auto
add_mul
=
match_mul_const
(
add
,
"gamma"
);
return
match
::
name
(
"@return"
)(
match
::
args
(
match
::
any_of
(
add
,
mul_add
,
add_mul
)));
}
static
float
get_float
(
instruction_ref
ins
)
{
return
ins
->
get_literal
().
at
<
float
>
();
}
template
<
class
Gemm
>
static
bool
update_gemm
(
Gemm
&
gemm
,
module_ref
pm
,
unsigned
input
)
{
auto
names
=
pm
->
get_parameter_names
();
if
(
names
.
size
()
!=
2
)
return
false
;
std
::
sort
(
names
.
begin
(),
names
.
end
());
unsigned
output
=
input
==
0
?
1
:
0
;
auto
mr
=
match
::
match_instruction
(
*
pm
,
std
::
prev
(
pm
->
end
()),
match_add
(
names
[
input
],
names
[
output
]));
if
(
mr
.
result
==
pm
->
end
())
return
false
;
if
(
contains
(
mr
.
instructions
,
"alpha_mul"
))
gemm
.
alpha
*=
get_float
(
mr
.
instructions
[
"alpha"
]);
else
if
(
contains
(
mr
.
instructions
,
"beta_mul"
))
gemm
.
beta
*=
get_float
(
mr
.
instructions
[
"beta"
]);
else
if
(
contains
(
mr
.
instructions
,
"gamma_mul"
))
{
gemm
.
alpha
*=
get_float
(
mr
.
instructions
[
"gamma"
]);
gemm
.
beta
*=
get_float
(
mr
.
instructions
[
"gamma"
]);
}
return
true
;
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
...
...
@@ -978,6 +1027,19 @@ struct find_gemm_pointwise
// Already fused gemm
if
(
not
float_equal
(
gemm
.
beta
,
0
))
return
;
gemm
.
beta
=
1
;
if
(
not
update_gemm
(
gemm
,
ins
->
module_inputs
().
front
(),
ins
->
inputs
().
front
()
==
gemm_ins
?
0
:
1
))
return
;
// const-fold input if not standard shape since rocblas can't handle it
if
(
not
c_ins
->
get_shape
().
standard
())
{
auto
c
=
op
::
contiguous
{};
auto
l
=
c
.
compute
(
c
.
compute_shape
({
c_ins
->
get_shape
()}),
{
c_ins
->
eval
()});
c_ins
=
m
.
add_literal
(
l
.
get_shape
(),
l
.
data
());
}
auto
inputs
=
gemm_ins
->
inputs
();
inputs
.
pop_back
();
...
...
@@ -985,11 +1047,68 @@ struct find_gemm_pointwise
inputs
.
push_back
(
c_ins
);
inputs
.
push_back
(
ins
->
inputs
().
back
());
gemm
.
beta
=
1
;
m
.
replace_instruction
(
ins
,
gemm
,
inputs
);
}
};
struct
find_contiguous_tranpose_gemm
{
auto
matcher
()
const
{
return
match
::
name
(
"gpu::contiguous"
)(
match
::
arg
(
0
)(
match
::
name
(
"transpose"
)(
match
::
arg
(
0
)(
match
::
name
(
"gpu::gemm"
)(
match
::
used_once
()).
bind
(
"gemm"
)))
.
bind
(
"transpose"
)));
}
template
<
class
Vector
>
static
bool
is_swapped
(
const
Vector
&
perm
,
std
::
size_t
i
,
std
::
size_t
j
)
{
if
(
i
>=
perm
.
size
()
or
j
>=
perm
.
size
())
return
false
;
auto
perm2
=
perm
;
std
::
iota
(
perm2
.
begin
(),
perm2
.
end
(),
0
);
std
::
swap
(
perm2
[
i
],
perm2
[
j
]);
return
perm2
==
perm
;
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
gemm
=
r
.
instructions
[
"gemm"
];
auto
alloc
=
gemm
->
inputs
().
back
();
auto
transpose
=
r
.
instructions
[
"transpose"
];
auto
perm
=
transpose
->
get_operator
().
to_value
()[
"permutation"
].
to_vector
<
int64_t
>
();
auto
iperm
=
invert_permutation
(
perm
);
if
(
perm
.
size
()
<
3
)
return
;
if
(
not
is_swapped
(
perm
,
perm
.
size
()
-
3
,
perm
.
size
()
-
2
))
return
;
auto
lens
=
gemm
->
get_shape
().
lens
();
if
(
lens
.
size
()
>
3
and
not
std
::
all_of
(
lens
.
begin
(),
lens
.
end
()
-
3
,
[](
auto
i
)
{
return
i
==
1
;
}))
return
;
auto
gemmv
=
gemm
->
get_operator
().
to_value
();
gemmv
[
"trans_batch"
]
=
1
;
auto
s
=
shape
{
alloc
->
get_shape
().
type
(),
reorder_dims
(
alloc
->
get_shape
().
lens
(),
iperm
)};
auto
new_alloc
=
m
.
insert_instruction
(
gemm
,
make_op
(
"allocate"
,
{{
"shape"
,
to_value
(
s
)}}));
auto
alloc_transpose
=
m
.
insert_instruction
(
gemm
,
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
new_alloc
);
auto
inputs
=
gemm
->
inputs
();
inputs
.
back
()
=
alloc_transpose
;
auto
new_gemm
=
m
.
insert_instruction
(
gemm
,
make_op
(
"gpu::gemm"
,
gemmv
),
inputs
);
auto
gemm_transpoe
=
m
.
insert_instruction
(
gemm
,
transpose
->
get_operator
(),
new_gemm
);
m
.
replace_instruction
(
ins
,
gemm_transpoe
);
}
};
struct
find_commutative_broadcast
{
auto
matcher
()
const
...
...
@@ -1091,6 +1210,7 @@ void fuse_ops::apply(module& m) const
find_gemm_add
{},
find_layernorm_pointwise
{},
find_gemm_pointwise
{},
find_contiguous_tranpose_gemm
{},
find_commutative_broadcast
{});
match
::
find_matches
(
m
,
find_contiguous
{});
}
...
...
src/targets/gpu/gemm_impl.cpp
View file @
37ddce62
...
...
@@ -24,6 +24,7 @@
#include <rocblas.h>
#include <migraphx/gpu/gemm_impl.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/permutation.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -67,6 +68,19 @@ void blas_shape(const shape& s)
MIGRAPHX_THROW
(
"GPU_GEMM: Batch dimension is not collapsible"
);
}
shape
transpose_batch
(
const
shape
&
s
,
unsigned
trans_batch
)
{
if
(
trans_batch
==
0
)
return
s
;
if
(
s
.
lens
().
size
()
<
3
)
return
s
;
auto
batch
=
s
.
lens
().
size
()
-
3
;
std
::
vector
<
int64_t
>
perm
(
s
.
lens
().
size
());
std
::
iota
(
perm
.
begin
(),
perm
.
end
(),
0
);
std
::
swap
(
perm
[
batch
],
perm
[
batch
+
trans_batch
]);
return
shape
::
from_permutation
(
s
.
type
(),
s
.
lens
(),
perm
);
}
template
<
class
R
,
class
...
Ts
,
class
...
Us
>
R
rocblas_invoke
(
R
(
*
f
)(
Ts
...),
Us
...
xs
)
{
...
...
@@ -97,6 +111,12 @@ void gemm_impl(context& ctx,
bool
int8_x4_format
,
bool
compute_fp32
)
{
const
bool
is_3inputs
=
(
args
.
size
()
==
4
);
if
(
not
is_3inputs
)
{
beta
=
0
;
}
bool
transa
=
is_transposed
(
args
[
0
].
get_shape
());
bool
transb
=
is_transposed
(
args
[
1
].
get_shape
());
auto
n_dim
=
output_shape
.
lens
().
size
();
...
...
@@ -105,12 +125,8 @@ void gemm_impl(context& ctx,
rocblas_int
lda
=
args
[
0
].
get_shape
().
strides
()[
transa
?
dim_1
:
dim_0
];
rocblas_int
ldb
=
args
[
1
].
get_shape
().
strides
()[
transb
?
dim_1
:
dim_0
];
rocblas_int
ldc
=
args
[
2
].
get_shape
().
strides
()[
dim_0
];
rocblas_int
ldd
=
is_3inputs
?
args
[
3
].
get_shape
().
strides
()[
dim_0
]
:
ldc
;
bool
is_3inputs
=
(
args
.
size
()
==
4
);
if
(
!
is_3inputs
)
{
beta
=
0
;
}
rocblas_datatype
arg_type
=
get_type
(
args
[
0
].
get_shape
().
type
());
auto
output_type
=
arg_type
;
if
(
output_type
==
rocblas_datatype_i8_r
)
...
...
@@ -186,7 +202,7 @@ void gemm_impl(context& ctx,
ldc
,
is_3inputs
?
to_pointer
(
args
[
3
])
:
to_pointer
(
args
[
2
]),
output_type
,
ld
c
,
ld
d
,
compute_type
,
rocblas_gemm_algo_standard
,
0
,
...
...
@@ -197,6 +213,7 @@ void gemm_impl(context& ctx,
auto
a_stride
=
get_batch_stride
(
args
[
0
]);
auto
b_stride
=
get_batch_stride
(
args
[
1
]);
auto
c_stride
=
get_batch_stride
(
args
[
2
]);
auto
d_stride
=
is_3inputs
?
get_batch_stride
(
args
[
3
])
:
c_stride
;
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex
,
ctx
.
get_stream
().
get_rocblas
(),
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
...
...
@@ -220,8 +237,8 @@ void gemm_impl(context& ctx,
c_stride
,
is_3inputs
?
to_pointer
(
args
[
3
])
:
to_pointer
(
args
[
2
]),
output_type
,
ld
c
,
c
_stride
,
ld
d
,
d
_stride
,
num_matrices
,
compute_type
,
rocblas_gemm_algo_standard
,
...
...
src/targets/gpu/include/migraphx/gpu/context.hpp
View file @
37ddce62
...
...
@@ -244,6 +244,15 @@ struct context
return
hip_event_ptr
{
event
};
}
static
hip_event_ptr
create_event_for_timing
()
{
hipEvent_t
event
;
auto
status
=
hipEventCreate
(
&
event
);
if
(
status
!=
hipSuccess
)
MIGRAPHX_THROW
(
"Failed to create event"
);
return
hip_event_ptr
{
event
};
}
value
to_value
()
const
{
value
result
;
...
...
@@ -267,10 +276,49 @@ struct context
any_ptr
get_queue
()
{
return
get_stream
().
get
();
}
void
enable_perf_measurement
(
bool
b
=
true
)
{
if
(
b
)
{
start_event
=
create_event_for_timing
();
stop_event
=
create_event_for_timing
();
get_stream
().
record
(
start_event
.
get
());
get_stream
().
record
(
stop_event
.
get
());
}
else
{
start_event
=
nullptr
;
stop_event
=
nullptr
;
}
measure_perf
=
b
;
}
std
::
pair
<
hipEvent_t
,
hipEvent_t
>
get_perf_events
()
const
{
if
(
measure_perf
)
return
std
::
make_pair
(
start_event
.
get
(),
stop_event
.
get
());
return
std
::
make_pair
(
nullptr
,
nullptr
);
}
float
get_elapsed_ms
()
const
{
float
result
=
0
;
if
(
start_event
!=
nullptr
and
stop_event
!=
nullptr
)
{
auto
status
=
hipEventElapsedTime
(
&
result
,
start_event
.
get
(),
stop_event
.
get
());
if
(
status
!=
hipSuccess
)
MIGRAPHX_THROW
(
"Failed hipEventElapsedTime: "
+
hip_error
(
status
));
}
return
result
;
}
private:
// TODO: Make this a vector to support multiple devices
std
::
shared_ptr
<
hip_device
>
current_device
;
std
::
vector
<
shared
<
hip_event_ptr
>>
events
;
bool
measure_perf
=
false
;
shared
<
hip_event_ptr
>
start_event
=
nullptr
;
shared
<
hip_event_ptr
>
stop_event
=
nullptr
;
};
inline
void
migraphx_to_value
(
value
&
v
,
const
context
&
ctx
)
{
v
=
ctx
.
to_value
();
}
...
...
src/targets/gpu/include/migraphx/gpu/gemm.hpp
View file @
37ddce62
...
...
@@ -42,15 +42,17 @@ namespace gpu {
struct
context
;
void
blas_shape
(
const
shape
&
s
);
shape
transpose_batch
(
const
shape
&
s
,
unsigned
trans_batch
);
template
<
class
Op
>
struct
rocblas_gemm
{
Op
op
;
float
alpha
=
1
;
float
beta
=
0
;
bool
int8_x4_format
=
true
;
bool
compute_fp32
=
false
;
float
alpha
=
1
;
float
beta
=
0
;
bool
int8_x4_format
=
true
;
bool
compute_fp32
=
false
;
unsigned
trans_batch
=
0
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
...
...
@@ -58,7 +60,9 @@ struct rocblas_gemm
return
pack_join
(
migraphx
::
reflect
(
self
.
op
,
f
),
pack
(
f
(
self
.
alpha
,
"alpha"
),
f
(
self
.
beta
,
"beta"
),
f
(
self
.
int8_x4_format
,
"int8_x4_format"
)));
f
(
self
.
int8_x4_format
,
"int8_x4_format"
),
f
(
self
.
compute_fp32
,
"compute_fp32"
),
f
(
self
.
trans_batch
,
"trans_batch"
)));
}
std
::
string
name
()
const
...
...
@@ -74,13 +78,14 @@ struct rocblas_gemm
{
std
::
vector
<
shape
>
in_shapes
(
inputs
);
in_shapes
.
pop_back
();
check_shapes
{
in_shapes
,
*
this
}.
not_broadcasted
(
);
check_shapes
{
in_shapes
,
*
this
}.
has
(
2
,
3
);
blas_shape
(
inputs
[
0
]);
blas_shape
(
inputs
[
1
]);
// if gemm and add are fused
if
(
in_shapes
.
size
()
>
2
)
{
auto
cmat_shape
=
in_shapes
.
back
();
check_shapes
{{
cmat_shape
},
*
this
}.
not_transposed
().
not_broadcasted
();
in_shapes
.
pop_back
();
blas_shape
(
cmat_shape
);
auto
op_out_shape
=
op
.
compute_shape
(
in_shapes
);
...
...
@@ -97,10 +102,10 @@ struct rocblas_gemm
to_string
(
cmat_shape
.
type
())
+
", it must be: "
+
to_string
(
op_out_shape
.
type
()));
}
return
op_out_shape
;
return
transpose_batch
(
op_out_shape
,
trans_batch
)
;
}
return
op
.
compute_shape
(
in_shapes
);
return
transpose_batch
(
op
.
compute_shape
(
in_shapes
)
,
trans_batch
)
;
}
argument
...
...
src/targets/gpu/include/migraphx/gpu/hip.hpp
View file @
37ddce62
...
...
@@ -37,6 +37,8 @@ namespace gpu {
struct
context
;
std
::
string
hip_error
(
int
error
);
argument
allocate_gpu
(
const
shape
&
s
,
bool
host
=
false
);
argument
register_on_gpu
(
const
argument
&
arg
);
...
...
src/targets/gpu/include/migraphx/gpu/kernel.hpp
View file @
37ddce62
...
...
@@ -50,17 +50,22 @@ struct kernel
void
launch
(
hipStream_t
stream
,
std
::
size_t
global
,
std
::
size_t
local
,
const
std
::
vector
<
kernel_argument
>&
args
)
const
;
const
std
::
vector
<
kernel_argument
>&
args
,
hipEvent_t
start
=
nullptr
,
hipEvent_t
stop
=
nullptr
)
const
;
void
launch
(
hipStream_t
stream
,
std
::
size_t
global
,
std
::
size_t
local
,
std
::
vector
<
void
*>
args
)
const
;
std
::
vector
<
void
*>
args
,
hipEvent_t
start
=
nullptr
,
hipEvent_t
stop
=
nullptr
)
const
;
auto
launch
(
hipStream_t
stream
,
std
::
size_t
global
,
std
::
size_t
local
)
const
template
<
class
...
Ts
>
auto
launch
(
hipStream_t
stream
,
std
::
size_t
global
,
std
::
size_t
local
,
Ts
...
zs
)
const
{
return
[
=
](
auto
&&
...
xs
)
{
launch
(
stream
,
global
,
local
,
std
::
vector
<
kernel_argument
>
{
xs
...});
launch
(
stream
,
global
,
local
,
std
::
vector
<
kernel_argument
>
{
xs
...}
,
zs
...
);
};
}
...
...
src/targets/gpu/kernel.cpp
View file @
37ddce62
...
...
@@ -80,7 +80,9 @@ void launch_kernel(hipFunction_t fun,
std
::
size_t
global
,
std
::
size_t
local
,
void
*
kernargs
,
std
::
size_t
size
)
std
::
size_t
size
,
hipEvent_t
start
,
hipEvent_t
stop
)
{
assert
(
global
>
0
);
assert
(
local
>
0
);
...
...
@@ -97,34 +99,55 @@ void launch_kernel(hipFunction_t fun,
#endif
};
auto
status
=
hipExtModuleLaunchKernel
(
fun
,
global
,
1
,
1
,
local
,
1
,
1
,
0
,
stream
,
nullptr
,
reinterpret_cast
<
void
**>
(
&
config
));
auto
status
=
hipExtModuleLaunchKernel
(
fun
,
global
,
1
,
1
,
local
,
1
,
1
,
0
,
stream
,
nullptr
,
reinterpret_cast
<
void
**>
(
&
config
),
start
,
stop
);
if
(
status
!=
hipSuccess
)
MIGRAPHX_THROW
(
"Failed to launch kernel: "
+
hip_error
(
status
));
if
(
stop
!=
nullptr
)
{
status
=
hipEventSynchronize
(
stop
);
if
(
status
!=
hipSuccess
)
MIGRAPHX_THROW
(
"Failed to sync event: "
+
hip_error
(
status
));
}
}
void
kernel
::
launch
(
hipStream_t
stream
,
std
::
size_t
global
,
std
::
size_t
local
,
std
::
vector
<
void
*>
args
)
const
std
::
vector
<
void
*>
args
,
hipEvent_t
start
,
hipEvent_t
stop
)
const
{
assert
(
impl
!=
nullptr
);
void
*
kernargs
=
args
.
data
();
std
::
size_t
size
=
args
.
size
()
*
sizeof
(
void
*
);
launch_kernel
(
impl
->
fun
,
stream
,
global
,
local
,
kernargs
,
size
);
launch_kernel
(
impl
->
fun
,
stream
,
global
,
local
,
kernargs
,
size
,
start
,
stop
);
}
void
kernel
::
launch
(
hipStream_t
stream
,
std
::
size_t
global
,
std
::
size_t
local
,
const
std
::
vector
<
kernel_argument
>&
args
)
const
const
std
::
vector
<
kernel_argument
>&
args
,
hipEvent_t
start
,
hipEvent_t
stop
)
const
{
assert
(
impl
!=
nullptr
);
std
::
vector
<
char
>
kernargs
=
pack_args
(
args
);
std
::
size_t
size
=
kernargs
.
size
();
launch_kernel
(
impl
->
fun
,
stream
,
global
,
local
,
kernargs
.
data
(),
size
);
launch_kernel
(
impl
->
fun
,
stream
,
global
,
local
,
kernargs
.
data
(),
size
,
start
,
stop
);
}
}
// namespace gpu
...
...
src/targets/gpu/kernels/include/migraphx/kernels/algorithm.hpp
View file @
37ddce62
...
...
@@ -163,7 +163,7 @@ constexpr Iterator1 search(Iterator1 first, Iterator1 last, Iterator2 s_first, I
{
return
last
;
}
if
(
!
(
*
it
==
*
s_it
))
if
(
not
(
*
it
==
*
s_it
))
{
break
;
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/array.hpp
View file @
37ddce62
...
...
@@ -212,7 +212,7 @@ struct array
return
true
;
}
friend
constexpr
bool
operator
!=
(
const
array
&
x
,
const
array
&
y
)
{
return
!
(
x
==
y
);
}
friend
constexpr
bool
operator
!=
(
const
array
&
x
,
const
array
&
y
)
{
return
not
(
x
==
y
);
}
// This uses the product order rather than lexical order
friend
constexpr
bool
operator
<
(
const
array
&
x
,
const
array
&
y
)
{
...
...
src/targets/gpu/kernels/include/migraphx/kernels/integral_constant.hpp
View file @
37ddce62
...
...
@@ -73,10 +73,10 @@ MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(!=)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP
(
&
)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP
(
^
)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP
(
|
)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP
(
&&
)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP
(
||
)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP
(
and
)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP
(
or
)
MIGRAPHX_INTEGRAL_CONSTANT_UNARY_OP
(
!
)
MIGRAPHX_INTEGRAL_CONSTANT_UNARY_OP
(
not
)
MIGRAPHX_INTEGRAL_CONSTANT_UNARY_OP
(
~
)
MIGRAPHX_INTEGRAL_CONSTANT_UNARY_OP
(
+
)
MIGRAPHX_INTEGRAL_CONSTANT_UNARY_OP
(
-
)
...
...
src/targets/gpu/lowering.cpp
View file @
37ddce62
...
...
@@ -341,7 +341,7 @@ struct miopen_apply
catch
(
migraphx
::
exception
&
)
{
// In case no solver supports the default format, retry using the other format.
compile_quant_conv_with_format
(
!
int8_x4_format
);
compile_quant_conv_with_format
(
not
int8_x4_format
);
}
auto
args
=
ins
->
inputs
();
...
...
src/targets/gpu/mlir.cpp
View file @
37ddce62
...
...
@@ -78,7 +78,7 @@ struct mlir_handle
friend
bool
operator
==
(
ptr
x
,
ptr
y
)
{
return
x
.
get_value
()
==
y
.
get_value
();
}
friend
bool
operator
!=
(
ptr
x
,
ptr
y
)
{
return
!
(
x
==
y
);
}
friend
bool
operator
!=
(
ptr
x
,
ptr
y
)
{
return
not
(
x
==
y
);
}
T
obj
{};
};
...
...
@@ -503,7 +503,7 @@ struct mlir_program
pp
=
problem_params
{
ins
->
get_operator
(),
to_shapes
(
ins
->
inputs
()),
ins
->
get_shape
()};
std
::
string
tuned
=
get_tune_params
();
if
(
!
tuned
.
empty
())
if
(
not
tuned
.
empty
())
ops
.
add_attributes
({{
"perf_config"
,
tuned
}});
// check if HW supports xdlops
if
(
contains
(
get_xdlops_archs
(),
target_name
))
...
...
src/targets/gpu/pack_int8_args.cpp
View file @
37ddce62
...
...
@@ -154,7 +154,7 @@ void pack_int8_args::apply(module& m) const
bool
transa
=
inputs
[
0
]
->
get_shape
().
transposed
();
bool
transb
=
inputs
[
1
]
->
get_shape
().
transposed
();
if
(
!
transb
)
if
(
not
transb
)
{
auto
packed_b
=
m
.
insert_instruction
(
ins
,
make_op
(
"hip::allocate"
,
{{
"shape"
,
to_value
(
inputs
[
1
]
->
get_shape
())}}));
...
...
src/targets/gpu/target.cpp
View file @
37ddce62
...
...
@@ -42,6 +42,7 @@
#include <migraphx/register_target.hpp>
#include <migraphx/replace_allocate.hpp>
#include <migraphx/rewrite_batchnorm.hpp>
#include <migraphx/rewrite_gelu.hpp>
#include <migraphx/rewrite_pooling.hpp>
#include <migraphx/rewrite_quantization.hpp>
#include <migraphx/rewrite_rnn.hpp>
...
...
@@ -116,6 +117,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
inline_module
{},
rewrite_pooling
{},
dead_code_elimination
{},
rewrite_gelu
{},
dead_code_elimination
{},
eliminate_common_subexpression
{},
dead_code_elimination
{},
simplify_algebra
{},
...
...
@@ -134,8 +137,6 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
lowering
{
&
ctx
,
options
.
offload_copy
},
eliminate_contiguous
{
"gpu::contiguous"
},
dead_code_elimination
{},
replace_allocate
{
gpu_allocation_model
{},
options
.
offload_copy
},
dead_code_elimination
{},
eliminate_concat
{
concat_gpu_optimization
{}},
dead_code_elimination
{},
pack_int8_args
{},
...
...
@@ -144,6 +145,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination
{},
fuse_ops
{
&
ctx
,
options
.
fast_math
},
dead_code_elimination
{},
replace_allocate
{
gpu_allocation_model
{},
options
.
offload_copy
},
dead_code_elimination
{},
compile_ops
{
&
ctx
},
dead_code_elimination
{},
write_literals
{
&
ctx
},
...
...
src/tf/parse_conv.cpp
View file @
37ddce62
...
...
@@ -100,7 +100,7 @@ struct parse_conv : op_parser<parse_conv>
{
MIGRAPHX_THROW
(
"padding should have 4 values"
);
}
if
(
padding
[
0
]
!=
padding
[
2
]
||
padding
[
1
]
!=
padding
[
3
])
if
(
padding
[
0
]
!=
padding
[
2
]
or
padding
[
1
]
!=
padding
[
3
])
{
MIGRAPHX_THROW
(
"migraphx does not support asymetric padding"
);
}
...
...
src/tf/parse_depthwiseconv.cpp
View file @
37ddce62
...
...
@@ -90,7 +90,7 @@ struct parse_depthwiseconv : op_parser<parse_depthwiseconv>
calculate_padding
(
0
,
pads
,
input_dims
[
2
],
op
.
stride
[
0
],
op
.
dilation
[
0
],
weight_h
);
calculate_padding
(
1
,
pads
,
input_dims
[
3
],
op
.
stride
[
1
],
op
.
dilation
[
1
],
weight_w
);
if
(
pads
[
0
]
!=
pads
[
2
]
||
pads
[
1
]
!=
pads
[
3
])
if
(
pads
[
0
]
!=
pads
[
2
]
or
pads
[
1
]
!=
pads
[
3
])
{
std
::
vector
<
int64_t
>
padding
=
{
0
,
0
,
pads
[
0
],
pads
[
1
],
0
,
0
,
pads
[
2
],
pads
[
3
]};
l0
=
info
.
add_instruction
(
migraphx
::
make_op
(
"pad"
,
{{
"pads"
,
padding
}}),
l0
);
...
...
src/tf/parse_pooling.cpp
View file @
37ddce62
...
...
@@ -42,7 +42,7 @@ struct parse_pooling : op_parser<parse_pooling>
tf_parser
::
node_info
info
,
std
::
vector
<
instruction_ref
>
args
)
const
{
if
(
!
starts_with
(
opd
.
tf_name
,
"Max"
)
&&
!
starts_with
(
opd
.
tf_name
,
"Av"
))
if
(
not
starts_with
(
opd
.
tf_name
,
"Max"
)
and
not
starts_with
(
opd
.
tf_name
,
"Av"
))
{
MIGRAPHX_THROW
(
"tf pooling mode must be Max or Average"
);
}
...
...
src/tf/tf_parser.cpp
View file @
37ddce62
...
...
@@ -371,7 +371,7 @@ void tf_parser::parse_node(const std::string& name)
{
result
=
ops
[
node
.
op
()](
*
this
,
{
get_attributes
(
node
),
node
.
op
(),
mm
},
args
);
}
assert
(
!
result
.
empty
());
assert
(
not
result
.
empty
());
// First output has no ":" delimiter
instructions
[
name
]
=
result
.
front
();
for
(
size_t
i
=
1
;
i
<
result
.
size
();
i
++
)
...
...
@@ -458,7 +458,7 @@ literal tf_parser::parse_tensor(const tensorflow::TensorProto& t) const
{
std
::
vector
<
size_t
>
dims
=
parse_dims
(
t
.
tensor_shape
());
size_t
shape_size
=
std
::
accumulate
(
dims
.
begin
(),
dims
.
end
(),
1
,
std
::
multiplies
<
size_t
>
());
if
(
!
t
.
tensor_content
().
empty
())
// has raw data
if
(
not
t
.
tensor_content
().
empty
())
// has raw data
{
const
std
::
string
&
s
=
t
.
tensor_content
();
switch
(
t
.
dtype
())
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment