Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
e69b4a33
"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "d43cf0a34b4c25d3c2d472416b0e12c0d6d0a4a7"
Commit
e69b4a33
authored
Nov 02, 2023
by
Paul
Browse files
Fix merge conflicts:
parent
f3a8933c
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
16 additions
and
37 deletions
+16
-37
src/targets/gpu/gemm_impl.cpp
src/targets/gpu/gemm_impl.cpp
+15
-28
src/targets/gpu/include/migraphx/gpu/gemm.hpp
src/targets/gpu/include/migraphx/gpu/gemm.hpp
+1
-4
src/targets/gpu/include/migraphx/gpu/gemm_impl.hpp
src/targets/gpu/include/migraphx/gpu/gemm_impl.hpp
+0
-5
No files found.
src/targets/gpu/gemm_impl.cpp
View file @
e69b4a33
...
...
@@ -147,7 +147,6 @@ struct gemm_impl
const
std
::
vector
<
shape
>&
input_shapes
,
T
alpha_param
,
T
beta_param
,
bool
int8_x4_format
,
bool
compute_fp32_flag
)
:
alpha
(
alpha_param
),
beta
(
beta_param
),
...
...
@@ -200,10 +199,6 @@ struct gemm_impl
compute_type
=
rocblas_datatype_f32_r
;
}
#if ROCBLAS_VERSION_MAJOR < 3
int8_flag
=
int8_x4_format
?
rocblas_gemm_flags_pack_int8x4
:
rocblas_gemm_flags_none
;
#endif
auto
a_lens
=
input_shapes
[
0
].
lens
();
auto
b_lens
=
input_shapes
[
1
].
lens
();
...
...
@@ -211,10 +206,6 @@ struct gemm_impl
m
=
out_lens
[
dim_0
];
n
=
out_lens
[
dim_1
];
k
=
input_shapes
[
0
].
lens
()[
dim_1
];
if
(
input_shapes
[
0
].
type
()
==
shape
::
int8_type
and
(
k
%
4
)
!=
0
and
int8_x4_format
)
{
MIGRAPHX_THROW
(
"ROCBLAS_GEMM: k size of int8 type input must be multiple of 4!"
);
}
a_stride
=
get_batch_stride
(
input_shapes
[
0
]);
b_stride
=
get_batch_stride
(
input_shapes
[
1
]);
...
...
@@ -241,13 +232,13 @@ struct gemm_impl
common_args
,
rocblas_gemm_algo_standard
,
solution_idx
,
int8
_flag
);
gemm
_flag
s
);
}
else
{
auto
common_args
=
create_gemm_ex_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_ex
,
common_args
,
rocblas_gemm_algo_standard
,
solution_idx
,
int8
_flag
);
&
rocblas_gemm_ex
,
common_args
,
rocblas_gemm_algo_standard
,
solution_idx
,
gemm
_flag
s
);
}
}
...
...
@@ -408,7 +399,7 @@ struct gemm_impl
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex_get_solutions
,
common_args
,
rocblas_gemm_algo_solution_index
,
int8
_flag
,
gemm
_flag
s
,
nullptr
,
&
list_size
);
solution_indices
.
resize
(
list_size
);
...
...
@@ -417,7 +408,7 @@ struct gemm_impl
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex_get_solutions
,
common_sol_args
,
rocblas_gemm_algo_solution_index
,
int8
_flag
,
gemm
_flag
s
,
solution_indices
.
data
(),
&
list_size
);
}
...
...
@@ -427,7 +418,7 @@ struct gemm_impl
rocblas_invoke
(
&
rocblas_gemm_ex_get_solutions
,
common_args
,
rocblas_gemm_algo_solution_index
,
int8
_flag
,
gemm
_flag
s
,
nullptr
,
&
list_size
);
solution_indices
.
resize
(
list_size
);
...
...
@@ -436,7 +427,7 @@ struct gemm_impl
rocblas_invoke
(
&
rocblas_gemm_ex_get_solutions
,
common_sol_args
,
rocblas_gemm_algo_solution_index
,
int8
_flag
,
gemm
_flag
s
,
solution_indices
.
data
(),
&
list_size
);
}
...
...
@@ -489,7 +480,7 @@ struct gemm_impl
std
::
function
<
const
void
*
()
>
get_alpha
{};
std
::
function
<
const
void
*
()
>
get_beta
{};
flag_type
int8
_flag
=
0
;
rocblas_gemm_flags
gemm
_flag
s
=
rocblas_gemm_flags_none
;
rocblas_int
lda
=
0
;
rocblas_int
ldb
=
0
;
rocblas_int
ldc
=
0
;
...
...
@@ -511,7 +502,6 @@ void gemm_compute(context& ctx,
const
std
::
vector
<
argument
>&
args
,
float
alpha
,
float
beta
,
bool
int8_x4_format
,
bool
compute_fp32
,
int32_t
solution_idx
)
{
...
...
@@ -521,7 +511,7 @@ void gemm_compute(context& ctx,
std
::
back_inserter
(
input_shapes
),
[](
const
argument
&
x
)
{
return
x
.
get_shape
();
});
auto
gemm_item
=
gemm_impl
<
float
>
(
output_shape
,
input_shapes
,
alpha
,
beta
,
int8_x4_format
,
compute_fp32
);
gemm_impl
<
float
>
(
output_shape
,
input_shapes
,
alpha
,
beta
,
compute_fp32
);
gemm_item
.
run
(
ctx
,
args
,
solution_idx
);
}
...
...
@@ -530,7 +520,6 @@ void gemm_compute(context& ctx,
const
std
::
vector
<
argument
>&
args
,
int32_t
alpha
,
int32_t
beta
,
bool
int8_x4_format
,
bool
compute_fp32
,
int32_t
solution_idx
)
{
...
...
@@ -540,7 +529,7 @@ void gemm_compute(context& ctx,
std
::
back_inserter
(
input_shapes
),
[](
const
argument
&
x
)
{
return
x
.
get_shape
();
});
auto
gemm_item
=
gemm_impl
<
int32_t
>
(
output_shape
,
input_shapes
,
alpha
,
beta
,
int8_x4_format
,
compute_fp32
);
gemm_impl
<
int32_t
>
(
output_shape
,
input_shapes
,
alpha
,
beta
,
compute_fp32
);
gemm_item
.
run
(
ctx
,
args
,
solution_idx
);
}
...
...
@@ -553,7 +542,6 @@ int32_t gemm_finalize(context& ctx,
const
std
::
vector
<
shape
>&
input_shapes
,
float
alpha
,
float
beta
,
bool
int8_x4_format
,
bool
compute_fp32
,
int32_t
solution_idx
)
{
...
...
@@ -565,7 +553,7 @@ int32_t gemm_finalize(context& ctx,
if
(
solution_idx
==
0
)
{
auto
gemm_item
=
gemm_impl
<
float
>
(
output_shape
,
input_shapes
,
alpha
,
beta
,
int8_x4_format
,
compute_fp32
);
gemm_impl
<
float
>
(
output_shape
,
input_shapes
,
alpha
,
beta
,
compute_fp32
);
solution_idx
=
gemm_item
.
tune
(
ctx
,
input_shapes
);
}
else
...
...
@@ -573,13 +561,13 @@ int32_t gemm_finalize(context& ctx,
// If a tuned solution index is already given, don't tune again but validate
// in case the data was tuned with a different rocBLAS version
auto
gemm_item
=
gemm_impl
<
float
>
(
output_shape
,
input_shapes
,
alpha
,
beta
,
int8_x4_format
,
compute_fp32
);
gemm_impl
<
float
>
(
output_shape
,
input_shapes
,
alpha
,
beta
,
compute_fp32
);
solution_idx
=
gemm_item
.
validate
(
ctx
,
input_shapes
,
solution_idx
);
}
#else
// suppress compiler warnings
(
void
)
ctx
,
(
void
)
output_shape
,
(
void
)
input_shapes
;
(
void
)
alpha
,
(
void
)
beta
,
(
void
)
int8_x4_format
,
(
void
)
compute_fp32
;
(
void
)
alpha
,
(
void
)
beta
,
(
void
)
compute_fp32
;
#endif
return
solution_idx
;
}
...
...
@@ -593,7 +581,6 @@ int32_t gemm_finalize(context& ctx,
const
std
::
vector
<
shape
>&
input_shapes
,
int32_t
alpha
,
int32_t
beta
,
bool
int8_x4_format
,
bool
compute_fp32
,
int32_t
solution_idx
)
{
...
...
@@ -604,7 +591,7 @@ int32_t gemm_finalize(context& ctx,
if
(
solution_idx
==
0
)
{
auto
gemm_item
=
gemm_impl
<
int32_t
>
(
output_shape
,
input_shapes
,
alpha
,
beta
,
int8_x4_format
,
compute_fp32
);
output_shape
,
input_shapes
,
alpha
,
beta
,
compute_fp32
);
solution_idx
=
gemm_item
.
tune
(
ctx
,
input_shapes
);
}
else
...
...
@@ -612,13 +599,13 @@ int32_t gemm_finalize(context& ctx,
// If a tuned solution index is already given, don't tune again but validate
// in case the data was tuned with a different rocBLAS version
auto
gemm_item
=
gemm_impl
<
int32_t
>
(
output_shape
,
input_shapes
,
alpha
,
beta
,
int8_x4_format
,
compute_fp32
);
output_shape
,
input_shapes
,
alpha
,
beta
,
compute_fp32
);
solution_idx
=
gemm_item
.
validate
(
ctx
,
input_shapes
,
solution_idx
);
}
#else
// suppress compiler warnings
(
void
)
ctx
,
(
void
)
output_shape
,
(
void
)
input_shapes
;
(
void
)
alpha
,
(
void
)
beta
,
(
void
)
int8_x4_format
,
(
void
)
compute_fp32
;
(
void
)
alpha
,
(
void
)
beta
,
(
void
)
compute_fp32
;
#endif
return
solution_idx
;
}
...
...
src/targets/gpu/include/migraphx/gpu/gemm.hpp
View file @
e69b4a33
...
...
@@ -115,7 +115,7 @@ struct rocblas_gemm
if
(
this
->
name
()
==
"gpu::gemm"
)
{
gemm_compute
(
ctx
,
output_shape
,
args
,
alpha
,
beta
,
int8_x4_format
,
compute_fp32
,
solution_idx
);
ctx
,
output_shape
,
args
,
alpha
,
beta
,
compute_fp32
,
solution_idx
);
}
else
{
...
...
@@ -124,7 +124,6 @@ struct rocblas_gemm
args
,
int32_t
(
alpha
),
int32_t
(
beta
),
int8_x4_format
,
compute_fp32
,
solution_idx
);
}
...
...
@@ -148,7 +147,6 @@ struct rocblas_gemm
input_shapes
,
alpha
,
beta
,
int8_x4_format
,
compute_fp32
,
solution_idx
);
}
...
...
@@ -159,7 +157,6 @@ struct rocblas_gemm
input_shapes
,
int32_t
(
alpha
),
int32_t
(
beta
),
int8_x4_format
,
compute_fp32
,
solution_idx
);
}
...
...
src/targets/gpu/include/migraphx/gpu/gemm_impl.hpp
View file @
e69b4a33
...
...
@@ -60,7 +60,6 @@ using flag_type = int;
* @param args .
* @param alpha .
* @param beta .
* @param int8_x4_format .
* @param compute_fp32 .
*/
void
gemm_compute
(
context
&
ctx
,
...
...
@@ -68,7 +67,6 @@ void gemm_compute(context& ctx,
const
std
::
vector
<
argument
>&
args
,
float
alpha
,
float
beta
,
bool
int8_x4_format
,
bool
compute_fp32
,
int32_t
solution_idx
);
...
...
@@ -77,7 +75,6 @@ void gemm_compute(context& ctx,
const
std
::
vector
<
argument
>&
args
,
int32_t
alpha
,
int32_t
beta
,
bool
int8_x4_format
,
bool
compute_fp32
,
int32_t
solution_idx
);
...
...
@@ -86,7 +83,6 @@ int32_t gemm_finalize(context& ctx,
const
std
::
vector
<
shape
>&
input_shapes
,
float
alpha
,
float
beta
,
bool
int8_x4_format
,
bool
compute_fp32
);
int32_t
gemm_finalize
(
context
&
ctx
,
...
...
@@ -94,7 +90,6 @@ int32_t gemm_finalize(context& ctx,
const
std
::
vector
<
shape
>&
input_shapes
,
int32_t
alpha
,
int32_t
beta
,
bool
int8_x4_format
,
bool
compute_fp32
,
int32_t
solution_idx
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment