Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
cf8ccba4
Commit
cf8ccba4
authored
Jul 07, 2022
by
Paul
Browse files
Merge branch 'bert-opt2' into bert-opt3
parents
9747cc44
bd70cd8d
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
137 additions
and
11 deletions
+137
-11
src/include/migraphx/op/unsqueeze.hpp
src/include/migraphx/op/unsqueeze.hpp
+3
-0
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+1
-1
src/targets/gpu/fuse_ops.cpp
src/targets/gpu/fuse_ops.cpp
+60
-0
src/targets/gpu/gemm_impl.cpp
src/targets/gpu/gemm_impl.cpp
+14
-0
src/targets/gpu/include/migraphx/gpu/gemm.hpp
src/targets/gpu/include/migraphx/gpu/gemm.hpp
+11
-7
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+2
-2
test/op_shape_test.cpp
test/op_shape_test.cpp
+14
-1
test/simplify_reshapes_test.cpp
test/simplify_reshapes_test.cpp
+32
-0
No files found.
src/include/migraphx/op/unsqueeze.hpp
View file @
cf8ccba4
...
@@ -74,6 +74,9 @@ struct unsqueeze
...
@@ -74,6 +74,9 @@ struct unsqueeze
MIGRAPHX_THROW
(
"UNSQUEEZE: Input must be a scalar"
);
MIGRAPHX_THROW
(
"UNSQUEEZE: Input must be a scalar"
);
}
}
if
(
steps
.
size
()
>
axes
.
size
())
MIGRAPHX_THROW
(
"UNSQUEEZE: Steps provided with no axis"
);
std
::
size_t
new_size
=
old_lens
.
size
()
+
axes
.
size
();
std
::
size_t
new_size
=
old_lens
.
size
()
+
axes
.
size
();
std
::
vector
<
std
::
size_t
>
new_lens
(
new_size
);
std
::
vector
<
std
::
size_t
>
new_lens
(
new_size
);
...
...
src/simplify_reshapes.cpp
View file @
cf8ccba4
...
@@ -275,7 +275,7 @@ struct find_concat_transpose
...
@@ -275,7 +275,7 @@ struct find_concat_transpose
{
{
auto
matcher
()
const
auto
matcher
()
const
{
{
return
match
::
name
(
"concat"
)(
match
::
all_of
[
match
::
inputs
()](
match
::
transpose
_shape
(
)));
return
match
::
name
(
"concat"
)(
match
::
all_of
[
match
::
inputs
()](
match
::
name
(
"
transpose
"
)));
}
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
mr
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
mr
)
const
...
...
src/targets/gpu/fuse_ops.cpp
View file @
cf8ccba4
...
@@ -48,6 +48,7 @@
...
@@ -48,6 +48,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/array.hpp>
#include <migraphx/array.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/clip.hpp>
#include <migraphx/op/clip.hpp>
#include <migraphx/op/contiguous.hpp>
#include <migraphx/op/contiguous.hpp>
...
@@ -1063,6 +1064,64 @@ struct find_gemm_pointwise
...
@@ -1063,6 +1064,64 @@ struct find_gemm_pointwise
}
}
};
};
struct
find_contiguous_tranpose_gemm
{
auto
matcher
()
const
{
return
match
::
name
(
"gpu::contiguous"
)(
match
::
arg
(
0
)(
match
::
name
(
"transpose"
)(
match
::
arg
(
0
)(
match
::
name
(
"gpu::gemm"
)(
match
::
used_once
()).
bind
(
"gemm"
)))
.
bind
(
"transpose"
)));
}
template
<
class
Vector
>
static
bool
is_swapped
(
const
Vector
&
perm
,
std
::
size_t
i
,
std
::
size_t
j
)
{
if
(
i
>=
perm
.
size
()
or
j
>=
perm
.
size
())
return
false
;
auto
perm2
=
perm
;
std
::
iota
(
perm2
.
begin
(),
perm2
.
end
(),
0
);
std
::
swap
(
perm2
[
i
],
perm2
[
j
]);
return
perm2
==
perm
;
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
gemm
=
r
.
instructions
[
"gemm"
];
auto
alloc
=
gemm
->
inputs
().
back
();
auto
transpose
=
r
.
instructions
[
"transpose"
];
auto
perm
=
transpose
->
get_operator
().
to_value
()[
"permutation"
].
to_vector
<
int64_t
>
();
auto
iperm
=
invert_permutation
(
perm
);
if
(
perm
.
size
()
<
3
)
return
;
if
(
not
is_swapped
(
perm
,
perm
.
size
()
-
3
,
perm
.
size
()
-
2
))
return
;
auto
lens
=
gemm
->
get_shape
().
lens
();
if
(
lens
.
size
()
>
3
and
not
std
::
all_of
(
lens
.
begin
(),
lens
.
end
()
-
3
,
[](
auto
i
)
{
return
i
==
1
;
}))
return
;
auto
gemmv
=
gemm
->
get_operator
().
to_value
();
gemmv
[
"trans_batch"
]
=
1
;
auto
s
=
shape
{
alloc
->
get_shape
().
type
(),
reorder_dims
(
alloc
->
get_shape
().
lens
(),
iperm
)};
auto
new_alloc
=
m
.
insert_instruction
(
gemm
,
make_op
(
"allocate"
,
{{
"shape"
,
to_value
(
s
)}}));
auto
alloc_transpose
=
m
.
insert_instruction
(
gemm
,
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
new_alloc
);
auto
inputs
=
gemm
->
inputs
();
inputs
.
back
()
=
alloc_transpose
;
auto
new_gemm
=
m
.
insert_instruction
(
gemm
,
make_op
(
"gpu::gemm"
,
gemmv
),
inputs
);
auto
gemm_transpoe
=
m
.
insert_instruction
(
gemm
,
transpose
->
get_operator
(),
new_gemm
);
m
.
replace_instruction
(
ins
,
gemm_transpoe
);
}
};
struct
find_commutative_broadcast
struct
find_commutative_broadcast
{
{
auto
matcher
()
const
auto
matcher
()
const
...
@@ -1164,6 +1223,7 @@ void fuse_ops::apply(module& m) const
...
@@ -1164,6 +1223,7 @@ void fuse_ops::apply(module& m) const
find_gemm_add
{},
find_gemm_add
{},
find_layernorm_pointwise
{},
find_layernorm_pointwise
{},
find_gemm_pointwise
{},
find_gemm_pointwise
{},
find_contiguous_tranpose_gemm
{},
find_commutative_broadcast
{});
find_commutative_broadcast
{});
match
::
find_matches
(
m
,
find_contiguous
{});
match
::
find_matches
(
m
,
find_contiguous
{});
}
}
...
...
src/targets/gpu/gemm_impl.cpp
View file @
cf8ccba4
...
@@ -24,6 +24,7 @@
...
@@ -24,6 +24,7 @@
#include <rocblas.h>
#include <rocblas.h>
#include <migraphx/gpu/gemm_impl.hpp>
#include <migraphx/gpu/gemm_impl.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/permutation.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -67,6 +68,19 @@ void blas_shape(const shape& s)
...
@@ -67,6 +68,19 @@ void blas_shape(const shape& s)
MIGRAPHX_THROW
(
"GPU_GEMM: Batch dimension is not collapsible"
);
MIGRAPHX_THROW
(
"GPU_GEMM: Batch dimension is not collapsible"
);
}
}
shape
transpose_batch
(
const
shape
&
s
,
unsigned
trans_batch
)
{
if
(
trans_batch
==
0
)
return
s
;
if
(
s
.
lens
().
size
()
<
3
)
return
s
;
auto
batch
=
s
.
lens
().
size
()
-
3
;
std
::
vector
<
int64_t
>
perm
(
s
.
lens
().
size
());
std
::
iota
(
perm
.
begin
(),
perm
.
end
(),
0
);
std
::
swap
(
perm
[
batch
],
perm
[
batch
+
trans_batch
]);
return
shape
::
from_permutation
(
s
.
type
(),
s
.
lens
(),
perm
);
}
template
<
class
R
,
class
...
Ts
,
class
...
Us
>
template
<
class
R
,
class
...
Ts
,
class
...
Us
>
R
rocblas_invoke
(
R
(
*
f
)(
Ts
...),
Us
...
xs
)
R
rocblas_invoke
(
R
(
*
f
)(
Ts
...),
Us
...
xs
)
{
{
...
...
src/targets/gpu/include/migraphx/gpu/gemm.hpp
View file @
cf8ccba4
...
@@ -42,15 +42,17 @@ namespace gpu {
...
@@ -42,15 +42,17 @@ namespace gpu {
struct
context
;
struct
context
;
void
blas_shape
(
const
shape
&
s
);
void
blas_shape
(
const
shape
&
s
);
shape
transpose_batch
(
const
shape
&
s
,
unsigned
trans_batch
);
template
<
class
Op
>
template
<
class
Op
>
struct
rocblas_gemm
struct
rocblas_gemm
{
{
Op
op
;
Op
op
;
float
alpha
=
1
;
float
alpha
=
1
;
float
beta
=
0
;
float
beta
=
0
;
bool
int8_x4_format
=
true
;
bool
int8_x4_format
=
true
;
bool
compute_fp32
=
false
;
bool
compute_fp32
=
false
;
unsigned
trans_batch
=
0
;
template
<
class
Self
,
class
F
>
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
static
auto
reflect
(
Self
&
self
,
F
f
)
...
@@ -58,7 +60,9 @@ struct rocblas_gemm
...
@@ -58,7 +60,9 @@ struct rocblas_gemm
return
pack_join
(
migraphx
::
reflect
(
self
.
op
,
f
),
return
pack_join
(
migraphx
::
reflect
(
self
.
op
,
f
),
pack
(
f
(
self
.
alpha
,
"alpha"
),
pack
(
f
(
self
.
alpha
,
"alpha"
),
f
(
self
.
beta
,
"beta"
),
f
(
self
.
beta
,
"beta"
),
f
(
self
.
int8_x4_format
,
"int8_x4_format"
)));
f
(
self
.
int8_x4_format
,
"int8_x4_format"
),
f
(
self
.
compute_fp32
,
"compute_fp32"
),
f
(
self
.
trans_batch
,
"trans_batch"
)));
}
}
std
::
string
name
()
const
std
::
string
name
()
const
...
@@ -98,10 +102,10 @@ struct rocblas_gemm
...
@@ -98,10 +102,10 @@ struct rocblas_gemm
to_string
(
cmat_shape
.
type
())
+
to_string
(
cmat_shape
.
type
())
+
", it must be: "
+
to_string
(
op_out_shape
.
type
()));
", it must be: "
+
to_string
(
op_out_shape
.
type
()));
}
}
return
op_out_shape
;
return
transpose_batch
(
op_out_shape
,
trans_batch
)
;
}
}
return
op
.
compute_shape
(
in_shapes
);
return
transpose_batch
(
op
.
compute_shape
(
in_shapes
)
,
trans_batch
)
;
}
}
argument
argument
...
...
src/targets/gpu/target.cpp
View file @
cf8ccba4
...
@@ -133,8 +133,6 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
...
@@ -133,8 +133,6 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
lowering
{
&
ctx
,
options
.
offload_copy
},
lowering
{
&
ctx
,
options
.
offload_copy
},
eliminate_contiguous
{
"gpu::contiguous"
},
eliminate_contiguous
{
"gpu::contiguous"
},
dead_code_elimination
{},
dead_code_elimination
{},
replace_allocate
{
gpu_allocation_model
{},
options
.
offload_copy
},
dead_code_elimination
{},
eliminate_concat
{
concat_gpu_optimization
{}},
eliminate_concat
{
concat_gpu_optimization
{}},
dead_code_elimination
{},
dead_code_elimination
{},
pack_int8_args
{},
pack_int8_args
{},
...
@@ -143,6 +141,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
...
@@ -143,6 +141,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination
{},
dead_code_elimination
{},
fuse_ops
{
&
ctx
,
options
.
fast_math
},
fuse_ops
{
&
ctx
,
options
.
fast_math
},
dead_code_elimination
{},
dead_code_elimination
{},
replace_allocate
{
gpu_allocation_model
{},
options
.
offload_copy
},
dead_code_elimination
{},
compile_ops
{
&
ctx
},
compile_ops
{
&
ctx
},
dead_code_elimination
{},
dead_code_elimination
{},
write_literals
{
&
ctx
},
write_literals
{
&
ctx
},
...
...
test/op_shape_test.cpp
View file @
cf8ccba4
...
@@ -1551,7 +1551,7 @@ TEST_CASE(test_unsqueeze_step_non_divisable)
...
@@ -1551,7 +1551,7 @@ TEST_CASE(test_unsqueeze_step_non_divisable)
throws_shape
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
2
}},
{
"steps"
,
{
2
}}}),
s1
);
throws_shape
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
2
}},
{
"steps"
,
{
2
}}}),
s1
);
}
}
TEST_CASE
(
test_unsqueeze_step_
non_
zero
)
TEST_CASE
(
test_unsqueeze_step_zero
)
{
{
migraphx
::
shape
s1
{
migraphx
::
shape
::
float_type
,
{
4
,
5
,
12
}};
migraphx
::
shape
s1
{
migraphx
::
shape
::
float_type
,
{
4
,
5
,
12
}};
throws_shape
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
2
}},
{
"steps"
,
{
0
}}}),
s1
);
throws_shape
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
2
}},
{
"steps"
,
{
0
}}}),
s1
);
...
@@ -1563,6 +1563,12 @@ TEST_CASE(test_unsqueeze_step_at_end)
...
@@ -1563,6 +1563,12 @@ TEST_CASE(test_unsqueeze_step_at_end)
throws_shape
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
3
}},
{
"steps"
,
{
2
}}}),
s1
);
throws_shape
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
3
}},
{
"steps"
,
{
2
}}}),
s1
);
}
}
TEST_CASE
(
test_unsqueeze_mismatch_step_axis
)
{
migraphx
::
shape
s1
{
migraphx
::
shape
::
float_type
,
{
4
,
5
,
12
}};
throws_shape
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
2
}},
{
"steps"
,
{
2
,
3
}}}),
s1
);
}
TEST_CASE
(
test_unsqueeze_negative_axis
)
TEST_CASE
(
test_unsqueeze_negative_axis
)
{
{
migraphx
::
shape
s1
{
migraphx
::
shape
::
float_type
,
{
4
,
5
,
3
}};
migraphx
::
shape
s1
{
migraphx
::
shape
::
float_type
,
{
4
,
5
,
3
}};
...
@@ -1659,6 +1665,13 @@ TEST_CASE(test_unsqueeze_multiple_axes_4)
...
@@ -1659,6 +1665,13 @@ TEST_CASE(test_unsqueeze_multiple_axes_4)
expect_shape
(
s2
,
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
5
,
4
,
2
}}}),
s1
);
expect_shape
(
s2
,
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
5
,
4
,
2
}}}),
s1
);
}
}
TEST_CASE
(
test_unsqueeze_multiple_axes_step
)
{
migraphx
::
shape
s1
{
migraphx
::
shape
::
float_type
,
{
3
,
4
,
10
}};
migraphx
::
shape
s2
{
migraphx
::
shape
::
float_type
,
{
3
,
4
,
2
,
5
,
1
,
1
}};
expect_shape
(
s2
,
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
2
,
4
,
5
}},
{
"steps"
,
{
2
}}}),
s1
);
}
TEST_CASE
(
transpose_shape
)
TEST_CASE
(
transpose_shape
)
{
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
2
}};
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
2
}};
...
...
test/simplify_reshapes_test.cpp
View file @
cf8ccba4
...
@@ -1141,6 +1141,38 @@ TEST_CASE(transpose_contiguous_reshape_binary_broadcast)
...
@@ -1141,6 +1141,38 @@ TEST_CASE(transpose_contiguous_reshape_binary_broadcast)
EXPECT
(
m1
==
m2
);
EXPECT
(
m1
==
m2
);
}
}
TEST_CASE
(
transpose_unsqueeze_concat
)
{
migraphx
::
module
m1
;
{
auto
l0
=
m1
.
add_parameter
(
"0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
1
,
1
}});
auto
lt0
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
3
,
1
}}}),
l0
);
auto
l1
=
m1
.
add_parameter
(
"1"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
1
,
1
}});
auto
lt1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
3
,
1
}}}),
l1
);
auto
l2
=
m1
.
add_parameter
(
"2"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
1
,
1
}});
auto
lt2
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
3
,
1
}}}),
l2
);
std
::
vector
<
migraphx
::
instruction_ref
>
args
{
lt0
,
lt1
,
lt2
};
std
::
vector
<
migraphx
::
instruction_ref
>
unsqueezed_args
;
int64_t
axis
=
3
;
std
::
transform
(
args
.
begin
(),
args
.
end
(),
std
::
back_inserter
(
unsqueezed_args
),
[
&
](
migraphx
::
instruction_ref
arg
)
{
return
m1
.
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
axis
}}}),
arg
);
});
m1
.
add_instruction
(
migraphx
::
make_op
(
"concat"
,
{{
"axis"
,
axis
}}),
unsqueezed_args
);
}
// TODO: This could be simplified to a single transpose after concat
migraphx
::
module
m2
=
m1
;
run_pass
(
m1
);
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
transpose_slice
)
TEST_CASE
(
transpose_slice
)
{
{
migraphx
::
module
m1
;
migraphx
::
module
m1
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment