Commit e846ae72 authored by Paul's avatar Paul
Browse files

Format

parent dc71e23d
...@@ -160,51 +160,51 @@ struct gemm_impl ...@@ -160,51 +160,51 @@ struct gemm_impl
} }
if(arg_type == rocblas_datatype_f16_r) if(arg_type == rocblas_datatype_f16_r)
compute_type = rocblas_datatype_f32_r; compute_type = rocblas_datatype_f32_r;
rocblas_gemm_flags flag = rocblas_gemm_flags_none; rocblas_gemm_flags flag = rocblas_gemm_flags_none;
#if ROCBLAS_VERSION_MAJOR < 3 #if ROCBLAS_VERSION_MAJOR < 3
if(int8_x4_format) if(int8_x4_format)
flag = rocblas_gemm_flags_pack_int8x4; flag = rocblas_gemm_flags_pack_int8x4;
#endif #endif
// Create lambdas that will cast alpha, beta to the output shape's type // Create lambdas that will cast alpha, beta to the output shape's type
// and retain the values being pointed to // and retain the values being pointed to
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto alpha_r = as(alpha); auto alpha_r = as(alpha);
auto beta_r = as(beta); auto beta_r = as(beta);
if(compute_fp32) if(compute_fp32)
{
get_alpha = [=] { return &alpha; };
get_beta = [=] { return &beta; };
}
else
{
get_alpha = [=] { return &alpha_r; };
get_beta = [=] { return &beta_r; };
}
});
transa = is_transposed(input_shapes[0]);
transb = is_transposed(input_shapes[1]);
auto n_dim = output_shape.lens().size();
auto dim_0 = n_dim - 2;
auto dim_1 = n_dim - 1;
// Leading dimensions of matrices
lda = input_shapes[0].strides()[transa ? dim_1 : dim_0];
ldb = input_shapes[1].strides()[transb ? dim_1 : dim_0];
ldc = input_shapes[2].strides()[dim_0];
ldd = is_3inputs ? input_shapes[3].strides()[dim_0] : ldc;
arg_type = get_type(input_shapes[0].type());
output_type = arg_type;
if(output_type == rocblas_datatype_i8_r)
{ {
output_type = rocblas_datatype_i32_r; get_alpha = [=] { return &alpha; };
get_beta = [=] { return &beta; };
}
else
{
get_alpha = [=] { return &alpha_r; };
get_beta = [=] { return &beta_r; };
}
});
transa = is_transposed(input_shapes[0]);
transb = is_transposed(input_shapes[1]);
auto n_dim = output_shape.lens().size();
auto dim_0 = n_dim - 2;
auto dim_1 = n_dim - 1;
// Leading dimensions of matrices
lda = input_shapes[0].strides()[transa ? dim_1 : dim_0];
ldb = input_shapes[1].strides()[transb ? dim_1 : dim_0];
ldc = input_shapes[2].strides()[dim_0];
ldd = is_3inputs ? input_shapes[3].strides()[dim_0] : ldc;
arg_type = get_type(input_shapes[0].type());
output_type = arg_type;
if(output_type == rocblas_datatype_i8_r)
{
output_type = rocblas_datatype_i32_r;
} }
compute_type = output_type; compute_type = output_type;
if(compute_fp32) if(compute_fp32)
{ {
if(arg_type == rocblas_datatype_f16_r) if(arg_type == rocblas_datatype_f16_r)
compute_type = rocblas_datatype_f32_r; compute_type = rocblas_datatype_f32_r;
} }
int8_flag = int8_x4_format ? rocblas_gemm_flags_pack_int8x4 : rocblas_gemm_flags_none; int8_flag = int8_x4_format ? rocblas_gemm_flags_pack_int8x4 : rocblas_gemm_flags_none;
...@@ -218,22 +218,24 @@ struct gemm_impl ...@@ -218,22 +218,24 @@ struct gemm_impl
k = input_shapes[0].lens()[dim_1]; k = input_shapes[0].lens()[dim_1];
if(input_shapes[0].type() == shape::int8_type and (k % 4) != 0 and int8_x4_format) if(input_shapes[0].type() == shape::int8_type and (k % 4) != 0 and int8_x4_format)
{ {
MIGRAPHX_THROW("ROCBLAS_GEMM: k size of int8 type input must be multiple of 4!"); MIGRAPHX_THROW("ROCBLAS_GEMM: k size of int8 type input must be multiple of 4!");
} }
a_stride = get_batch_stride(input_shapes[0]); a_stride = get_batch_stride(input_shapes[0]);
b_stride = get_batch_stride(input_shapes[1]); b_stride = get_batch_stride(input_shapes[1]);
c_stride = get_batch_stride(input_shapes[2]); c_stride = get_batch_stride(input_shapes[2]);
d_stride = is_3inputs ? get_batch_stride(input_shapes[3]) : c_stride; d_stride = is_3inputs ? get_batch_stride(input_shapes[3]) : c_stride;
num_matrices = std::accumulate( num_matrices = std::accumulate(out_lens.rbegin() + 2,
out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>()); out_lens.rend(),
std::size_t{1},
std::multiplies<std::size_t>());
if(num_matrices == 1 or (num_matrices > 1 and b_stride == 0)) if(num_matrices == 1 or (num_matrices > 1 and b_stride == 0))
{ {
// If the batch dimension of B is broadcasted, then we can // If the batch dimension of B is broadcasted, then we can
// multiply m by the batch_size and use rocblas_gemm_ex // multiply m by the batch_size and use rocblas_gemm_ex
// instead of rocblas_gemm_strided_batched_ex. // instead of rocblas_gemm_strided_batched_ex.
m *= num_matrices; m *= num_matrices;
strided_batched = false; strided_batched = false;
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment