Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
ab48f476
"vscode:/vscode.git/clone" did not exist on "4553478951d63d38abdcb37b287960ca19b380db"
Commit
ab48f476
authored
Jun 25, 2022
by
Paul
Browse files
Improve handling of broadcast
parent
75f378a3
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
109 additions
and
9 deletions
+109
-9
src/targets/gpu/gemm_impl.cpp
src/targets/gpu/gemm_impl.cpp
+11
-8
src/targets/gpu/include/migraphx/gpu/gemm.hpp
src/targets/gpu/include/migraphx/gpu/gemm.hpp
+2
-1
test/verify/gemm_add_broadcast1.cpp
test/verify/gemm_add_broadcast1.cpp
+48
-0
test/verify/gemm_add_broadcast2.cpp
test/verify/gemm_add_broadcast2.cpp
+48
-0
No files found.
src/targets/gpu/gemm_impl.cpp
View file @
ab48f476
...
@@ -97,6 +97,12 @@ void gemm_impl(context& ctx,
...
@@ -97,6 +97,12 @@ void gemm_impl(context& ctx,
bool
int8_x4_format
,
bool
int8_x4_format
,
bool
compute_fp32
)
bool
compute_fp32
)
{
{
const
bool
is_3inputs
=
(
args
.
size
()
==
4
);
if
(
!
is_3inputs
)
{
beta
=
0
;
}
bool
transa
=
is_transposed
(
args
[
0
].
get_shape
());
bool
transa
=
is_transposed
(
args
[
0
].
get_shape
());
bool
transb
=
is_transposed
(
args
[
1
].
get_shape
());
bool
transb
=
is_transposed
(
args
[
1
].
get_shape
());
auto
n_dim
=
output_shape
.
lens
().
size
();
auto
n_dim
=
output_shape
.
lens
().
size
();
...
@@ -105,12 +111,8 @@ void gemm_impl(context& ctx,
...
@@ -105,12 +111,8 @@ void gemm_impl(context& ctx,
rocblas_int
lda
=
args
[
0
].
get_shape
().
strides
()[
transa
?
dim_1
:
dim_0
];
rocblas_int
lda
=
args
[
0
].
get_shape
().
strides
()[
transa
?
dim_1
:
dim_0
];
rocblas_int
ldb
=
args
[
1
].
get_shape
().
strides
()[
transb
?
dim_1
:
dim_0
];
rocblas_int
ldb
=
args
[
1
].
get_shape
().
strides
()[
transb
?
dim_1
:
dim_0
];
rocblas_int
ldc
=
args
[
2
].
get_shape
().
strides
()[
dim_0
];
rocblas_int
ldc
=
args
[
2
].
get_shape
().
strides
()[
dim_0
];
rocblas_int
ldd
=
is_3inputs
?
args
[
3
].
get_shape
().
strides
()[
dim_0
]
:
ldc
;
bool
is_3inputs
=
(
args
.
size
()
==
4
);
if
(
!
is_3inputs
)
{
beta
=
0
;
}
rocblas_datatype
arg_type
=
get_type
(
args
[
0
].
get_shape
().
type
());
rocblas_datatype
arg_type
=
get_type
(
args
[
0
].
get_shape
().
type
());
auto
output_type
=
arg_type
;
auto
output_type
=
arg_type
;
if
(
output_type
==
rocblas_datatype_i8_r
)
if
(
output_type
==
rocblas_datatype_i8_r
)
...
@@ -186,7 +188,7 @@ void gemm_impl(context& ctx,
...
@@ -186,7 +188,7 @@ void gemm_impl(context& ctx,
ldc
,
ldc
,
is_3inputs
?
to_pointer
(
args
[
3
])
:
to_pointer
(
args
[
2
]),
is_3inputs
?
to_pointer
(
args
[
3
])
:
to_pointer
(
args
[
2
]),
output_type
,
output_type
,
ld
c
,
ld
d
,
compute_type
,
compute_type
,
rocblas_gemm_algo_standard
,
rocblas_gemm_algo_standard
,
0
,
0
,
...
@@ -197,6 +199,7 @@ void gemm_impl(context& ctx,
...
@@ -197,6 +199,7 @@ void gemm_impl(context& ctx,
auto
a_stride
=
get_batch_stride
(
args
[
0
]);
auto
a_stride
=
get_batch_stride
(
args
[
0
]);
auto
b_stride
=
get_batch_stride
(
args
[
1
]);
auto
b_stride
=
get_batch_stride
(
args
[
1
]);
auto
c_stride
=
get_batch_stride
(
args
[
2
]);
auto
c_stride
=
get_batch_stride
(
args
[
2
]);
auto
d_stride
=
is_3inputs
?
get_batch_stride
(
args
[
3
])
:
c_stride
;
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex
,
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex
,
ctx
.
get_stream
().
get_rocblas
(),
ctx
.
get_stream
().
get_rocblas
(),
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
...
@@ -220,8 +223,8 @@ void gemm_impl(context& ctx,
...
@@ -220,8 +223,8 @@ void gemm_impl(context& ctx,
c_stride
,
c_stride
,
is_3inputs
?
to_pointer
(
args
[
3
])
:
to_pointer
(
args
[
2
]),
is_3inputs
?
to_pointer
(
args
[
3
])
:
to_pointer
(
args
[
2
]),
output_type
,
output_type
,
ld
c
,
ld
d
,
c
_stride
,
d
_stride
,
num_matrices
,
num_matrices
,
compute_type
,
compute_type
,
rocblas_gemm_algo_standard
,
rocblas_gemm_algo_standard
,
...
...
src/targets/gpu/include/migraphx/gpu/gemm.hpp
View file @
ab48f476
...
@@ -74,13 +74,14 @@ struct rocblas_gemm
...
@@ -74,13 +74,14 @@ struct rocblas_gemm
{
{
std
::
vector
<
shape
>
in_shapes
(
inputs
);
std
::
vector
<
shape
>
in_shapes
(
inputs
);
in_shapes
.
pop_back
();
in_shapes
.
pop_back
();
check_shapes
{
in_shapes
,
*
this
}
.
not_broadcasted
()
;
check_shapes
{
in_shapes
,
*
this
};
blas_shape
(
inputs
[
0
]);
blas_shape
(
inputs
[
0
]);
blas_shape
(
inputs
[
1
]);
blas_shape
(
inputs
[
1
]);
// if gemm and add are fused
// if gemm and add are fused
if
(
in_shapes
.
size
()
>
2
)
if
(
in_shapes
.
size
()
>
2
)
{
{
auto
cmat_shape
=
in_shapes
.
back
();
auto
cmat_shape
=
in_shapes
.
back
();
check_shapes
{{
cmat_shape
},
*
this
}.
not_transposed
().
not_broadcasted
();
in_shapes
.
pop_back
();
in_shapes
.
pop_back
();
blas_shape
(
cmat_shape
);
blas_shape
(
cmat_shape
);
auto
op_out_shape
=
op
.
compute_shape
(
in_shapes
);
auto
op_out_shape
=
op
.
compute_shape
(
in_shapes
);
...
...
test/verify/gemm_add_broadcast1.cpp
0 → 100644
View file @
ab48f476
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/apply_alpha_beta.hpp>
struct
gemm_add_broadcast1
:
verify_program
<
gemm_add_broadcast1
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
3
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
float_type
,
{
1
,
3
,
4
}};
migraphx
::
shape
m3_shape
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
4
}};
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"2"
,
m2_shape
);
auto
l3
=
mm
->
add_parameter
(
"3"
,
m3_shape
);
auto
l3_b
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
1
,
2
,
4
}}}),
l3
);
auto
dot
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
l1
,
l2
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
dot
,
l3_b
);
return
p
;
}
};
test/verify/gemm_add_broadcast2.cpp
0 → 100644
View file @
ab48f476
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/apply_alpha_beta.hpp>
struct
gemm_add_broadcast2
:
verify_program
<
gemm_add_broadcast2
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
3
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
float_type
,
{
1
,
3
,
4
}};
migraphx
::
shape
m3_shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
1
}};
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"2"
,
m2_shape
);
auto
l3
=
mm
->
add_parameter
(
"3"
,
m3_shape
);
auto
l3_b
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
1
,
2
,
4
}}}),
l3
);
auto
dot
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
l1
,
l2
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
dot
,
l3_b
);
return
p
;
}
};
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment