Unverified Commit 380ce400 authored by Yuting Jiang's avatar Yuting Jiang Committed by GitHub
Browse files

Bug: Fix code incesure issue of integer overflow in cublas function (#290)

**Description**
Fix insecure issue of Multiplication result converted to larger type.

**Major Revision**
- Use a cast to ensure that the multiplication is done using the long long to avoid overflow.
parent 5f6ad0cd
......@@ -164,8 +164,9 @@ void gemmStridedBatchedEx(cublasHandle_t handle, int transa, int transb, int m,
}
CUBLAS_SAFE_CALL(cublasGemmStridedBatchedEx(handle, (transa ? CUBLAS_OP_T : CUBLAS_OP_N),
(transb ? CUBLAS_OP_T : CUBLAS_OP_N), m, n, k, &alpha, a, matrix_type,
(transa ? k : m), m * k, b, matrix_type, (transb ? n : k), n * k, &beta,
c, matrix_type, m, m * n, batchCount, compute_type, algo));
(transa ? k : m), static_cast<long long>(m) * k, b, matrix_type,
(transb ? n : k), static_cast<long long>(n) * k, &beta, c, matrix_type,
m, static_cast<long long>(m) * n, batchCount, compute_type, algo));
}
/**
......@@ -187,7 +188,8 @@ void sgemmStridedBatched(cublasHandle_t handle, int transa, int transb, int m, i
float beta = 1.0f;
CUBLAS_SAFE_CALL(cublasSgemmStridedBatched(
handle, (transa ? CUBLAS_OP_T : CUBLAS_OP_N), (transb ? CUBLAS_OP_T : CUBLAS_OP_N), m, n, k, &alpha, a,
(transa ? k : m), m * k, b, (transb ? n : k), n * k, &beta, c, m, m * n, batchCount));
(transa ? k : m), static_cast<long long>(m) * k, b, (transb ? n : k), static_cast<long long>(n) * k, &beta, c,
m, static_cast<long long>(m) * n, batchCount));
}
/**
......@@ -210,5 +212,6 @@ void cgemm3mStridedBatched(cublasHandle_t handle, int transa, int transb, int m,
cuComplex beta = make_cuComplex(0.0f, 0.0f);
CUBLAS_SAFE_CALL(cublasCgemm3mStridedBatched(
handle, (transa ? CUBLAS_OP_T : CUBLAS_OP_N), (transb ? CUBLAS_OP_T : CUBLAS_OP_N), m, n, k, &alpha, a,
(transa ? k : m), m * k, b, (transb ? n : k), n * k, &beta, c, m, m * n, batchCount));
(transa ? k : m), static_cast<long long>(m) * k, b, (transb ? n : k), static_cast<long long>(n) * k, &beta, c,
m, static_cast<long long>(m) * n, batchCount));
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment