Commit e69b4a33 authored by Paul's avatar Paul
Browse files

Fix merge conflicts:

parent f3a8933c
...@@ -147,7 +147,6 @@ struct gemm_impl ...@@ -147,7 +147,6 @@ struct gemm_impl
const std::vector<shape>& input_shapes, const std::vector<shape>& input_shapes,
T alpha_param, T alpha_param,
T beta_param, T beta_param,
bool int8_x4_format,
bool compute_fp32_flag) bool compute_fp32_flag)
: alpha(alpha_param), : alpha(alpha_param),
beta(beta_param), beta(beta_param),
...@@ -200,10 +199,6 @@ struct gemm_impl ...@@ -200,10 +199,6 @@ struct gemm_impl
compute_type = rocblas_datatype_f32_r; compute_type = rocblas_datatype_f32_r;
} }
#if ROCBLAS_VERSION_MAJOR < 3
int8_flag = int8_x4_format ? rocblas_gemm_flags_pack_int8x4 : rocblas_gemm_flags_none;
#endif
auto a_lens = input_shapes[0].lens(); auto a_lens = input_shapes[0].lens();
auto b_lens = input_shapes[1].lens(); auto b_lens = input_shapes[1].lens();
...@@ -211,10 +206,6 @@ struct gemm_impl ...@@ -211,10 +206,6 @@ struct gemm_impl
m = out_lens[dim_0]; m = out_lens[dim_0];
n = out_lens[dim_1]; n = out_lens[dim_1];
k = input_shapes[0].lens()[dim_1]; k = input_shapes[0].lens()[dim_1];
if(input_shapes[0].type() == shape::int8_type and (k % 4) != 0 and int8_x4_format)
{
MIGRAPHX_THROW("ROCBLAS_GEMM: k size of int8 type input must be multiple of 4!");
}
a_stride = get_batch_stride(input_shapes[0]); a_stride = get_batch_stride(input_shapes[0]);
b_stride = get_batch_stride(input_shapes[1]); b_stride = get_batch_stride(input_shapes[1]);
...@@ -241,13 +232,13 @@ struct gemm_impl ...@@ -241,13 +232,13 @@ struct gemm_impl
common_args, common_args,
rocblas_gemm_algo_standard, rocblas_gemm_algo_standard,
solution_idx, solution_idx,
int8_flag); gemm_flags);
} }
else else
{ {
auto common_args = create_gemm_ex_args_common(ctx, input_args); auto common_args = create_gemm_ex_args_common(ctx, input_args);
rocblas_invoke( rocblas_invoke(
&rocblas_gemm_ex, common_args, rocblas_gemm_algo_standard, solution_idx, int8_flag); &rocblas_gemm_ex, common_args, rocblas_gemm_algo_standard, solution_idx, gemm_flags);
} }
} }
...@@ -408,7 +399,7 @@ struct gemm_impl ...@@ -408,7 +399,7 @@ struct gemm_impl
rocblas_invoke(&rocblas_gemm_strided_batched_ex_get_solutions, rocblas_invoke(&rocblas_gemm_strided_batched_ex_get_solutions,
common_args, common_args,
rocblas_gemm_algo_solution_index, rocblas_gemm_algo_solution_index,
int8_flag, gemm_flags,
nullptr, nullptr,
&list_size); &list_size);
solution_indices.resize(list_size); solution_indices.resize(list_size);
...@@ -417,7 +408,7 @@ struct gemm_impl ...@@ -417,7 +408,7 @@ struct gemm_impl
rocblas_invoke(&rocblas_gemm_strided_batched_ex_get_solutions, rocblas_invoke(&rocblas_gemm_strided_batched_ex_get_solutions,
common_sol_args, common_sol_args,
rocblas_gemm_algo_solution_index, rocblas_gemm_algo_solution_index,
int8_flag, gemm_flags,
solution_indices.data(), solution_indices.data(),
&list_size); &list_size);
} }
...@@ -427,7 +418,7 @@ struct gemm_impl ...@@ -427,7 +418,7 @@ struct gemm_impl
rocblas_invoke(&rocblas_gemm_ex_get_solutions, rocblas_invoke(&rocblas_gemm_ex_get_solutions,
common_args, common_args,
rocblas_gemm_algo_solution_index, rocblas_gemm_algo_solution_index,
int8_flag, gemm_flags,
nullptr, nullptr,
&list_size); &list_size);
solution_indices.resize(list_size); solution_indices.resize(list_size);
...@@ -436,7 +427,7 @@ struct gemm_impl ...@@ -436,7 +427,7 @@ struct gemm_impl
rocblas_invoke(&rocblas_gemm_ex_get_solutions, rocblas_invoke(&rocblas_gemm_ex_get_solutions,
common_sol_args, common_sol_args,
rocblas_gemm_algo_solution_index, rocblas_gemm_algo_solution_index,
int8_flag, gemm_flags,
solution_indices.data(), solution_indices.data(),
&list_size); &list_size);
} }
...@@ -489,7 +480,7 @@ struct gemm_impl ...@@ -489,7 +480,7 @@ struct gemm_impl
std::function<const void*()> get_alpha{}; std::function<const void*()> get_alpha{};
std::function<const void*()> get_beta{}; std::function<const void*()> get_beta{};
flag_type int8_flag = 0; rocblas_gemm_flags gemm_flags = rocblas_gemm_flags_none;
rocblas_int lda = 0; rocblas_int lda = 0;
rocblas_int ldb = 0; rocblas_int ldb = 0;
rocblas_int ldc = 0; rocblas_int ldc = 0;
...@@ -511,7 +502,6 @@ void gemm_compute(context& ctx, ...@@ -511,7 +502,6 @@ void gemm_compute(context& ctx,
const std::vector<argument>& args, const std::vector<argument>& args,
float alpha, float alpha,
float beta, float beta,
bool int8_x4_format,
bool compute_fp32, bool compute_fp32,
int32_t solution_idx) int32_t solution_idx)
{ {
...@@ -521,7 +511,7 @@ void gemm_compute(context& ctx, ...@@ -521,7 +511,7 @@ void gemm_compute(context& ctx,
std::back_inserter(input_shapes), std::back_inserter(input_shapes),
[](const argument& x) { return x.get_shape(); }); [](const argument& x) { return x.get_shape(); });
auto gemm_item = auto gemm_item =
gemm_impl<float>(output_shape, input_shapes, alpha, beta, int8_x4_format, compute_fp32); gemm_impl<float>(output_shape, input_shapes, alpha, beta, compute_fp32);
gemm_item.run(ctx, args, solution_idx); gemm_item.run(ctx, args, solution_idx);
} }
...@@ -530,7 +520,6 @@ void gemm_compute(context& ctx, ...@@ -530,7 +520,6 @@ void gemm_compute(context& ctx,
const std::vector<argument>& args, const std::vector<argument>& args,
int32_t alpha, int32_t alpha,
int32_t beta, int32_t beta,
bool int8_x4_format,
bool compute_fp32, bool compute_fp32,
int32_t solution_idx) int32_t solution_idx)
{ {
...@@ -540,7 +529,7 @@ void gemm_compute(context& ctx, ...@@ -540,7 +529,7 @@ void gemm_compute(context& ctx,
std::back_inserter(input_shapes), std::back_inserter(input_shapes),
[](const argument& x) { return x.get_shape(); }); [](const argument& x) { return x.get_shape(); });
auto gemm_item = auto gemm_item =
gemm_impl<int32_t>(output_shape, input_shapes, alpha, beta, int8_x4_format, compute_fp32); gemm_impl<int32_t>(output_shape, input_shapes, alpha, beta, compute_fp32);
gemm_item.run(ctx, args, solution_idx); gemm_item.run(ctx, args, solution_idx);
} }
...@@ -553,7 +542,6 @@ int32_t gemm_finalize(context& ctx, ...@@ -553,7 +542,6 @@ int32_t gemm_finalize(context& ctx,
const std::vector<shape>& input_shapes, const std::vector<shape>& input_shapes,
float alpha, float alpha,
float beta, float beta,
bool int8_x4_format,
bool compute_fp32, bool compute_fp32,
int32_t solution_idx) int32_t solution_idx)
{ {
...@@ -565,7 +553,7 @@ int32_t gemm_finalize(context& ctx, ...@@ -565,7 +553,7 @@ int32_t gemm_finalize(context& ctx,
if(solution_idx == 0) if(solution_idx == 0)
{ {
auto gemm_item = auto gemm_item =
gemm_impl<float>(output_shape, input_shapes, alpha, beta, int8_x4_format, compute_fp32); gemm_impl<float>(output_shape, input_shapes, alpha, beta, compute_fp32);
solution_idx = gemm_item.tune(ctx, input_shapes); solution_idx = gemm_item.tune(ctx, input_shapes);
} }
else else
...@@ -573,13 +561,13 @@ int32_t gemm_finalize(context& ctx, ...@@ -573,13 +561,13 @@ int32_t gemm_finalize(context& ctx,
// If a tuned solution index is already given, don't tune again but validate // If a tuned solution index is already given, don't tune again but validate
// in case the data was tuned with a different rocBLAS version // in case the data was tuned with a different rocBLAS version
auto gemm_item = auto gemm_item =
gemm_impl<float>(output_shape, input_shapes, alpha, beta, int8_x4_format, compute_fp32); gemm_impl<float>(output_shape, input_shapes, alpha, beta, compute_fp32);
solution_idx = gemm_item.validate(ctx, input_shapes, solution_idx); solution_idx = gemm_item.validate(ctx, input_shapes, solution_idx);
} }
#else #else
// suppress compiler warnings // suppress compiler warnings
(void)ctx, (void)output_shape, (void)input_shapes; (void)ctx, (void)output_shape, (void)input_shapes;
(void)alpha, (void)beta, (void)int8_x4_format, (void)compute_fp32; (void)alpha, (void)beta, (void)compute_fp32;
#endif #endif
return solution_idx; return solution_idx;
} }
...@@ -593,7 +581,6 @@ int32_t gemm_finalize(context& ctx, ...@@ -593,7 +581,6 @@ int32_t gemm_finalize(context& ctx,
const std::vector<shape>& input_shapes, const std::vector<shape>& input_shapes,
int32_t alpha, int32_t alpha,
int32_t beta, int32_t beta,
bool int8_x4_format,
bool compute_fp32, bool compute_fp32,
int32_t solution_idx) int32_t solution_idx)
{ {
...@@ -604,7 +591,7 @@ int32_t gemm_finalize(context& ctx, ...@@ -604,7 +591,7 @@ int32_t gemm_finalize(context& ctx,
if(solution_idx == 0) if(solution_idx == 0)
{ {
auto gemm_item = gemm_impl<int32_t>( auto gemm_item = gemm_impl<int32_t>(
output_shape, input_shapes, alpha, beta, int8_x4_format, compute_fp32); output_shape, input_shapes, alpha, beta, compute_fp32);
solution_idx = gemm_item.tune(ctx, input_shapes); solution_idx = gemm_item.tune(ctx, input_shapes);
} }
else else
...@@ -612,13 +599,13 @@ int32_t gemm_finalize(context& ctx, ...@@ -612,13 +599,13 @@ int32_t gemm_finalize(context& ctx,
// If a tuned solution index is already given, don't tune again but validate // If a tuned solution index is already given, don't tune again but validate
// in case the data was tuned with a different rocBLAS version // in case the data was tuned with a different rocBLAS version
auto gemm_item = gemm_impl<int32_t>( auto gemm_item = gemm_impl<int32_t>(
output_shape, input_shapes, alpha, beta, int8_x4_format, compute_fp32); output_shape, input_shapes, alpha, beta, compute_fp32);
solution_idx = gemm_item.validate(ctx, input_shapes, solution_idx); solution_idx = gemm_item.validate(ctx, input_shapes, solution_idx);
} }
#else #else
// suppress compiler warnings // suppress compiler warnings
(void)ctx, (void)output_shape, (void)input_shapes; (void)ctx, (void)output_shape, (void)input_shapes;
(void)alpha, (void)beta, (void)int8_x4_format, (void)compute_fp32; (void)alpha, (void)beta, (void)compute_fp32;
#endif #endif
return solution_idx; return solution_idx;
} }
......
...@@ -115,7 +115,7 @@ struct rocblas_gemm ...@@ -115,7 +115,7 @@ struct rocblas_gemm
if(this->name() == "gpu::gemm") if(this->name() == "gpu::gemm")
{ {
gemm_compute( gemm_compute(
ctx, output_shape, args, alpha, beta, int8_x4_format, compute_fp32, solution_idx); ctx, output_shape, args, alpha, beta, compute_fp32, solution_idx);
} }
else else
{ {
...@@ -124,7 +124,6 @@ struct rocblas_gemm ...@@ -124,7 +124,6 @@ struct rocblas_gemm
args, args,
int32_t(alpha), int32_t(alpha),
int32_t(beta), int32_t(beta),
int8_x4_format,
compute_fp32, compute_fp32,
solution_idx); solution_idx);
} }
...@@ -148,7 +147,6 @@ struct rocblas_gemm ...@@ -148,7 +147,6 @@ struct rocblas_gemm
input_shapes, input_shapes,
alpha, alpha,
beta, beta,
int8_x4_format,
compute_fp32, compute_fp32,
solution_idx); solution_idx);
} }
...@@ -159,7 +157,6 @@ struct rocblas_gemm ...@@ -159,7 +157,6 @@ struct rocblas_gemm
input_shapes, input_shapes,
int32_t(alpha), int32_t(alpha),
int32_t(beta), int32_t(beta),
int8_x4_format,
compute_fp32, compute_fp32,
solution_idx); solution_idx);
} }
......
...@@ -60,7 +60,6 @@ using flag_type = int; ...@@ -60,7 +60,6 @@ using flag_type = int;
* @param args . * @param args .
* @param alpha . * @param alpha .
* @param beta . * @param beta .
* @param int8_x4_format .
* @param compute_fp32 . * @param compute_fp32 .
*/ */
void gemm_compute(context& ctx, void gemm_compute(context& ctx,
...@@ -68,7 +67,6 @@ void gemm_compute(context& ctx, ...@@ -68,7 +67,6 @@ void gemm_compute(context& ctx,
const std::vector<argument>& args, const std::vector<argument>& args,
float alpha, float alpha,
float beta, float beta,
bool int8_x4_format,
bool compute_fp32, bool compute_fp32,
int32_t solution_idx); int32_t solution_idx);
...@@ -77,7 +75,6 @@ void gemm_compute(context& ctx, ...@@ -77,7 +75,6 @@ void gemm_compute(context& ctx,
const std::vector<argument>& args, const std::vector<argument>& args,
int32_t alpha, int32_t alpha,
int32_t beta, int32_t beta,
bool int8_x4_format,
bool compute_fp32, bool compute_fp32,
int32_t solution_idx); int32_t solution_idx);
...@@ -86,7 +83,6 @@ int32_t gemm_finalize(context& ctx, ...@@ -86,7 +83,6 @@ int32_t gemm_finalize(context& ctx,
const std::vector<shape>& input_shapes, const std::vector<shape>& input_shapes,
float alpha, float alpha,
float beta, float beta,
bool int8_x4_format,
bool compute_fp32); bool compute_fp32);
int32_t gemm_finalize(context& ctx, int32_t gemm_finalize(context& ctx,
...@@ -94,7 +90,6 @@ int32_t gemm_finalize(context& ctx, ...@@ -94,7 +90,6 @@ int32_t gemm_finalize(context& ctx,
const std::vector<shape>& input_shapes, const std::vector<shape>& input_shapes,
int32_t alpha, int32_t alpha,
int32_t beta, int32_t beta,
bool int8_x4_format,
bool compute_fp32, bool compute_fp32,
int32_t solution_idx); int32_t solution_idx);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment