Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
7892b274
"vscode:/vscode.git/clone" did not exist on "9ba8b4801a7f811334c4f1c24dd1ebc721f9a4a9"
Commit
7892b274
authored
Oct 17, 2023
by
Paul
Browse files
Remvoe stray if
parent
9ca6c1d8
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
31 additions
and
33 deletions
+31
-33
src/targets/gpu/gemm_impl.cpp
src/targets/gpu/gemm_impl.cpp
+31
-33
No files found.
src/targets/gpu/gemm_impl.cpp
View file @
7892b274
...
...
@@ -158,8 +158,6 @@ struct gemm_impl
{
beta
=
0
;
}
if
(
arg_type
==
rocblas_datatype_f16_r
)
compute_type
=
rocblas_datatype_f32_r
;
rocblas_gemm_flags
flag
=
rocblas_gemm_flags_none
;
#if ROCBLAS_VERSION_MAJOR < 3
...
...
@@ -200,43 +198,43 @@ struct gemm_impl
{
output_type
=
rocblas_datatype_i32_r
;
}
compute_type
=
output_type
;
if
(
compute_fp32
)
{
compute_type
=
output_type
;
if
(
compute_fp32
)
{
if
(
arg_type
==
rocblas_datatype_f16_r
)
compute_type
=
rocblas_datatype_f32_r
;
}
}
int8_flag
=
int8_x4_format
?
rocblas_gemm_flags_pack_int8x4
:
rocblas_gemm_flags_none
;
int8_flag
=
int8_x4_format
?
rocblas_gemm_flags_pack_int8x4
:
rocblas_gemm_flags_none
;
auto
a_lens
=
input_shapes
[
0
].
lens
();
auto
b_lens
=
input_shapes
[
1
].
lens
();
auto
a_lens
=
input_shapes
[
0
].
lens
();
auto
b_lens
=
input_shapes
[
1
].
lens
();
auto
out_lens
=
output_shape
.
lens
();
m
=
out_lens
[
dim_0
];
n
=
out_lens
[
dim_1
];
k
=
input_shapes
[
0
].
lens
()[
dim_1
];
if
(
input_shapes
[
0
].
type
()
==
shape
::
int8_type
and
(
k
%
4
)
!=
0
and
int8_x4_format
)
{
MIGRAPHX_THROW
(
"ROCBLAS_GEMM: k size of int8 type input must be multiple of 4!"
);
}
auto
out_lens
=
output_shape
.
lens
();
m
=
out_lens
[
dim_0
];
n
=
out_lens
[
dim_1
];
k
=
input_shapes
[
0
].
lens
()[
dim_1
];
if
(
input_shapes
[
0
].
type
()
==
shape
::
int8_type
and
(
k
%
4
)
!=
0
and
int8_x4_format
)
{
MIGRAPHX_THROW
(
"ROCBLAS_GEMM: k size of int8 type input must be multiple of 4!"
);
}
a_stride
=
get_batch_stride
(
input_shapes
[
0
]);
b_stride
=
get_batch_stride
(
input_shapes
[
1
]);
c_stride
=
get_batch_stride
(
input_shapes
[
2
]);
d_stride
=
is_3inputs
?
get_batch_stride
(
input_shapes
[
3
])
:
c_stride
;
num_matrices
=
std
::
accumulate
(
out_lens
.
rbegin
()
+
2
,
out_lens
.
rend
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
if
(
num_matrices
==
1
or
(
num_matrices
>
1
and
b_stride
==
0
))
{
// If the batch dimension of B is broadcasted, then we can
// multiply m by the batch_size and use rocblas_gemm_ex
// instead of rocblas_gemm_strided_batched_ex.
m
*=
num_matrices
;
strided_batched
=
false
;
}
a_stride
=
get_batch_stride
(
input_shapes
[
0
]);
b_stride
=
get_batch_stride
(
input_shapes
[
1
]);
c_stride
=
get_batch_stride
(
input_shapes
[
2
]);
d_stride
=
is_3inputs
?
get_batch_stride
(
input_shapes
[
3
])
:
c_stride
;
num_matrices
=
std
::
accumulate
(
out_lens
.
rbegin
()
+
2
,
out_lens
.
rend
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
if
(
num_matrices
==
1
or
(
num_matrices
>
1
and
b_stride
==
0
))
{
// If the batch dimension of B is broadcasted, then we can
// multiply m by the batch_size and use rocblas_gemm_ex
// instead of rocblas_gemm_strided_batched_ex.
m
*=
num_matrices
;
strided_batched
=
false
;
}
}
void
run
(
context
&
ctx
,
const
std
::
vector
<
argument
>&
input_args
,
int32_t
solution_idx
=
0
)
const
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment