"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "d43cf0a34b4c25d3c2d472416b0e12c0d6d0a4a7"
Commit e69b4a33 authored by Paul's avatar Paul
Browse files

Fix merge conflicts:

parent f3a8933c
......@@ -147,7 +147,6 @@ struct gemm_impl
const std::vector<shape>& input_shapes,
T alpha_param,
T beta_param,
bool int8_x4_format,
bool compute_fp32_flag)
: alpha(alpha_param),
beta(beta_param),
......@@ -200,10 +199,6 @@ struct gemm_impl
compute_type = rocblas_datatype_f32_r;
}
#if ROCBLAS_VERSION_MAJOR < 3
int8_flag = int8_x4_format ? rocblas_gemm_flags_pack_int8x4 : rocblas_gemm_flags_none;
#endif
auto a_lens = input_shapes[0].lens();
auto b_lens = input_shapes[1].lens();
......@@ -211,10 +206,6 @@ struct gemm_impl
m = out_lens[dim_0];
n = out_lens[dim_1];
k = input_shapes[0].lens()[dim_1];
if(input_shapes[0].type() == shape::int8_type and (k % 4) != 0 and int8_x4_format)
{
MIGRAPHX_THROW("ROCBLAS_GEMM: k size of int8 type input must be multiple of 4!");
}
a_stride = get_batch_stride(input_shapes[0]);
b_stride = get_batch_stride(input_shapes[1]);
......@@ -241,13 +232,13 @@ struct gemm_impl
common_args,
rocblas_gemm_algo_standard,
solution_idx,
int8_flag);
gemm_flags);
}
else
{
auto common_args = create_gemm_ex_args_common(ctx, input_args);
rocblas_invoke(
&rocblas_gemm_ex, common_args, rocblas_gemm_algo_standard, solution_idx, int8_flag);
&rocblas_gemm_ex, common_args, rocblas_gemm_algo_standard, solution_idx, gemm_flags);
}
}
......@@ -408,7 +399,7 @@ struct gemm_impl
rocblas_invoke(&rocblas_gemm_strided_batched_ex_get_solutions,
common_args,
rocblas_gemm_algo_solution_index,
int8_flag,
gemm_flags,
nullptr,
&list_size);
solution_indices.resize(list_size);
......@@ -417,7 +408,7 @@ struct gemm_impl
rocblas_invoke(&rocblas_gemm_strided_batched_ex_get_solutions,
common_sol_args,
rocblas_gemm_algo_solution_index,
int8_flag,
gemm_flags,
solution_indices.data(),
&list_size);
}
......@@ -427,7 +418,7 @@ struct gemm_impl
rocblas_invoke(&rocblas_gemm_ex_get_solutions,
common_args,
rocblas_gemm_algo_solution_index,
int8_flag,
gemm_flags,
nullptr,
&list_size);
solution_indices.resize(list_size);
......@@ -436,7 +427,7 @@ struct gemm_impl
rocblas_invoke(&rocblas_gemm_ex_get_solutions,
common_sol_args,
rocblas_gemm_algo_solution_index,
int8_flag,
gemm_flags,
solution_indices.data(),
&list_size);
}
......@@ -489,7 +480,7 @@ struct gemm_impl
std::function<const void*()> get_alpha{};
std::function<const void*()> get_beta{};
flag_type int8_flag = 0;
rocblas_gemm_flags gemm_flags = rocblas_gemm_flags_none;
rocblas_int lda = 0;
rocblas_int ldb = 0;
rocblas_int ldc = 0;
......@@ -511,7 +502,6 @@ void gemm_compute(context& ctx,
const std::vector<argument>& args,
float alpha,
float beta,
bool int8_x4_format,
bool compute_fp32,
int32_t solution_idx)
{
......@@ -521,7 +511,7 @@ void gemm_compute(context& ctx,
std::back_inserter(input_shapes),
[](const argument& x) { return x.get_shape(); });
auto gemm_item =
gemm_impl<float>(output_shape, input_shapes, alpha, beta, int8_x4_format, compute_fp32);
gemm_impl<float>(output_shape, input_shapes, alpha, beta, compute_fp32);
gemm_item.run(ctx, args, solution_idx);
}
......@@ -530,7 +520,6 @@ void gemm_compute(context& ctx,
const std::vector<argument>& args,
int32_t alpha,
int32_t beta,
bool int8_x4_format,
bool compute_fp32,
int32_t solution_idx)
{
......@@ -540,7 +529,7 @@ void gemm_compute(context& ctx,
std::back_inserter(input_shapes),
[](const argument& x) { return x.get_shape(); });
auto gemm_item =
gemm_impl<int32_t>(output_shape, input_shapes, alpha, beta, int8_x4_format, compute_fp32);
gemm_impl<int32_t>(output_shape, input_shapes, alpha, beta, compute_fp32);
gemm_item.run(ctx, args, solution_idx);
}
......@@ -553,7 +542,6 @@ int32_t gemm_finalize(context& ctx,
const std::vector<shape>& input_shapes,
float alpha,
float beta,
bool int8_x4_format,
bool compute_fp32,
int32_t solution_idx)
{
......@@ -565,7 +553,7 @@ int32_t gemm_finalize(context& ctx,
if(solution_idx == 0)
{
auto gemm_item =
gemm_impl<float>(output_shape, input_shapes, alpha, beta, int8_x4_format, compute_fp32);
gemm_impl<float>(output_shape, input_shapes, alpha, beta, compute_fp32);
solution_idx = gemm_item.tune(ctx, input_shapes);
}
else
......@@ -573,13 +561,13 @@ int32_t gemm_finalize(context& ctx,
// If a tuned solution index is already given, don't tune again but validate
// in case the data was tuned with a different rocBLAS version
auto gemm_item =
gemm_impl<float>(output_shape, input_shapes, alpha, beta, int8_x4_format, compute_fp32);
gemm_impl<float>(output_shape, input_shapes, alpha, beta, compute_fp32);
solution_idx = gemm_item.validate(ctx, input_shapes, solution_idx);
}
#else
// suppress compiler warnings
(void)ctx, (void)output_shape, (void)input_shapes;
(void)alpha, (void)beta, (void)int8_x4_format, (void)compute_fp32;
(void)alpha, (void)beta, (void)compute_fp32;
#endif
return solution_idx;
}
......@@ -593,7 +581,6 @@ int32_t gemm_finalize(context& ctx,
const std::vector<shape>& input_shapes,
int32_t alpha,
int32_t beta,
bool int8_x4_format,
bool compute_fp32,
int32_t solution_idx)
{
......@@ -604,7 +591,7 @@ int32_t gemm_finalize(context& ctx,
if(solution_idx == 0)
{
auto gemm_item = gemm_impl<int32_t>(
output_shape, input_shapes, alpha, beta, int8_x4_format, compute_fp32);
output_shape, input_shapes, alpha, beta, compute_fp32);
solution_idx = gemm_item.tune(ctx, input_shapes);
}
else
......@@ -612,13 +599,13 @@ int32_t gemm_finalize(context& ctx,
// If a tuned solution index is already given, don't tune again but validate
// in case the data was tuned with a different rocBLAS version
auto gemm_item = gemm_impl<int32_t>(
output_shape, input_shapes, alpha, beta, int8_x4_format, compute_fp32);
output_shape, input_shapes, alpha, beta, compute_fp32);
solution_idx = gemm_item.validate(ctx, input_shapes, solution_idx);
}
#else
// suppress compiler warnings
(void)ctx, (void)output_shape, (void)input_shapes;
(void)alpha, (void)beta, (void)int8_x4_format, (void)compute_fp32;
(void)alpha, (void)beta, (void)compute_fp32;
#endif
return solution_idx;
}
......
......@@ -115,7 +115,7 @@ struct rocblas_gemm
if(this->name() == "gpu::gemm")
{
gemm_compute(
ctx, output_shape, args, alpha, beta, int8_x4_format, compute_fp32, solution_idx);
ctx, output_shape, args, alpha, beta, compute_fp32, solution_idx);
}
else
{
......@@ -124,7 +124,6 @@ struct rocblas_gemm
args,
int32_t(alpha),
int32_t(beta),
int8_x4_format,
compute_fp32,
solution_idx);
}
......@@ -148,7 +147,6 @@ struct rocblas_gemm
input_shapes,
alpha,
beta,
int8_x4_format,
compute_fp32,
solution_idx);
}
......@@ -159,7 +157,6 @@ struct rocblas_gemm
input_shapes,
int32_t(alpha),
int32_t(beta),
int8_x4_format,
compute_fp32,
solution_idx);
}
......
......@@ -60,7 +60,6 @@ using flag_type = int;
* @param args .
* @param alpha .
* @param beta .
* @param int8_x4_format .
* @param compute_fp32 .
*/
void gemm_compute(context& ctx,
......@@ -68,7 +67,6 @@ void gemm_compute(context& ctx,
const std::vector<argument>& args,
float alpha,
float beta,
bool int8_x4_format,
bool compute_fp32,
int32_t solution_idx);
......@@ -77,7 +75,6 @@ void gemm_compute(context& ctx,
const std::vector<argument>& args,
int32_t alpha,
int32_t beta,
bool int8_x4_format,
bool compute_fp32,
int32_t solution_idx);
......@@ -86,7 +83,6 @@ int32_t gemm_finalize(context& ctx,
const std::vector<shape>& input_shapes,
float alpha,
float beta,
bool int8_x4_format,
bool compute_fp32);
int32_t gemm_finalize(context& ctx,
......@@ -94,7 +90,6 @@ int32_t gemm_finalize(context& ctx,
const std::vector<shape>& input_shapes,
int32_t alpha,
int32_t beta,
bool int8_x4_format,
bool compute_fp32,
int32_t solution_idx);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment