Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
e846ae72
Commit
e846ae72
authored
Oct 17, 2023
by
Paul
Browse files
Format
parent
dc71e23d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
45 additions
and
43 deletions
+45
-43
src/targets/gpu/gemm_impl.cpp
src/targets/gpu/gemm_impl.cpp
+45
-43
No files found.
src/targets/gpu/gemm_impl.cpp
View file @
e846ae72
...
...
@@ -160,51 +160,51 @@ struct gemm_impl
}
if
(
arg_type
==
rocblas_datatype_f16_r
)
compute_type
=
rocblas_datatype_f32_r
;
rocblas_gemm_flags
flag
=
rocblas_gemm_flags_none
;
#if ROCBLAS_VERSION_MAJOR < 3
if
(
int8_x4_format
)
flag
=
rocblas_gemm_flags_pack_int8x4
;
#endif
// Create lambdas that will cast alpha, beta to the output shape's type
// and retain the values being pointed to
output_shape
.
visit_type
([
&
](
auto
as
)
{
auto
alpha_r
=
as
(
alpha
);
auto
beta_r
=
as
(
beta
);
if
(
compute_fp32
)
{
get_alpha
=
[
=
]
{
return
&
alpha
;
};
get_beta
=
[
=
]
{
return
&
beta
;
};
}
else
{
get_alpha
=
[
=
]
{
return
&
alpha_r
;
};
get_beta
=
[
=
]
{
return
&
beta_r
;
};
}
});
transa
=
is_transposed
(
input_shapes
[
0
]);
transb
=
is_transposed
(
input_shapes
[
1
]);
auto
n_dim
=
output_shape
.
lens
().
size
();
auto
dim_0
=
n_dim
-
2
;
auto
dim_1
=
n_dim
-
1
;
// Leading dimensions of matrices
lda
=
input_shapes
[
0
].
strides
()[
transa
?
dim_1
:
dim_0
];
ldb
=
input_shapes
[
1
].
strides
()[
transb
?
dim_1
:
dim_0
];
ldc
=
input_shapes
[
2
].
strides
()[
dim_0
];
ldd
=
is_3inputs
?
input_shapes
[
3
].
strides
()[
dim_0
]
:
ldc
;
arg_type
=
get_type
(
input_shapes
[
0
].
type
());
output_type
=
arg_type
;
if
(
output_type
==
rocblas_datatype_i8_r
)
// Create lambdas that will cast alpha, beta to the output shape's type
// and retain the values being pointed to
output_shape
.
visit_type
([
&
](
auto
as
)
{
auto
alpha_r
=
as
(
alpha
);
auto
beta_r
=
as
(
beta
);
if
(
compute_fp32
)
{
output_type
=
rocblas_datatype_i32_r
;
get_alpha
=
[
=
]
{
return
&
alpha
;
};
get_beta
=
[
=
]
{
return
&
beta
;
};
}
else
{
get_alpha
=
[
=
]
{
return
&
alpha_r
;
};
get_beta
=
[
=
]
{
return
&
beta_r
;
};
}
});
transa
=
is_transposed
(
input_shapes
[
0
]);
transb
=
is_transposed
(
input_shapes
[
1
]);
auto
n_dim
=
output_shape
.
lens
().
size
();
auto
dim_0
=
n_dim
-
2
;
auto
dim_1
=
n_dim
-
1
;
// Leading dimensions of matrices
lda
=
input_shapes
[
0
].
strides
()[
transa
?
dim_1
:
dim_0
];
ldb
=
input_shapes
[
1
].
strides
()[
transb
?
dim_1
:
dim_0
];
ldc
=
input_shapes
[
2
].
strides
()[
dim_0
];
ldd
=
is_3inputs
?
input_shapes
[
3
].
strides
()[
dim_0
]
:
ldc
;
arg_type
=
get_type
(
input_shapes
[
0
].
type
());
output_type
=
arg_type
;
if
(
output_type
==
rocblas_datatype_i8_r
)
{
output_type
=
rocblas_datatype_i32_r
;
}
compute_type
=
output_type
;
if
(
compute_fp32
)
{
if
(
arg_type
==
rocblas_datatype_f16_r
)
compute_type
=
rocblas_datatype_f32_r
;
if
(
arg_type
==
rocblas_datatype_f16_r
)
compute_type
=
rocblas_datatype_f32_r
;
}
int8_flag
=
int8_x4_format
?
rocblas_gemm_flags_pack_int8x4
:
rocblas_gemm_flags_none
;
...
...
@@ -218,22 +218,24 @@ struct gemm_impl
k
=
input_shapes
[
0
].
lens
()[
dim_1
];
if
(
input_shapes
[
0
].
type
()
==
shape
::
int8_type
and
(
k
%
4
)
!=
0
and
int8_x4_format
)
{
MIGRAPHX_THROW
(
"ROCBLAS_GEMM: k size of int8 type input must be multiple of 4!"
);
MIGRAPHX_THROW
(
"ROCBLAS_GEMM: k size of int8 type input must be multiple of 4!"
);
}
a_stride
=
get_batch_stride
(
input_shapes
[
0
]);
b_stride
=
get_batch_stride
(
input_shapes
[
1
]);
c_stride
=
get_batch_stride
(
input_shapes
[
2
]);
d_stride
=
is_3inputs
?
get_batch_stride
(
input_shapes
[
3
])
:
c_stride
;
num_matrices
=
std
::
accumulate
(
out_lens
.
rbegin
()
+
2
,
out_lens
.
rend
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
num_matrices
=
std
::
accumulate
(
out_lens
.
rbegin
()
+
2
,
out_lens
.
rend
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
if
(
num_matrices
==
1
or
(
num_matrices
>
1
and
b_stride
==
0
))
{
// If the batch dimension of B is broadcasted, then we can
// multiply m by the batch_size and use rocblas_gemm_ex
// instead of rocblas_gemm_strided_batched_ex.
m
*=
num_matrices
;
strided_batched
=
false
;
// If the batch dimension of B is broadcasted, then we can
// multiply m by the batch_size and use rocblas_gemm_ex
// instead of rocblas_gemm_strided_batched_ex.
m
*=
num_matrices
;
strided_batched
=
false
;
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment