"vscode:/vscode.git/clone" did not exist on "9ba8b4801a7f811334c4f1c24dd1ebc721f9a4a9"
Commit 7892b274 authored by Paul's avatar Paul
Browse files

Remvoe stray if

parent 9ca6c1d8
......@@ -158,8 +158,6 @@ struct gemm_impl
{
beta = 0;
}
if(arg_type == rocblas_datatype_f16_r)
compute_type = rocblas_datatype_f32_r;
rocblas_gemm_flags flag = rocblas_gemm_flags_none;
#if ROCBLAS_VERSION_MAJOR < 3
......@@ -200,43 +198,43 @@ struct gemm_impl
{
output_type = rocblas_datatype_i32_r;
}
compute_type = output_type;
if(compute_fp32)
{
compute_type = output_type;
if(compute_fp32)
{
if(arg_type == rocblas_datatype_f16_r)
compute_type = rocblas_datatype_f32_r;
}
}
int8_flag = int8_x4_format ? rocblas_gemm_flags_pack_int8x4 : rocblas_gemm_flags_none;
int8_flag = int8_x4_format ? rocblas_gemm_flags_pack_int8x4 : rocblas_gemm_flags_none;
auto a_lens = input_shapes[0].lens();
auto b_lens = input_shapes[1].lens();
auto a_lens = input_shapes[0].lens();
auto b_lens = input_shapes[1].lens();
auto out_lens = output_shape.lens();
m = out_lens[dim_0];
n = out_lens[dim_1];
k = input_shapes[0].lens()[dim_1];
if(input_shapes[0].type() == shape::int8_type and (k % 4) != 0 and int8_x4_format)
{
MIGRAPHX_THROW("ROCBLAS_GEMM: k size of int8 type input must be multiple of 4!");
}
auto out_lens = output_shape.lens();
m = out_lens[dim_0];
n = out_lens[dim_1];
k = input_shapes[0].lens()[dim_1];
if(input_shapes[0].type() == shape::int8_type and (k % 4) != 0 and int8_x4_format)
{
MIGRAPHX_THROW("ROCBLAS_GEMM: k size of int8 type input must be multiple of 4!");
}
a_stride = get_batch_stride(input_shapes[0]);
b_stride = get_batch_stride(input_shapes[1]);
c_stride = get_batch_stride(input_shapes[2]);
d_stride = is_3inputs ? get_batch_stride(input_shapes[3]) : c_stride;
num_matrices = std::accumulate(out_lens.rbegin() + 2,
out_lens.rend(),
std::size_t{1},
std::multiplies<std::size_t>());
if(num_matrices == 1 or (num_matrices > 1 and b_stride == 0))
{
// If the batch dimension of B is broadcasted, then we can
// multiply m by the batch_size and use rocblas_gemm_ex
// instead of rocblas_gemm_strided_batched_ex.
m *= num_matrices;
strided_batched = false;
}
a_stride = get_batch_stride(input_shapes[0]);
b_stride = get_batch_stride(input_shapes[1]);
c_stride = get_batch_stride(input_shapes[2]);
d_stride = is_3inputs ? get_batch_stride(input_shapes[3]) : c_stride;
num_matrices = std::accumulate(out_lens.rbegin() + 2,
out_lens.rend(),
std::size_t{1},
std::multiplies<std::size_t>());
if(num_matrices == 1 or (num_matrices > 1 and b_stride == 0))
{
// If the batch dimension of B is broadcasted, then we can
// multiply m by the batch_size and use rocblas_gemm_ex
// instead of rocblas_gemm_strided_batched_ex.
m *= num_matrices;
strided_batched = false;
}
}
void run(context& ctx, const std::vector<argument>& input_args, int32_t solution_idx = 0) const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment