Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
636e2847
Unverified
Commit
636e2847
authored
Aug 26, 2022
by
Chris Austen
Committed by
GitHub
Aug 26, 2022
Browse files
Merge branch 'develop' into fastsoftmax
parents
dca963c0
8752875a
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
394 additions
and
61 deletions
+394
-61
src/include/migraphx/matcher.hpp
src/include/migraphx/matcher.hpp
+6
-2
src/include/migraphx/target_assignments.hpp
src/include/migraphx/target_assignments.hpp
+1
-0
src/pad_calc.cpp
src/pad_calc.cpp
+1
-1
src/simplify_algebra.cpp
src/simplify_algebra.cpp
+55
-19
src/targets/gpu/fuse_ops.cpp
src/targets/gpu/fuse_ops.cpp
+138
-18
src/targets/gpu/gemm_impl.cpp
src/targets/gpu/gemm_impl.cpp
+25
-8
src/targets/gpu/include/migraphx/gpu/gemm.hpp
src/targets/gpu/include/migraphx/gpu/gemm.hpp
+13
-8
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+2
-2
test/include/basic_ops.hpp
test/include/basic_ops.hpp
+3
-2
test/simplify_algebra_test.cpp
test/simplify_algebra_test.cpp
+52
-1
test/verify/gemm_add_broadcast1.cpp
test/verify/gemm_add_broadcast1.cpp
+49
-0
test/verify/gemm_add_broadcast2.cpp
test/verify/gemm_add_broadcast2.cpp
+49
-0
No files found.
src/include/migraphx/matcher.hpp
View file @
636e2847
...
...
@@ -564,6 +564,11 @@ MIGRAPHX_BASIC_MATCHER(is_unused, const matcher_context& ctx, instruction_ref in
return
nullopt
;
}
MIGRAPHX_PRED_MATCHER
(
broadcast
,
instruction_ref
ins
)
{
return
contains
({
"broadcast"
,
"multibroadcast"
},
ins
->
name
());
}
template
<
class
...
Ms
>
auto
skip
(
Ms
...
ms
)
{
...
...
@@ -813,8 +818,7 @@ inline auto has_attribute(const std::string& name)
template
<
class
...
Ms
>
auto
pointwise
(
Ms
...
ms
)
{
return
match
::
has_attribute
(
"pointwise"
)(
match
::
any_of
(
match
::
nargs
(
1
),
match
::
nargs
(
2
)),
ms
...);
return
match
::
has_attribute
(
"pointwise"
)(
ms
...);
}
}
// namespace match
...
...
src/include/migraphx/target_assignments.hpp
View file @
636e2847
...
...
@@ -25,6 +25,7 @@
#define MIGRAPHX_GUARD_MIGRAPHX_ASSIGNMENT_HPP
#include <unordered_map>
#include <string>
#include <migraphx/instruction_ref.hpp>
...
...
src/pad_calc.cpp
View file @
636e2847
...
...
@@ -60,7 +60,7 @@ std::vector<std::size_t> calc_dyn_auto_pad(std::vector<std::size_t> tensor_lens,
{
std
::
vector
<
std
::
size_t
>
padding
;
padding
.
resize
(
2
*
k_lens
.
size
());
for
(
size_t
i
=
0
;
i
<
padding
.
size
()
/
2
;
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
padding
.
size
()
/
2
;
i
++
)
{
std
::
ptrdiff_t
input_dim
=
tensor_lens
[
i
];
std
::
ptrdiff_t
stride
=
strides
[
i
];
...
...
src/simplify_algebra.cpp
View file @
636e2847
...
...
@@ -208,6 +208,42 @@ struct find_mul_add
}
};
struct
find_dot_add
{
auto
matcher
()
const
{
return
match
::
name
(
"dot"
)(
match
::
either_arg
(
0
,
1
)(
match
::
name
(
"add"
)(
match
::
either_arg
(
0
,
1
)(
match
::
any
().
bind
(
"x"
),
match
::
any_of
(
match
::
is_constant
()).
bind
(
"b"
)),
match
::
none_of
(
match
::
args
(
match
::
is_constant
(),
match
::
is_constant
())),
match
::
used_once
()),
match
::
is_constant
().
bind
(
"a"
)));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
a_ins
=
r
.
instructions
[
"a"
];
auto
b_ins
=
r
.
instructions
[
"b"
];
auto
x_ins
=
r
.
instructions
[
"x"
];
assert
(
x_ins
!=
b_ins
);
const
bool
flipped
=
a_ins
==
ins
->
inputs
().
back
();
auto
insert_dot
=
[
&
](
auto
x
,
auto
y
)
{
if
(
flipped
)
return
m
.
insert_instruction
(
ins
,
make_op
(
"dot"
),
y
,
x
);
else
return
m
.
insert_instruction
(
ins
,
make_op
(
"dot"
),
x
,
y
);
};
auto
ax_ins
=
insert_dot
(
a_ins
,
x_ins
);
auto
ab_ins
=
insert_dot
(
a_ins
,
b_ins
);
m
.
replace_instruction
(
ins
,
make_op
(
"add"
),
ax_ins
,
ab_ins
);
}
};
struct
find_add_lit_broadcast
{
auto
matcher
()
const
...
...
@@ -267,28 +303,26 @@ struct find_double_add_lit_broadcast
struct
find_inner_broadcast
{
auto
matcher
()
const
{
return
pointwise
(
match
::
nargs
(
2
),
match
::
args
(
match
::
name
(
"broadcast"
).
bind
(
"x"
),
match
::
name
(
"broadcast"
).
bind
(
"y"
)));
}
auto
matcher
()
const
{
return
pointwise
(
match
::
all_of
[
match
::
inputs
()](
match
::
broadcast
()));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
x_ins
=
r
.
instructions
[
"x"
];
auto
y_ins
=
r
.
instructions
[
"y"
];
auto
xbroadcast
=
any_cast
<
op
::
broadcast
>
(
x_ins
->
get_operator
());
auto
ybroadcast
=
any_cast
<
op
::
broadcast
>
(
y_ins
->
get_operator
());
if
(
xbroadcast
.
axis
!=
ybroadcast
.
axis
)
auto
broadcasts
=
ins
->
inputs
();
if
(
broadcasts
.
empty
())
return
;
std
::
vector
<
instruction_ref
>
inputs
;
std
::
transform
(
broadcasts
.
begin
(),
broadcasts
.
end
(),
std
::
back_inserter
(
inputs
),
[](
auto
i
)
{
return
i
->
inputs
().
front
();
});
if
(
std
::
any_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
auto
i
)
{
return
i
->
get_shape
()
!=
inputs
.
front
()
->
get_shape
();
}))
return
;
auto
op
=
m
.
insert_instruction
(
ins
,
ins
->
get_operator
(),
x_ins
->
inputs
().
front
(),
y_ins
->
inputs
().
front
());
m
.
replace_instruction
(
ins
,
xbroadcast
,
op
);
auto
op
=
m
.
insert_instruction
(
ins
,
ins
->
get_operator
(),
inputs
);
m
.
replace_instruction
(
ins
,
broadcasts
.
front
()
->
get_operator
(),
op
);
}
};
...
...
@@ -416,8 +450,9 @@ struct find_splits
{
auto
matcher
()
const
{
return
match
::
any
(
match
::
any_of
[
match
::
outputs
()](
match
::
name
(
"slice"
)(
match
::
any_of
[
match
::
outputs
()](
match
::
pointwise
(),
reduction
()))));
return
match
::
any
(
match
::
any_of
[
match
::
outputs
()](
match
::
name
(
"slice"
)(
match
::
any_of
[
match
::
outputs
()](
match
::
pointwise
(
match
::
any_of
(
match
::
nargs
(
1
),
match
::
nargs
(
2
))),
reduction
()))));
}
static
bool
is_dependent
(
const
module
&
m
,
instruction_ref
ins1
,
instruction_ref
ins2
)
...
...
@@ -1048,6 +1083,7 @@ void simplify_algebra::apply(module& m) const
find_mul_conv
{},
find_mul_slice_conv
{},
find_mul_add
{},
find_dot_add
{},
find_div_const
{},
find_sub_const
{},
find_rsqrt
{},
...
...
src/targets/gpu/fuse_ops.cpp
View file @
636e2847
...
...
@@ -48,8 +48,10 @@
#include <migraphx/instruction.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/array.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/clip.hpp>
#include <migraphx/op/contiguous.hpp>
#include <cmath>
#include <set>
...
...
@@ -279,6 +281,11 @@ MIGRAPHX_REGISTER_OP(hip_layernorm)
struct
hip_triadd_layernorm
:
ternary_device
<
hip_triadd_layernorm
,
&
device
::
triadd_layernorm
>
{
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
4
).
standard
();
return
inputs
[
0
];
}
// Empty finalize to skip dimension reduction
void
finalize
(
context
&
,
const
shape
&
,
const
std
::
vector
<
shape
>&
)
{}
};
...
...
@@ -943,28 +950,70 @@ struct find_gemm_add
}
};
auto
pointwise_name
(
const
std
::
string
&
s
)
{
return
precompile_name
(
"pointwise"
)(
match
::
make_basic_pred_matcher
([
=
](
auto
ins
)
{
module_ref
pm
=
ins
->
module_inputs
().
front
();
auto
n
=
std
::
count_if
(
pm
->
begin
(),
pm
->
end
(),
[
&
](
auto
&
i
)
{
return
i
.
name
()
==
s
;
});
if
(
n
!=
1
)
return
false
;
return
std
::
all_of
(
pm
->
begin
(),
pm
->
end
(),
[
&
](
auto
&
i
)
{
return
starts_with
(
i
.
name
(),
"@"
)
or
i
.
name
()
==
s
;
});
}));
}
struct
find_gemm_pointwise
{
auto
matcher
()
const
{
return
p
ointwise_name
(
"add
"
)(
return
p
recompile_name
(
"pointwise
"
)(
match
::
nargs
(
3
),
match
::
all_of
[
match
::
inputs
()](
match
::
standard_shape
()),
match
::
either_arg
(
0
,
1
)(
match
::
used_once
().
bind
(
"c"
),
match
::
name
(
"gpu::gemm"
)(
match
::
nargs
(
3
)).
bind
(
"gemm"
)));
match
::
either_arg
(
0
,
1
)(
match
::
any_of
(
match
::
standard_shape
(),
match
::
is_constant
()).
bind
(
"c"
),
match
::
name
(
"gpu::gemm"
)(
match
::
nargs
(
3
),
match
::
used_once
()).
bind
(
"gemm"
)));
}
// TODO: Move to matcher.hpp
static
auto
match_param
(
const
std
::
string
&
name
)
{
return
match
::
make_basic_pred_matcher
([
=
](
auto
ins
)
{
if
(
ins
->
name
()
!=
"@param"
)
return
false
;
auto
p
=
any_cast
<
builtin
::
param
>
(
ins
->
get_operator
());
return
p
.
parameter
==
name
;
});
}
template
<
class
M
>
static
auto
match_mul_const
(
M
m
,
const
std
::
string
&
var
)
{
return
match
::
name
(
"mul"
)(
match
::
either_arg
(
0
,
1
)(
match
::
name
(
"@literal"
).
bind
(
var
),
m
))
.
bind
(
var
+
"_mul"
);
}
static
auto
match_add
(
const
std
::
string
&
input
,
const
std
::
string
&
output
)
{
auto
param
=
match
::
name
(
"@param"
);
auto
add
=
match
::
name
(
"add"
)(
match
::
args
(
param
,
param
));
auto
inner_mul
=
match
::
any_of
(
match_mul_const
(
match_param
(
input
),
"alpha"
),
match_mul_const
(
match_param
(
output
),
"beta"
));
auto
mul_add
=
match
::
name
(
"add"
)(
match
::
either_arg
(
0
,
1
)(
inner_mul
,
param
));
auto
add_mul
=
match_mul_const
(
add
,
"gamma"
);
return
match
::
name
(
"@return"
)(
match
::
args
(
match
::
any_of
(
add
,
mul_add
,
add_mul
)));
}
static
float
get_float
(
instruction_ref
ins
)
{
return
ins
->
get_literal
().
at
<
float
>
();
}
template
<
class
Gemm
>
static
bool
update_gemm
(
Gemm
&
gemm
,
module_ref
pm
,
unsigned
input
)
{
auto
names
=
pm
->
get_parameter_names
();
if
(
names
.
size
()
!=
2
)
return
false
;
std
::
sort
(
names
.
begin
(),
names
.
end
());
unsigned
output
=
input
==
0
?
1
:
0
;
auto
mr
=
match
::
match_instruction
(
*
pm
,
std
::
prev
(
pm
->
end
()),
match_add
(
names
[
input
],
names
[
output
]));
if
(
mr
.
result
==
pm
->
end
())
return
false
;
if
(
contains
(
mr
.
instructions
,
"alpha_mul"
))
gemm
.
alpha
*=
get_float
(
mr
.
instructions
[
"alpha"
]);
else
if
(
contains
(
mr
.
instructions
,
"beta_mul"
))
gemm
.
beta
*=
get_float
(
mr
.
instructions
[
"beta"
]);
else
if
(
contains
(
mr
.
instructions
,
"gamma_mul"
))
{
gemm
.
alpha
*=
get_float
(
mr
.
instructions
[
"gamma"
]);
gemm
.
beta
*=
get_float
(
mr
.
instructions
[
"gamma"
]);
}
return
true
;
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
...
...
@@ -978,6 +1027,19 @@ struct find_gemm_pointwise
// Already fused gemm
if
(
not
float_equal
(
gemm
.
beta
,
0
))
return
;
gemm
.
beta
=
1
;
if
(
not
update_gemm
(
gemm
,
ins
->
module_inputs
().
front
(),
ins
->
inputs
().
front
()
==
gemm_ins
?
0
:
1
))
return
;
// const-fold input if not standard shape since rocblas can't handle it
if
(
not
c_ins
->
get_shape
().
standard
())
{
auto
c
=
op
::
contiguous
{};
auto
l
=
c
.
compute
(
c
.
compute_shape
({
c_ins
->
get_shape
()}),
{
c_ins
->
eval
()});
c_ins
=
m
.
add_literal
(
l
.
get_shape
(),
l
.
data
());
}
auto
inputs
=
gemm_ins
->
inputs
();
inputs
.
pop_back
();
...
...
@@ -985,11 +1047,68 @@ struct find_gemm_pointwise
inputs
.
push_back
(
c_ins
);
inputs
.
push_back
(
ins
->
inputs
().
back
());
gemm
.
beta
=
1
;
m
.
replace_instruction
(
ins
,
gemm
,
inputs
);
}
};
struct
find_contiguous_tranpose_gemm
{
auto
matcher
()
const
{
return
match
::
name
(
"gpu::contiguous"
)(
match
::
arg
(
0
)(
match
::
name
(
"transpose"
)(
match
::
arg
(
0
)(
match
::
name
(
"gpu::gemm"
)(
match
::
used_once
()).
bind
(
"gemm"
)))
.
bind
(
"transpose"
)));
}
template
<
class
Vector
>
static
bool
is_swapped
(
const
Vector
&
perm
,
std
::
size_t
i
,
std
::
size_t
j
)
{
if
(
i
>=
perm
.
size
()
or
j
>=
perm
.
size
())
return
false
;
auto
perm2
=
perm
;
std
::
iota
(
perm2
.
begin
(),
perm2
.
end
(),
0
);
std
::
swap
(
perm2
[
i
],
perm2
[
j
]);
return
perm2
==
perm
;
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
gemm
=
r
.
instructions
[
"gemm"
];
auto
alloc
=
gemm
->
inputs
().
back
();
auto
transpose
=
r
.
instructions
[
"transpose"
];
auto
perm
=
transpose
->
get_operator
().
to_value
()[
"permutation"
].
to_vector
<
int64_t
>
();
auto
iperm
=
invert_permutation
(
perm
);
if
(
perm
.
size
()
<
3
)
return
;
if
(
not
is_swapped
(
perm
,
perm
.
size
()
-
3
,
perm
.
size
()
-
2
))
return
;
auto
lens
=
gemm
->
get_shape
().
lens
();
if
(
lens
.
size
()
>
3
and
not
std
::
all_of
(
lens
.
begin
(),
lens
.
end
()
-
3
,
[](
auto
i
)
{
return
i
==
1
;
}))
return
;
auto
gemmv
=
gemm
->
get_operator
().
to_value
();
gemmv
[
"trans_batch"
]
=
1
;
auto
s
=
shape
{
alloc
->
get_shape
().
type
(),
reorder_dims
(
alloc
->
get_shape
().
lens
(),
iperm
)};
auto
new_alloc
=
m
.
insert_instruction
(
gemm
,
make_op
(
"allocate"
,
{{
"shape"
,
to_value
(
s
)}}));
auto
alloc_transpose
=
m
.
insert_instruction
(
gemm
,
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
new_alloc
);
auto
inputs
=
gemm
->
inputs
();
inputs
.
back
()
=
alloc_transpose
;
auto
new_gemm
=
m
.
insert_instruction
(
gemm
,
make_op
(
"gpu::gemm"
,
gemmv
),
inputs
);
auto
gemm_transpoe
=
m
.
insert_instruction
(
gemm
,
transpose
->
get_operator
(),
new_gemm
);
m
.
replace_instruction
(
ins
,
gemm_transpoe
);
}
};
struct
find_commutative_broadcast
{
auto
matcher
()
const
...
...
@@ -1091,6 +1210,7 @@ void fuse_ops::apply(module& m) const
find_gemm_add
{},
find_layernorm_pointwise
{},
find_gemm_pointwise
{},
find_contiguous_tranpose_gemm
{},
find_commutative_broadcast
{});
match
::
find_matches
(
m
,
find_contiguous
{});
}
...
...
src/targets/gpu/gemm_impl.cpp
View file @
636e2847
...
...
@@ -24,6 +24,7 @@
#include <rocblas.h>
#include <migraphx/gpu/gemm_impl.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/permutation.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -67,6 +68,19 @@ void blas_shape(const shape& s)
MIGRAPHX_THROW
(
"GPU_GEMM: Batch dimension is not collapsible"
);
}
shape
transpose_batch
(
const
shape
&
s
,
unsigned
trans_batch
)
{
if
(
trans_batch
==
0
)
return
s
;
if
(
s
.
lens
().
size
()
<
3
)
return
s
;
auto
batch
=
s
.
lens
().
size
()
-
3
;
std
::
vector
<
int64_t
>
perm
(
s
.
lens
().
size
());
std
::
iota
(
perm
.
begin
(),
perm
.
end
(),
0
);
std
::
swap
(
perm
[
batch
],
perm
[
batch
+
trans_batch
]);
return
shape
::
from_permutation
(
s
.
type
(),
s
.
lens
(),
perm
);
}
template
<
class
R
,
class
...
Ts
,
class
...
Us
>
R
rocblas_invoke
(
R
(
*
f
)(
Ts
...),
Us
...
xs
)
{
...
...
@@ -97,6 +111,12 @@ void gemm_impl(context& ctx,
bool
int8_x4_format
,
bool
compute_fp32
)
{
const
bool
is_3inputs
=
(
args
.
size
()
==
4
);
if
(
!
is_3inputs
)
{
beta
=
0
;
}
bool
transa
=
is_transposed
(
args
[
0
].
get_shape
());
bool
transb
=
is_transposed
(
args
[
1
].
get_shape
());
auto
n_dim
=
output_shape
.
lens
().
size
();
...
...
@@ -105,12 +125,8 @@ void gemm_impl(context& ctx,
rocblas_int
lda
=
args
[
0
].
get_shape
().
strides
()[
transa
?
dim_1
:
dim_0
];
rocblas_int
ldb
=
args
[
1
].
get_shape
().
strides
()[
transb
?
dim_1
:
dim_0
];
rocblas_int
ldc
=
args
[
2
].
get_shape
().
strides
()[
dim_0
];
rocblas_int
ldd
=
is_3inputs
?
args
[
3
].
get_shape
().
strides
()[
dim_0
]
:
ldc
;
bool
is_3inputs
=
(
args
.
size
()
==
4
);
if
(
!
is_3inputs
)
{
beta
=
0
;
}
rocblas_datatype
arg_type
=
get_type
(
args
[
0
].
get_shape
().
type
());
auto
output_type
=
arg_type
;
if
(
output_type
==
rocblas_datatype_i8_r
)
...
...
@@ -186,7 +202,7 @@ void gemm_impl(context& ctx,
ldc
,
is_3inputs
?
to_pointer
(
args
[
3
])
:
to_pointer
(
args
[
2
]),
output_type
,
ld
c
,
ld
d
,
compute_type
,
rocblas_gemm_algo_standard
,
0
,
...
...
@@ -197,6 +213,7 @@ void gemm_impl(context& ctx,
auto
a_stride
=
get_batch_stride
(
args
[
0
]);
auto
b_stride
=
get_batch_stride
(
args
[
1
]);
auto
c_stride
=
get_batch_stride
(
args
[
2
]);
auto
d_stride
=
is_3inputs
?
get_batch_stride
(
args
[
3
])
:
c_stride
;
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex
,
ctx
.
get_stream
().
get_rocblas
(),
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
...
...
@@ -220,8 +237,8 @@ void gemm_impl(context& ctx,
c_stride
,
is_3inputs
?
to_pointer
(
args
[
3
])
:
to_pointer
(
args
[
2
]),
output_type
,
ld
c
,
c
_stride
,
ld
d
,
d
_stride
,
num_matrices
,
compute_type
,
rocblas_gemm_algo_standard
,
...
...
src/targets/gpu/include/migraphx/gpu/gemm.hpp
View file @
636e2847
...
...
@@ -42,6 +42,7 @@ namespace gpu {
struct
context
;
void
blas_shape
(
const
shape
&
s
);
shape
transpose_batch
(
const
shape
&
s
,
unsigned
trans_batch
);
template
<
class
Op
>
struct
rocblas_gemm
...
...
@@ -51,6 +52,7 @@ struct rocblas_gemm
float
beta
=
0
;
bool
int8_x4_format
=
true
;
bool
compute_fp32
=
false
;
unsigned
trans_batch
=
0
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
...
...
@@ -58,7 +60,9 @@ struct rocblas_gemm
return
pack_join
(
migraphx
::
reflect
(
self
.
op
,
f
),
pack
(
f
(
self
.
alpha
,
"alpha"
),
f
(
self
.
beta
,
"beta"
),
f
(
self
.
int8_x4_format
,
"int8_x4_format"
)));
f
(
self
.
int8_x4_format
,
"int8_x4_format"
),
f
(
self
.
compute_fp32
,
"compute_fp32"
),
f
(
self
.
trans_batch
,
"trans_batch"
)));
}
std
::
string
name
()
const
...
...
@@ -74,13 +78,14 @@ struct rocblas_gemm
{
std
::
vector
<
shape
>
in_shapes
(
inputs
);
in_shapes
.
pop_back
();
check_shapes
{
in_shapes
,
*
this
}.
not_broadcasted
(
);
check_shapes
{
in_shapes
,
*
this
}.
has
(
2
,
3
);
blas_shape
(
inputs
[
0
]);
blas_shape
(
inputs
[
1
]);
// if gemm and add are fused
if
(
in_shapes
.
size
()
>
2
)
{
auto
cmat_shape
=
in_shapes
.
back
();
check_shapes
{{
cmat_shape
},
*
this
}.
not_transposed
().
not_broadcasted
();
in_shapes
.
pop_back
();
blas_shape
(
cmat_shape
);
auto
op_out_shape
=
op
.
compute_shape
(
in_shapes
);
...
...
@@ -97,10 +102,10 @@ struct rocblas_gemm
to_string
(
cmat_shape
.
type
())
+
", it must be: "
+
to_string
(
op_out_shape
.
type
()));
}
return
op_out_shape
;
return
transpose_batch
(
op_out_shape
,
trans_batch
)
;
}
return
op
.
compute_shape
(
in_shapes
);
return
transpose_batch
(
op
.
compute_shape
(
in_shapes
)
,
trans_batch
)
;
}
argument
...
...
src/targets/gpu/target.cpp
View file @
636e2847
...
...
@@ -134,8 +134,6 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
lowering
{
&
ctx
,
options
.
offload_copy
},
eliminate_contiguous
{
"gpu::contiguous"
},
dead_code_elimination
{},
replace_allocate
{
gpu_allocation_model
{},
options
.
offload_copy
},
dead_code_elimination
{},
eliminate_concat
{
concat_gpu_optimization
{}},
dead_code_elimination
{},
pack_int8_args
{},
...
...
@@ -144,6 +142,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination
{},
fuse_ops
{
&
ctx
,
options
.
fast_math
},
dead_code_elimination
{},
replace_allocate
{
gpu_allocation_model
{},
options
.
offload_copy
},
dead_code_elimination
{},
compile_ops
{
&
ctx
},
dead_code_elimination
{},
write_literals
{
&
ctx
},
...
...
test/include/basic_ops.hpp
View file @
636e2847
...
...
@@ -186,9 +186,10 @@ struct nop
migraphx
::
shape
compute_shape
(
const
std
::
vector
<
migraphx
::
shape
>&
)
const
{
return
{};
}
};
inline
migraphx
::
literal
get_2x2
()
inline
migraphx
::
literal
get_2x2
(
int
base
=
0
)
{
return
migraphx
::
literal
{{
migraphx
::
shape
::
float_type
,
{
2
,
2
}},
{
1
,
2
,
3
,
4
}};
return
migraphx
::
literal
{{
migraphx
::
shape
::
float_type
,
{
2
,
2
}},
{
base
+
1
,
base
+
2
,
base
+
3
,
base
+
4
}};
}
inline
migraphx
::
literal
get_2x2_transposed
()
...
...
test/simplify_algebra_test.cpp
View file @
636e2847
...
...
@@ -358,7 +358,33 @@ TEST_CASE(simplify_mul_add)
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
simplify_inner_broadcast
)
TEST_CASE
(
simplify_dot_add
)
{
migraphx
::
module
m1
;
{
auto
x
=
m1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
2
,
2
}});
auto
one
=
m1
.
add_literal
(
get_2x2
());
auto
two
=
m1
.
add_literal
(
get_2x2
(
1
));
auto
sum
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
one
,
x
);
auto
dot
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
sum
,
two
);
m1
.
add_instruction
(
pass_op
{},
dot
);
}
run_pass
(
m1
);
migraphx
::
module
m2
;
{
auto
x
=
m2
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
2
,
2
}});
auto
one
=
m2
.
add_literal
(
get_2x2
());
auto
two
=
m2
.
add_literal
(
get_2x2
(
1
));
auto
dot1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
x
,
two
);
auto
dot2
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
one
,
two
);
auto
sum
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
dot1
,
dot2
);
m2
.
add_instruction
(
pass_op
{},
sum
);
}
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
simplify_inner_broadcast1
)
{
auto
b
=
migraphx
::
op
::
broadcast
{
1
,
{
2
,
1
,
4
,
5
}};
migraphx
::
module
m1
;
...
...
@@ -383,6 +409,31 @@ TEST_CASE(simplify_inner_broadcast)
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
simplify_inner_broadcast2
)
{
auto
b
=
migraphx
::
op
::
multibroadcast
{{
2
,
1
,
4
,
5
}};
migraphx
::
module
m1
;
{
auto
x
=
m1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
1
,
1
,
1
,
1
}});
auto
y
=
m1
.
add_parameter
(
"y"
,
{
migraphx
::
shape
::
int32_type
,
{
1
,
1
,
1
,
1
}});
auto
xb
=
m1
.
add_instruction
(
b
,
x
);
auto
yb
=
m1
.
add_instruction
(
b
,
y
);
auto
sum
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
xb
,
yb
);
m1
.
add_instruction
(
pass_op
{},
sum
);
}
run_pass
(
m1
);
migraphx
::
module
m2
;
{
auto
x
=
m2
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
1
,
1
,
1
,
1
}});
auto
y
=
m2
.
add_parameter
(
"y"
,
{
migraphx
::
shape
::
int32_type
,
{
1
,
1
,
1
,
1
}});
auto
sum
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
x
,
y
);
auto
sumb
=
m2
.
add_instruction
(
b
,
sum
);
m2
.
add_instruction
(
pass_op
{},
sumb
);
}
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
simplify_add_conv1
)
{
migraphx
::
module
m
;
...
...
test/verify/gemm_add_broadcast1.cpp
0 → 100644
View file @
636e2847
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/apply_alpha_beta.hpp>
struct
gemm_add_broadcast1
:
verify_program
<
gemm_add_broadcast1
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
3
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
float_type
,
{
1
,
3
,
4
}};
migraphx
::
shape
m3_shape
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
4
}};
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"2"
,
m2_shape
);
auto
l3
=
mm
->
add_parameter
(
"3"
,
m3_shape
);
auto
l3_b
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
1
,
2
,
4
}}}),
l3
);
auto
dot
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
l1
,
l2
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
dot
,
l3_b
);
return
p
;
}
};
test/verify/gemm_add_broadcast2.cpp
0 → 100644
View file @
636e2847
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/apply_alpha_beta.hpp>
struct
gemm_add_broadcast2
:
verify_program
<
gemm_add_broadcast2
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
3
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
float_type
,
{
1
,
3
,
4
}};
migraphx
::
shape
m3_shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
1
}};
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"2"
,
m2_shape
);
auto
l3
=
mm
->
add_parameter
(
"3"
,
m3_shape
);
auto
l3_b
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
1
,
2
,
4
}}}),
l3
);
auto
dot
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
l1
,
l2
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
dot
,
l3_b
);
return
p
;
}
};
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment