Unverified Commit 9d250cdd authored by Yifan Xiong's avatar Yifan Xiong Committed by GitHub
Browse files

Benchmark - Fix matrix size overflow issue in cuBLASLt GEMM (#503)

Fix matrix size overflow issue when cast from int to size_t implicitly.
parent 26373edb
...@@ -163,7 +163,7 @@ def run(self): ...@@ -163,7 +163,7 @@ def run(self):
'numpy>=1.19.2', 'numpy>=1.19.2',
'omegaconf==2.0.6', 'omegaconf==2.0.6',
'openpyxl>=3.0.7', 'openpyxl>=3.0.7',
'pandas>=1.1.5', 'pandas>=1.1.5, <2.0.0',
'pssh @ git+https://github.com/lilydjwg/pssh.git@v2.3.4', 'pssh @ git+https://github.com/lilydjwg/pssh.git@v2.3.4',
'pyyaml>=5.3', 'pyyaml>=5.3',
'requests>=2.27.1', 'requests>=2.27.1',
......
...@@ -88,20 +88,21 @@ template <typename T> cudaDataType_t get_datatype() { ...@@ -88,20 +88,21 @@ template <typename T> cudaDataType_t get_datatype() {
} }
template <typename Ta, typename Tb, typename Tout> template <typename Ta, typename Tb, typename Tout>
float timing_matmul_tn(int m, int n, int k, int batch, int warmup, int iter) { float timing_matmul_tn(size_t m, size_t n, size_t k, size_t batch, int warmup, int iter) {
// init matrix // init matrix
Ta *matrix_a = nullptr; Ta *matrix_a = nullptr;
Tb *matrix_b = nullptr; Tb *matrix_b = nullptr;
Tout *matrix_out = nullptr; Tout *matrix_out = nullptr;
cudaMalloc(&matrix_a, m * k * std::max(batch, 1) * sizeof(Ta)); batch = std::max<size_t>(batch, 1);
cudaMalloc(&matrix_b, k * n * std::max(batch, 1) * sizeof(Tb)); cudaMalloc(&matrix_a, m * k * batch * sizeof(Ta));
cudaMalloc(&matrix_out, m * n * std::max(batch, 1) * sizeof(Tout)); cudaMalloc(&matrix_b, k * n * batch * sizeof(Tb));
cudaMalloc(&matrix_out, m * n * batch * sizeof(Tout));
init_matrix<Ta><<<216, 1024>>>(matrix_a, 1.f, m * k * std::max(batch, 1)); init_matrix<Ta><<<216, 1024>>>(matrix_a, 1.f, m * k * batch);
init_matrix<Tb><<<216, 1024>>>(matrix_b, 2.f, k * n * std::max(batch, 1)); init_matrix<Tb><<<216, 1024>>>(matrix_b, 2.f, k * n * batch);
// init gemm // init gemm
int lda = k, ldb = k, ldd = m; size_t lda = k, ldb = k, ldd = m;
std::unique_ptr<cublasLtGemm> gemm = std::make_unique<cublasLtGemm>(); std::unique_ptr<cublasLtGemm> gemm = std::make_unique<cublasLtGemm>();
gemm->Init(); gemm->Init();
gemm->Setup(m, n, k, batch, lda, ldb, ldd, get_datatype<Ta>(), get_datatype<Tb>(), get_datatype<Tout>(), gemm->Setup(m, n, k, batch, lda, ldb, ldd, get_datatype<Ta>(), get_datatype<Tb>(), get_datatype<Tout>(),
......
...@@ -5,12 +5,12 @@ ...@@ -5,12 +5,12 @@
void cublasLtGemm::Init() { void cublasLtGemm::Init() {
cublasLtHandle_t handle; cublasLtHandle_t handle;
checkCublasStatus(cublasLtCreate(&handle)); CUBLAS_CHECK(cublasLtCreate(&handle));
handle_.reset(handle); handle_.reset(handle);
/* preference can be initialized without arguments */ /* preference can be initialized without arguments */
cublasLtMatmulPreference_t preference; cublasLtMatmulPreference_t preference;
checkCublasStatus(cublasLtMatmulPreferenceCreate(&preference)); CUBLAS_CHECK(cublasLtMatmulPreferenceCreate(&preference));
preference_.reset(preference); preference_.reset(preference);
} }
...@@ -24,32 +24,32 @@ void cublasLtGemm::Setup(int m, int n, int k, int batch, int lda, int ldb, int l ...@@ -24,32 +24,32 @@ void cublasLtGemm::Setup(int m, int n, int k, int batch, int lda, int ldb, int l
// force c_type // force c_type
cudaDataType_t c_type = d_type; cudaDataType_t c_type = d_type;
// Create matrix descriptors. // Create matrix descriptors.
checkCublasStatus( CUBLAS_CHECK(
cublasLtMatrixLayoutCreate(&a_desc, a_type, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda)); cublasLtMatrixLayoutCreate(&a_desc, a_type, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda));
checkCublasStatus( CUBLAS_CHECK(
cublasLtMatrixLayoutCreate(&b_desc, b_type, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb)); cublasLtMatrixLayoutCreate(&b_desc, b_type, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb));
checkCublasStatus(cublasLtMatrixLayoutCreate(&c_desc, c_type, m, n, ldd)); CUBLAS_CHECK(cublasLtMatrixLayoutCreate(&c_desc, c_type, m, n, ldd));
checkCublasStatus(cublasLtMatrixLayoutCreate(&d_desc, d_type, m, n, ldd)); CUBLAS_CHECK(cublasLtMatrixLayoutCreate(&d_desc, d_type, m, n, ldd));
// strided batch gemm // strided batch gemm
if (batch > 0) { if (batch > 0) {
int64_t stridea = m * k, strideb = k * n, stridec = m * n, strided = m * n; int64_t stridea = m * k, strideb = k * n, stridec = m * n, strided = m * n;
checkCublasStatus( CUBLAS_CHECK(
cublasLtMatrixLayoutSetAttribute(a_desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch))); cublasLtMatrixLayoutSetAttribute(a_desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
checkCublasStatus(cublasLtMatrixLayoutSetAttribute(a_desc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, CUBLAS_CHECK(cublasLtMatrixLayoutSetAttribute(a_desc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridea,
&stridea, sizeof(stridea))); sizeof(stridea)));
checkCublasStatus( CUBLAS_CHECK(
cublasLtMatrixLayoutSetAttribute(b_desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch))); cublasLtMatrixLayoutSetAttribute(b_desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
checkCublasStatus(cublasLtMatrixLayoutSetAttribute(b_desc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, CUBLAS_CHECK(cublasLtMatrixLayoutSetAttribute(b_desc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideb,
&strideb, sizeof(strideb))); sizeof(strideb)));
checkCublasStatus( CUBLAS_CHECK(
cublasLtMatrixLayoutSetAttribute(c_desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch))); cublasLtMatrixLayoutSetAttribute(c_desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
checkCublasStatus(cublasLtMatrixLayoutSetAttribute(c_desc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, CUBLAS_CHECK(cublasLtMatrixLayoutSetAttribute(c_desc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridec,
&stridec, sizeof(stridec))); sizeof(stridec)));
checkCublasStatus( CUBLAS_CHECK(
cublasLtMatrixLayoutSetAttribute(d_desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch))); cublasLtMatrixLayoutSetAttribute(d_desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
checkCublasStatus(cublasLtMatrixLayoutSetAttribute(d_desc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, CUBLAS_CHECK(cublasLtMatrixLayoutSetAttribute(d_desc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strided,
&strided, sizeof(strided))); sizeof(strided)));
} }
a_desc_.reset(a_desc); a_desc_.reset(a_desc);
b_desc_.reset(b_desc); b_desc_.reset(b_desc);
...@@ -64,7 +64,7 @@ void cublasLtGemm::Setup(int m, int n, int k, int batch, int lda, int ldb, int l ...@@ -64,7 +64,7 @@ void cublasLtGemm::Setup(int m, int n, int k, int batch, int lda, int ldb, int l
gemm_compute_type = CUBLAS_COMPUTE_64F; gemm_compute_type = CUBLAS_COMPUTE_64F;
cublasLtMatmulDesc_t op_desc = nullptr; cublasLtMatmulDesc_t op_desc = nullptr;
checkCublasStatus(cublasLtMatmulDescCreate(&op_desc, gemm_compute_type, CUDA_R_32F)); CUBLAS_CHECK(cublasLtMatmulDescCreate(&op_desc, gemm_compute_type, CUDA_R_32F));
op_desc_.reset(op_desc); op_desc_.reset(op_desc);
if (a_type == CUDA_R_8F_E5M2 || b_type == CUDA_R_8F_E5M2 || a_type == CUDA_R_8F_E4M3 || b_type == CUDA_R_8F_E4M3) { if (a_type == CUDA_R_8F_E5M2 || b_type == CUDA_R_8F_E5M2 || a_type == CUDA_R_8F_E4M3 || b_type == CUDA_R_8F_E4M3) {
...@@ -73,33 +73,31 @@ void cublasLtGemm::Setup(int m, int n, int k, int batch, int lda, int ldb, int l ...@@ -73,33 +73,31 @@ void cublasLtGemm::Setup(int m, int n, int k, int batch, int lda, int ldb, int l
cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fastAccuMode, sizeof(fastAccuMode)); cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fastAccuMode, sizeof(fastAccuMode));
} }
checkCublasStatus( CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(op_desc_.get(), CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa)));
cublasLtMatmulDescSetAttribute(op_desc_.get(), CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa))); CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(op_desc_.get(), CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transb)));
checkCublasStatus(
cublasLtMatmulDescSetAttribute(op_desc_.get(), CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transb)));
if (a_scale_inverse != nullptr) { if (a_scale_inverse != nullptr) {
checkCublasStatus(cublasLtMatmulDescSetAttribute(op_desc_.get(), CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(op_desc_.get(), CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
&a_scale_inverse, sizeof(a_scale_inverse))); &a_scale_inverse, sizeof(a_scale_inverse)));
} }
if (b_scale_inverse != nullptr) { if (b_scale_inverse != nullptr) {
checkCublasStatus(cublasLtMatmulDescSetAttribute(op_desc_.get(), CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(op_desc_.get(), CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
&b_scale_inverse, sizeof(b_scale_inverse))); &b_scale_inverse, sizeof(b_scale_inverse)));
} }
checkCublasStatus( CUBLAS_CHECK(
cublasLtMatmulDescSetAttribute(op_desc_.get(), CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue))); cublasLtMatmulDescSetAttribute(op_desc_.get(), CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)));
} }
size_t cublasLtGemm::GetAlgorithm(int max_algorithm_count, size_t max_workspace_size) { size_t cublasLtGemm::GetAlgorithm(int max_algorithm_count, size_t max_workspace_size) {
checkCublasStatus(cublasLtMatmulPreferenceSetAttribute(preference_.get(), CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, CUBLAS_CHECK(cublasLtMatmulPreferenceSetAttribute(preference_.get(), CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&max_workspace_size, sizeof(max_workspace_size))); &max_workspace_size, sizeof(max_workspace_size)));
int found_algorithm_count = 0; int found_algorithm_count = 0;
std::vector<cublasLtMatmulHeuristicResult_t> results(max_algorithm_count); std::vector<cublasLtMatmulHeuristicResult_t> results(max_algorithm_count);
// Though we query all of possible algorithm, we will use the first later // Though we query all of possible algorithm, we will use the first later
checkCublasStatus(cublasLtMatmulAlgoGetHeuristic(handle_.get(), op_desc_.get(), a_desc_.get(), b_desc_.get(), CUBLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(handle_.get(), op_desc_.get(), a_desc_.get(), b_desc_.get(),
c_desc_.get(), d_desc_.get(), preference_.get(), c_desc_.get(), d_desc_.get(), preference_.get(), max_algorithm_count,
max_algorithm_count, results.data(), &found_algorithm_count)); results.data(), &found_algorithm_count));
if (found_algorithm_count == 0) { if (found_algorithm_count == 0) {
throw std::runtime_error("Unable to find any suitable algorithms"); throw std::runtime_error("Unable to find any suitable algorithms");
} }
...@@ -111,13 +109,13 @@ size_t cublasLtGemm::GetAlgorithm(int max_algorithm_count, size_t max_workspace_ ...@@ -111,13 +109,13 @@ size_t cublasLtGemm::GetAlgorithm(int max_algorithm_count, size_t max_workspace_
void cublasLtGemm::Execute(void *matrix_a, void *matrix_b, void *matrix_c, void *matrix_d, float alpha, float beta, void cublasLtGemm::Execute(void *matrix_a, void *matrix_b, void *matrix_c, void *matrix_d, float alpha, float beta,
void *workspace, size_t workspace_size, cudaStream_t stream) { void *workspace, size_t workspace_size, cudaStream_t stream) {
checkCublasStatus(cublasLtMatmul(handle_.get(), op_desc_.get(), static_cast<const void *>(&alpha), /* alpha */ CUBLAS_CHECK(cublasLtMatmul(handle_.get(), op_desc_.get(), static_cast<const void *>(&alpha), /* alpha */
matrix_a, /* A */ matrix_a, /* A */
a_desc_.get(), matrix_b, /* B */ a_desc_.get(), matrix_b, /* B */
b_desc_.get(), static_cast<const void *>(&beta), /* beta */ b_desc_.get(), static_cast<const void *>(&beta), /* beta */
matrix_c, /* C */ matrix_c, /* C */
c_desc_.get(), matrix_d, /* D */ c_desc_.get(), matrix_d, /* D */
d_desc_.get(), &heuristic_results_.front().algo, /* algo */ d_desc_.get(), &heuristic_results_.front().algo, /* algo */
workspace, /* workspace */ workspace, /* workspace */
workspace_size, stream)); /* stream */ workspace_size, stream)); /* stream */
} }
...@@ -10,12 +10,14 @@ ...@@ -10,12 +10,14 @@
#include <cublasLt.h> #include <cublasLt.h>
inline void checkCublasStatus(cublasStatus_t status) { #define CUBLAS_CHECK(func) \
if (status != CUBLAS_STATUS_SUCCESS) { do { \
printf("cuBLAS API failed with status %s\n", cublasGetStatusString(status)); cublasStatus_t status = func; \
throw std::logic_error("cuBLAS API failed"); if (status != CUBLAS_STATUS_SUCCESS) { \
} printf("cuBLAS call %s failed at %s:%d '%s'\n", #func, __FILE__, __LINE__, cublasGetStatusString(status)); \
} exit(EXIT_FAILURE); \
} \
} while (0)
class cublasLtGemm { class cublasLtGemm {
public: public:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment