Unverified Commit 380ce400 authored by Yuting Jiang's avatar Yuting Jiang Committed by GitHub
Browse files

Bug: Fix code incesure issue of integer overflow in cublas function (#290)

**Description**
Fix insecure issue of Multiplication result converted to larger type.

**Major Revision**
- Use a cast to ensure that the multiplication is done using the long long to avoid overflow.
parent 5f6ad0cd
...@@ -164,8 +164,9 @@ void gemmStridedBatchedEx(cublasHandle_t handle, int transa, int transb, int m, ...@@ -164,8 +164,9 @@ void gemmStridedBatchedEx(cublasHandle_t handle, int transa, int transb, int m,
} }
CUBLAS_SAFE_CALL(cublasGemmStridedBatchedEx(handle, (transa ? CUBLAS_OP_T : CUBLAS_OP_N), CUBLAS_SAFE_CALL(cublasGemmStridedBatchedEx(handle, (transa ? CUBLAS_OP_T : CUBLAS_OP_N),
(transb ? CUBLAS_OP_T : CUBLAS_OP_N), m, n, k, &alpha, a, matrix_type, (transb ? CUBLAS_OP_T : CUBLAS_OP_N), m, n, k, &alpha, a, matrix_type,
(transa ? k : m), m * k, b, matrix_type, (transb ? n : k), n * k, &beta, (transa ? k : m), static_cast<long long>(m) * k, b, matrix_type,
c, matrix_type, m, m * n, batchCount, compute_type, algo)); (transb ? n : k), static_cast<long long>(n) * k, &beta, c, matrix_type,
m, static_cast<long long>(m) * n, batchCount, compute_type, algo));
} }
/** /**
...@@ -187,7 +188,8 @@ void sgemmStridedBatched(cublasHandle_t handle, int transa, int transb, int m, i ...@@ -187,7 +188,8 @@ void sgemmStridedBatched(cublasHandle_t handle, int transa, int transb, int m, i
float beta = 1.0f; float beta = 1.0f;
CUBLAS_SAFE_CALL(cublasSgemmStridedBatched( CUBLAS_SAFE_CALL(cublasSgemmStridedBatched(
handle, (transa ? CUBLAS_OP_T : CUBLAS_OP_N), (transb ? CUBLAS_OP_T : CUBLAS_OP_N), m, n, k, &alpha, a, handle, (transa ? CUBLAS_OP_T : CUBLAS_OP_N), (transb ? CUBLAS_OP_T : CUBLAS_OP_N), m, n, k, &alpha, a,
(transa ? k : m), m * k, b, (transb ? n : k), n * k, &beta, c, m, m * n, batchCount)); (transa ? k : m), static_cast<long long>(m) * k, b, (transb ? n : k), static_cast<long long>(n) * k, &beta, c,
m, static_cast<long long>(m) * n, batchCount));
} }
/** /**
...@@ -210,5 +212,6 @@ void cgemm3mStridedBatched(cublasHandle_t handle, int transa, int transb, int m, ...@@ -210,5 +212,6 @@ void cgemm3mStridedBatched(cublasHandle_t handle, int transa, int transb, int m,
cuComplex beta = make_cuComplex(0.0f, 0.0f); cuComplex beta = make_cuComplex(0.0f, 0.0f);
CUBLAS_SAFE_CALL(cublasCgemm3mStridedBatched( CUBLAS_SAFE_CALL(cublasCgemm3mStridedBatched(
handle, (transa ? CUBLAS_OP_T : CUBLAS_OP_N), (transb ? CUBLAS_OP_T : CUBLAS_OP_N), m, n, k, &alpha, a, handle, (transa ? CUBLAS_OP_T : CUBLAS_OP_N), (transb ? CUBLAS_OP_T : CUBLAS_OP_N), m, n, k, &alpha, a,
(transa ? k : m), m * k, b, (transb ? n : k), n * k, &beta, c, m, m * n, batchCount)); (transa ? k : m), static_cast<long long>(m) * k, b, (transb ? n : k), static_cast<long long>(n) * k, &beta, c,
m, static_cast<long long>(m) * n, batchCount));
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment