Commit ee33e2e7 authored by zhouxiang's avatar zhouxiang
Browse files

support dtk23.10

parent e432dbb0
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple from typing import Tuple
__dcu_version__ = '0.0.13' __dcu_version__ = '0.0.13+gite432dbb.abi0.dtk2310.torch1.13'
__version__ = '0.0.13' __version__ = '0.0.13'
short_version = __version__ short_version = __version__
......
...@@ -12,6 +12,6 @@ setuptools ...@@ -12,6 +12,6 @@ setuptools
shortuuid shortuuid
tiktoken tiktoken
torch torch
transformers>=4.33.0 transformers>=4.33.2
tritonclient[all] tritonclient[all]
uvicorn uvicorn
...@@ -37,14 +37,14 @@ __forceinline__ __device__ float copysignf_pos(float a, float b) ...@@ -37,14 +37,14 @@ __forceinline__ __device__ float copysignf_pos(float a, float b)
__inline__ __device__ float tanh_opt(float x) __inline__ __device__ float tanh_opt(float x)
{ {
#if (__CUDA_ARCH__ >= 750 && CUDART_VERSION >= 11000) // #if (__CUDA_ARCH__ >= 750 && CUDART_VERSION >= 11000)
float r; // float r;
asm("tanh.approx.f32 %0,%1; \n\t" : "=f"(r) : "f"(x)); // asm("tanh.approx.f32 %0,%1; \n\t" : "=f"(r) : "f"(x));
return r; // return r;
#else // #else
const float exp_val = -1.f * fabs(2 * x); const float exp_val = -1.f * fabs(2 * x);
return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x); return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x);
#endif // #endif
} }
template<typename T> template<typename T>
......
...@@ -7,11 +7,11 @@ ...@@ -7,11 +7,11 @@
namespace turbomind { namespace turbomind {
#if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 4) // #if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 4)
#define L2_CACHEHINT(size) ".L2::" #size "B" // #define L2_CACHEHINT(size) ".L2::" #size "B"
#else // #else
#define L2_CACHEHINT(size) #define L2_CACHEHINT(size)
#endif // #endif
template<typename T> template<typename T>
__inline__ __device__ void cp_async_cg_A(uint32_t smem_int_ptr, const T* __restrict__ src, bool mask) __inline__ __device__ void cp_async_cg_A(uint32_t smem_int_ptr, const T* __restrict__ src, bool mask)
......
...@@ -61,12 +61,12 @@ __inline__ __device__ uint transpose_m8n8_b16_movmatrix(uint a) ...@@ -61,12 +61,12 @@ __inline__ __device__ uint transpose_m8n8_b16_movmatrix(uint a)
__inline__ __device__ uint transpose_m8n8_b16(uint a, int lane_id) __inline__ __device__ uint transpose_m8n8_b16(uint a, int lane_id)
{ {
#if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 8) // #if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 8)
(void)lane_id; // (void)lane_id;
return transpose_m8n8_b16_movmatrix(a); // return transpose_m8n8_b16_movmatrix(a);
#else // #else
return transpose_m8n8_b16_warp_shuffle(a, lane_id); return transpose_m8n8_b16_warp_shuffle(a, lane_id);
#endif // #endif
} }
namespace ops { namespace ops {
......
...@@ -16,11 +16,11 @@ ...@@ -16,11 +16,11 @@
#pragma once #pragma once
#include <array> #include <array>
#include <assert.h> #include <assert.h>
#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0)) // #if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))
#include <cooperative_groups/reduce.h> // #include <cooperative_groups/reduce.h>
#else // #else
#include <cooperative_groups.h> #include <cooperative_groups.h>
#endif // #endif
#include "src/turbomind/utils/cuda_bf16_wrapper.h" #include "src/turbomind/utils/cuda_bf16_wrapper.h"
#include "src/turbomind/utils/cuda_type_utils.cuh" #include "src/turbomind/utils/cuda_type_utils.cuh"
#include <cuda_fp16.h> #include <cuda_fp16.h>
...@@ -244,15 +244,15 @@ __inline__ __device__ void cgBlockReduceSumElements(float* element_list, float* ...@@ -244,15 +244,15 @@ __inline__ __device__ void cgBlockReduceSumElements(float* element_list, float*
const int tid = cta.thread_rank(); const int tid = cta.thread_rank();
const int blockz = blockDim.x; const int blockz = blockDim.x;
for (int i = 0; i < NUM; i++) { for (int i = 0; i < NUM; i++) {
#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0)) // #if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))
cgBlockReduceSumElements_shm[i * blockz + tid] = cg::reduce(tile, element_list[i], cg::plus<float>()); // cgBlockReduceSumElements_shm[i * blockz + tid] = cg::reduce(tile, element_list[i], cg::plus<float>());
#else // #else
// TODO Add implementation here // TODO Add implementation here
if (threadIdx.x == 0 && blockIdx.x == 0) { if (threadIdx.x == 0 && blockIdx.x == 0) {
printf("[ERROR] Not support cgBlockReduceSumElements when CUDA < 11 \n"); printf("[ERROR] Not support cgBlockReduceSumElements when CUDA < 11 \n");
assert(false); assert(false);
} }
#endif // #endif
} }
cg::sync(cta); cg::sync(cta);
if (tid == 0) { if (tid == 0) {
......
...@@ -77,11 +77,11 @@ if (BUILD_MULTI_GPU) ...@@ -77,11 +77,11 @@ if (BUILD_MULTI_GPU)
target_link_libraries(nccl_utils PUBLIC ${NCCL_LIBRARIES} logger) target_link_libraries(nccl_utils PUBLIC ${NCCL_LIBRARIES} logger)
endif() endif()
add_library(cublasINT8MMWrapper STATIC cublasINT8MMWrapper.cc) # add_library(cublasINT8MMWrapper STATIC cublasINT8MMWrapper.cc)
#set_property(TARGET cublasINT8MMWrapper PROPERTY POSITION_INDEPENDENT_CODE ON) #set_property(TARGET cublasINT8MMWrapper PROPERTY POSITION_INDEPENDENT_CODE ON)
#set_property(TARGET cublasINT8MMWrapper PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) #set_property(TARGET cublasINT8MMWrapper PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
#target_link_libraries(cublasINT8MMWrapper PUBLIC cublasLt cudart curand cublasAlgoMap cublasMMWrapper cuda_utils logger) #target_link_libraries(cublasINT8MMWrapper PUBLIC cublasLt cudart curand cublasAlgoMap cublasMMWrapper cuda_utils logger)
target_link_libraries(cublasINT8MMWrapper PUBLIC cudart curand cublasAlgoMap cublasMMWrapper cuda_utils logger) # target_link_libraries(cublasINT8MMWrapper PUBLIC cudart curand cublasAlgoMap cublasMMWrapper cuda_utils logger)
if(ENABLE_FP8) if(ENABLE_FP8)
add_library(cublasFP8MMWrapper STATIC cublasFP8MMWrapper.cu) add_library(cublasFP8MMWrapper STATIC cublasFP8MMWrapper.cu)
...@@ -108,7 +108,7 @@ if (SPARSITY_SUPPORT) ...@@ -108,7 +108,7 @@ if (SPARSITY_SUPPORT)
target_link_libraries(gemm PUBLIC cusparse -lcusparseLt) target_link_libraries(gemm PUBLIC cusparse -lcusparseLt)
endif() endif()
add_library(cuda_fp8_utils STATIC cuda_fp8_utils.cu) # add_library(cuda_fp8_utils STATIC cuda_fp8_utils.cu)
#set_property(TARGET cuda_fp8_utils PROPERTY POSITION_INDEPENDENT_CODE ON) #set_property(TARGET cuda_fp8_utils PROPERTY POSITION_INDEPENDENT_CODE ON)
#set_property(TARGET cuda_fp8_utils PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) #set_property(TARGET cuda_fp8_utils PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
......
...@@ -44,9 +44,9 @@ ...@@ -44,9 +44,9 @@
#include "src/turbomind/utils/logger.h" #include "src/turbomind/utils/logger.h"
#if defined(CUDART_VERSION) && CUDART_VERSION < 11020 // #if defined(CUDART_VERSION) && CUDART_VERSION < 11020
#define CUDA_MEMORY_POOL_DISABLED #define CUDA_MEMORY_POOL_DISABLED
#endif // #endif
namespace turbomind { namespace turbomind {
......
...@@ -237,10 +237,10 @@ void cublasFP8MMWrapper::Gemm(__nv_bfloat16* res, ...@@ -237,10 +237,10 @@ void cublasFP8MMWrapper::Gemm(__nv_bfloat16* res,
cublasLtMatmulAlgoConfigSetAttribute( cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &(info.reductionScheme), sizeof(info.reductionScheme)); &algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &(info.reductionScheme), sizeof(info.reductionScheme));
#if (CUDART_VERSION >= 11000) // #if (CUDART_VERSION >= 11000)
cublasLtMatmulAlgoConfigSetAttribute( // cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(info.stages), sizeof(info.stages)); // &algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(info.stages), sizeof(info.stages));
#endif // #endif
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3) #if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
cublasLtMatmulAlgoConfigSetAttribute( cublasLtMatmulAlgoConfigSetAttribute(
...@@ -462,10 +462,10 @@ void cublasFP8MMWrapper::Gemm(__nv_fp8_e4m3* res, ...@@ -462,10 +462,10 @@ void cublasFP8MMWrapper::Gemm(__nv_fp8_e4m3* res,
cublasLtMatmulAlgoConfigSetAttribute( cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &(info.reductionScheme), sizeof(info.reductionScheme)); &algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &(info.reductionScheme), sizeof(info.reductionScheme));
#if (CUDART_VERSION >= 11000) // #if (CUDART_VERSION >= 11000)
cublasLtMatmulAlgoConfigSetAttribute( // cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(info.stages), sizeof(info.stages)); // &algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(info.stages), sizeof(info.stages));
#endif // #endif
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3) #if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
cublasLtMatmulAlgoConfigSetAttribute( cublasLtMatmulAlgoConfigSetAttribute(
......
...@@ -94,11 +94,11 @@ void cublasINT8MMWrapper::Gemm(int* res, ...@@ -94,11 +94,11 @@ void cublasINT8MMWrapper::Gemm(int* res,
{ {
mu_->lock(); mu_->lock();
cublasOperation_t opTranspose = CUBLAS_OP_T; cublasOperation_t opTranspose = CUBLAS_OP_T;
#if (CUDART_VERSION >= 11000) // #if (CUDART_VERSION >= 11000)
cublasComputeType_t computeType = CUBLAS_COMPUTE_32I; // cublasComputeType_t computeType = CUBLAS_COMPUTE_32I;
#else // #else
cudaDataType_t computeType = CUDA_R_32I; cudaDataType_t computeType = CUDA_R_32I;
#endif // #endif
cublasLtMatmulDesc_t matmulDesc; cublasLtMatmulDesc_t matmulDesc;
cublasLtMatrixLayout_t AtransformDesc = NULL; cublasLtMatrixLayout_t AtransformDesc = NULL;
cublasLtMatrixLayout_t BtransformDesc = NULL; cublasLtMatrixLayout_t BtransformDesc = NULL;
...@@ -106,16 +106,16 @@ void cublasINT8MMWrapper::Gemm(int* res, ...@@ -106,16 +106,16 @@ void cublasINT8MMWrapper::Gemm(int* res,
cublasLtOrder_t order_COL32 = CUBLASLT_ORDER_COL32; cublasLtOrder_t order_COL32 = CUBLASLT_ORDER_COL32;
cublasLtOrder_t order_matrixB; cublasLtOrder_t order_matrixB;
#if (CUDART_VERSION >= 11000) // #if (CUDART_VERSION >= 11000)
if (use_ORDER_COL32_2R_4R4_) { // if (use_ORDER_COL32_2R_4R4_) {
order_matrixB = CUBLASLT_ORDER_COL32_2R_4R4; // order_matrixB = CUBLASLT_ORDER_COL32_2R_4R4;
} // }
else { // else {
// order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C;
// }
// #else
order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C; order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C;
} // #endif
#else
order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C;
#endif
int ldaTransform = 32 * m; int ldaTransform = 32 * m;
int ldbTransform; int ldbTransform;
...@@ -128,11 +128,11 @@ void cublasINT8MMWrapper::Gemm(int* res, ...@@ -128,11 +128,11 @@ void cublasINT8MMWrapper::Gemm(int* res,
int ldcTransform = 32 * m; int ldcTransform = 32 * m;
// create matmulDesc // create matmulDesc
#if (CUDART_VERSION >= 11000) // #if (CUDART_VERSION >= 11000)
cublasLtMatmulDescCreate(&matmulDesc, computeType, CUDA_R_32I); // cublasLtMatmulDescCreate(&matmulDesc, computeType, CUDA_R_32I);
#else // #else
cublasLtMatmulDescCreate(&matmulDesc, computeType); cublasLtMatmulDescCreate(&matmulDesc, computeType);
#endif // #endif
cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opTranspose, sizeof(cublasOperation_t)); cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opTranspose, sizeof(cublasOperation_t));
cublasLtMatrixLayoutCreate(&AtransformDesc, CUDA_R_8I, m, k, ldaTransform); cublasLtMatrixLayoutCreate(&AtransformDesc, CUDA_R_8I, m, k, ldaTransform);
cublasLtMatrixLayoutSetAttribute(AtransformDesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL32, sizeof(order_COL32)); cublasLtMatrixLayoutSetAttribute(AtransformDesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL32, sizeof(order_COL32));
...@@ -187,10 +187,10 @@ void cublasINT8MMWrapper::Gemm(int* res, ...@@ -187,10 +187,10 @@ void cublasINT8MMWrapper::Gemm(int* res,
&algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(tmp_info.swizzle), sizeof(tmp_info.swizzle)); &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(tmp_info.swizzle), sizeof(tmp_info.swizzle));
cublasLtMatmulAlgoConfigSetAttribute( cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &(tmp_info.reductionScheme), sizeof(int)); &algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &(tmp_info.reductionScheme), sizeof(int));
#if (CUDART_VERSION >= 11000) // #if (CUDART_VERSION >= 11000)
cublasLtMatmulAlgoConfigSetAttribute( // cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(tmp_info.stages), sizeof(tmp_info.stages)); // &algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(tmp_info.stages), sizeof(tmp_info.stages));
#endif // #endif
} }
else { else {
findAlgo = 1; findAlgo = 1;
...@@ -215,16 +215,16 @@ void cublasINT8MMWrapper::Gemm(int* res, ...@@ -215,16 +215,16 @@ void cublasINT8MMWrapper::Gemm(int* res,
cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(swizzle), sizeof(swizzle)); cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(swizzle), sizeof(swizzle));
cublasLtMatmulAlgoConfigSetAttribute( cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &(reductionScheme), sizeof(int)); &algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &(reductionScheme), sizeof(int));
#if (CUDART_VERSION >= 11000) // #if (CUDART_VERSION >= 11000)
int stages; // int stages;
if (use_ORDER_COL32_2R_4R4_) { // if (use_ORDER_COL32_2R_4R4_) {
stages = 15; // stages = 15;
} // }
else { // else {
stages = 13; // stages = 13;
} // }
cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(stages), sizeof(stages)); // cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(stages), sizeof(stages));
#endif // #endif
} }
cublasLtMatmul(cublaslt_handle_, cublasLtMatmul(cublaslt_handle_,
...@@ -273,11 +273,11 @@ void cublasINT8MMWrapper::Gemm(int8_t* res, ...@@ -273,11 +273,11 @@ void cublasINT8MMWrapper::Gemm(int8_t* res,
// int8 gemm does not support CUBLAS_POINTER_MODE_DEVICE // int8 gemm does not support CUBLAS_POINTER_MODE_DEVICE
// cublasLtPointerMode_t pointerMode = CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO; // cublasLtPointerMode_t pointerMode = CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO;
cudaDataType_t scaleType = CUDA_R_32F; cudaDataType_t scaleType = CUDA_R_32F;
#if (CUDART_VERSION >= 11000) // #if (CUDART_VERSION >= 11000)
cublasComputeType_t computeType = CUBLAS_COMPUTE_32I; // cublasComputeType_t computeType = CUBLAS_COMPUTE_32I;
#else // #else
cudaDataType_t computeType = CUDA_R_32I; cudaDataType_t computeType = CUDA_R_32I;
#endif // #endif
cublasLtMatmulDesc_t matmulDesc; cublasLtMatmulDesc_t matmulDesc;
cublasLtMatrixLayout_t AtransformDesc = NULL; cublasLtMatrixLayout_t AtransformDesc = NULL;
cublasLtMatrixLayout_t BtransformDesc = NULL; cublasLtMatrixLayout_t BtransformDesc = NULL;
...@@ -285,16 +285,16 @@ void cublasINT8MMWrapper::Gemm(int8_t* res, ...@@ -285,16 +285,16 @@ void cublasINT8MMWrapper::Gemm(int8_t* res,
cublasLtOrder_t order_COL32 = CUBLASLT_ORDER_COL32; cublasLtOrder_t order_COL32 = CUBLASLT_ORDER_COL32;
cublasLtOrder_t order_matrixB; cublasLtOrder_t order_matrixB;
#if (CUDART_VERSION >= 11000) // #if (CUDART_VERSION >= 11000)
if (use_ORDER_COL32_2R_4R4_) { // if (use_ORDER_COL32_2R_4R4_) {
order_matrixB = CUBLASLT_ORDER_COL32_2R_4R4; // order_matrixB = CUBLASLT_ORDER_COL32_2R_4R4;
} // }
else { // else {
// order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C;
// }
// #else
order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C; order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C;
} // #endif
#else
order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C;
#endif
int ldaTransform = 32 * m; int ldaTransform = 32 * m;
...@@ -309,11 +309,11 @@ void cublasINT8MMWrapper::Gemm(int8_t* res, ...@@ -309,11 +309,11 @@ void cublasINT8MMWrapper::Gemm(int8_t* res,
int ldcTransform = 32 * m; int ldcTransform = 32 * m;
// create matmulDesc // create matmulDesc
#if (CUDART_VERSION >= 11000) // #if (CUDART_VERSION >= 11000)
cublasLtMatmulDescCreate(&matmulDesc, computeType, scaleType); // cublasLtMatmulDescCreate(&matmulDesc, computeType, scaleType);
#else // #else
cublasLtMatmulDescCreate(&matmulDesc, computeType); cublasLtMatmulDescCreate(&matmulDesc, computeType);
#endif // #endif
cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opTranspose, sizeof(cublasOperation_t)); cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opTranspose, sizeof(cublasOperation_t));
cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_SCALE_TYPE, &scaleType, sizeof(scaleType)); cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_SCALE_TYPE, &scaleType, sizeof(scaleType));
// cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointerMode, // cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointerMode,
...@@ -367,10 +367,10 @@ void cublasINT8MMWrapper::Gemm(int8_t* res, ...@@ -367,10 +367,10 @@ void cublasINT8MMWrapper::Gemm(int8_t* res,
&algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(tmp_info.swizzle), sizeof(tmp_info.swizzle)); &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(tmp_info.swizzle), sizeof(tmp_info.swizzle));
cublasLtMatmulAlgoConfigSetAttribute( cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &(tmp_info.reductionScheme), sizeof(int)); &algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &(tmp_info.reductionScheme), sizeof(int));
#if (CUDART_VERSION >= 11000) // #if (CUDART_VERSION >= 11000)
cublasLtMatmulAlgoConfigSetAttribute( // cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(tmp_info.stages), sizeof(tmp_info.stages)); // &algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(tmp_info.stages), sizeof(tmp_info.stages));
#endif // #endif
} }
else { else {
findAlgo = 1; findAlgo = 1;
...@@ -395,16 +395,16 @@ void cublasINT8MMWrapper::Gemm(int8_t* res, ...@@ -395,16 +395,16 @@ void cublasINT8MMWrapper::Gemm(int8_t* res,
cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(swizzle), sizeof(swizzle)); cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(swizzle), sizeof(swizzle));
cublasLtMatmulAlgoConfigSetAttribute( cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &(reductionScheme), sizeof(int)); &algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &(reductionScheme), sizeof(int));
#if (CUDART_VERSION >= 11000) // #if (CUDART_VERSION >= 11000)
int stages; // int stages;
if (use_ORDER_COL32_2R_4R4_) { // if (use_ORDER_COL32_2R_4R4_) {
stages = 15; // stages = 15;
} // }
else { // else {
stages = 13; // stages = 13;
} // }
cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(stages), sizeof(stages)); // cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(stages), sizeof(stages));
#endif // #endif
} }
float beta = 0.0f; float beta = 0.0f;
......
...@@ -192,118 +192,119 @@ void cublasMMWrapper::Gemm(cublasOperation_t transa, ...@@ -192,118 +192,119 @@ void cublasMMWrapper::Gemm(cublasOperation_t transa,
} }
} }
// if (using_cublasLt) {
if (0) {
cublasLtMatmulDesc_t operationDesc = NULL;
cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
cudaDataType_t scaleType;
#if (CUDART_VERSION >= 11000)
cublasComputeType_t computeType;
#else
cudaDataType_t computeType;
#endif
if (is_fp16_computeType) {
#if (CUDART_VERSION >= 11000)
computeType = CUBLAS_COMPUTE_16F;
#else
computeType = CUDA_R_16F;
#endif
scaleType = CUDA_R_16F;
}
else {
#if (CUDART_VERSION >= 11000)
computeType = CUBLAS_COMPUTE_32F;
#else
computeType = CUDA_R_32F;
#endif
scaleType = CUDA_R_32F;
}
// --------------------------------------
// Create descriptors for the original matrices
cublasLtMatrixLayoutCreate(&Adesc, Atype_, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
cublasLtMatrixLayoutCreate(&Bdesc, Btype_, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
cublasLtMatrixLayoutCreate(&Cdesc, Ctype_, m, n, ldc);
#if (CUDART_VERSION >= 11000)
cublasLtMatmulDescCreate(&operationDesc, computeType, scaleType);
#else
cublasLtMatmulDescCreate(&operationDesc, computeType);
#endif
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(cublasOperation_t));
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(cublasOperation_t));
cublasLtMatmulAlgo_t algo; // if (using_cublasLt) {
void* workSpace = cublas_workspace_; // if (0) {
int workspaceSize = cublas_workspace_ == NULL ? 0 : CUBLAS_WORKSPACE_SIZE; // cublasLtMatmulDesc_t operationDesc = NULL;
if (findAlgo) { // cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
if (info.workspaceSize > workspaceSize) { // cudaDataType_t scaleType;
findAlgo = 0; // #if (CUDART_VERSION >= 11000)
} // cublasComputeType_t computeType;
else { // #else
cublasLtMatmulAlgoInit( // cudaDataType_t computeType;
cublaslt_handle_, computeType, scaleType, Atype_, Btype_, Ctype_, Ctype_, info.algoId, &algo); // #endif
cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &(info.customOption), sizeof(info.customOption)); // if (is_fp16_computeType) {
cublasLtMatmulAlgoConfigSetAttribute( // #if (CUDART_VERSION >= 11000)
&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, &(info.tile), sizeof(info.tile)); // computeType = CUBLAS_COMPUTE_16F;
cublasLtMatmulAlgoConfigSetAttribute( // #else
&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &(info.splitK_val), sizeof(info.splitK_val)); // computeType = CUDA_R_16F;
cublasLtMatmulAlgoConfigSetAttribute( // #endif
&algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(info.swizzle), sizeof(info.swizzle)); // scaleType = CUDA_R_16F;
cublasLtMatmulAlgoConfigSetAttribute(&algo, // }
CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, // else {
&(info.reductionScheme), // #if (CUDART_VERSION >= 11000)
sizeof(info.reductionScheme)); // computeType = CUBLAS_COMPUTE_32F;
// #else
#if (CUDART_VERSION >= 11000) // computeType = CUDA_R_32F;
cublasLtMatmulAlgoConfigSetAttribute( // #endif
&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(info.stages), sizeof(info.stages)); // scaleType = CUDA_R_32F;
#endif // }
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3) // // --------------------------------------
cublasLtMatmulAlgoConfigSetAttribute( // // Create descriptors for the original matrices
&algo, CUBLASLT_ALGO_CONFIG_INNER_SHAPE_ID, &(info.inner_shapeId), sizeof(info.inner_shapeId)); // cublasLtMatrixLayoutCreate(&Adesc, Atype_, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
cublasLtMatmulAlgoConfigSetAttribute(&algo, // cublasLtMatrixLayoutCreate(&Bdesc, Btype_, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
CUBLASLT_ALGO_CONFIG_CLUSTER_SHAPE_ID, // cublasLtMatrixLayoutCreate(&Cdesc, Ctype_, m, n, ldc);
&(info.cluster_shapeId), // #if (CUDART_VERSION >= 11000)
sizeof(info.cluster_shapeId)); // cublasLtMatmulDescCreate(&operationDesc, computeType, scaleType);
#elif (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH < 3) // #else
cublasLtMatmulAlgoConfigSetAttribute( // cublasLtMatmulDescCreate(&operationDesc, computeType);
&algo, CUBLASLT_ALGO_CONFIG_MMA_SHAPE_ID, &(info.mma_shapeId), sizeof(info.mma_shapeId)); // #endif
cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_CGA_SHAPE_ID, &(info.cga_shapeId), sizeof(info.cga_shapeId)); // cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(cublasOperation_t));
cublasLtMatmulAlgoConfigSetAttribute( // cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(cublasOperation_t));
&algo, CUBLASLT_ALGO_CONFIG_SCHEDULING_MODE, &(info.sche_mode), sizeof(info.sche_mode));
#endif // cublasLtMatmulAlgo_t algo;
} // void* workSpace = cublas_workspace_;
} // int workspaceSize = cublas_workspace_ == NULL ? 0 : CUBLAS_WORKSPACE_SIZE;
// if (findAlgo) {
// cublasLtMatmul(cublaslt_handle_, // if (info.workspaceSize > workspaceSize) {
// operationDesc, // findAlgo = 0;
// alpha, // }
// A, // else {
// Adesc, // cublasLtMatmulAlgoInit(
// B, // cublaslt_handle_, computeType, scaleType, Atype_, Btype_, Ctype_, Ctype_, info.algoId, &algo);
// Bdesc, // cublasLtMatmulAlgoConfigSetAttribute(
// beta, // &algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &(info.customOption), sizeof(info.customOption));
// C, // cublasLtMatmulAlgoConfigSetAttribute(
// Cdesc, // &algo, CUBLASLT_ALGO_CONFIG_TILE_ID, &(info.tile), sizeof(info.tile));
// C, // cublasLtMatmulAlgoConfigSetAttribute(
// Cdesc, // &algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &(info.splitK_val), sizeof(info.splitK_val));
// (findAlgo == 1 ? (&algo) : NULL), // cublasLtMatmulAlgoConfigSetAttribute(
// workSpace, // &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(info.swizzle), sizeof(info.swizzle));
// workspaceSize, // cublasLtMatmulAlgoConfigSetAttribute(&algo,
// stream_); // CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME,
// &(info.reductionScheme),
cublasLtMatmulDescDestroy(operationDesc); // sizeof(info.reductionScheme));
cublasLtMatrixLayoutDestroy(Adesc);
cublasLtMatrixLayoutDestroy(Bdesc); // // #if (CUDART_VERSION >= 11000)
cublasLtMatrixLayoutDestroy(Cdesc); // // cublasLtMatmulAlgoConfigSetAttribute(
sync_check_cuda_error(); // // &algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(info.stages), sizeof(info.stages));
} // // #endif
else {
// #if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
// cublasLtMatmulAlgoConfigSetAttribute(
// &algo, CUBLASLT_ALGO_CONFIG_INNER_SHAPE_ID, &(info.inner_shapeId), sizeof(info.inner_shapeId));
// cublasLtMatmulAlgoConfigSetAttribute(&algo,
// CUBLASLT_ALGO_CONFIG_CLUSTER_SHAPE_ID,
// &(info.cluster_shapeId),
// sizeof(info.cluster_shapeId));
// #elif (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH < 3)
// cublasLtMatmulAlgoConfigSetAttribute(
// &algo, CUBLASLT_ALGO_CONFIG_MMA_SHAPE_ID, &(info.mma_shapeId), sizeof(info.mma_shapeId));
// cublasLtMatmulAlgoConfigSetAttribute(
// &algo, CUBLASLT_ALGO_CONFIG_CGA_SHAPE_ID, &(info.cga_shapeId), sizeof(info.cga_shapeId));
// cublasLtMatmulAlgoConfigSetAttribute(
// &algo, CUBLASLT_ALGO_CONFIG_SCHEDULING_MODE, &(info.sche_mode), sizeof(info.sche_mode));
// #endif
// }
// }
// // cublasLtMatmul(cublaslt_handle_,
// // operationDesc,
// // alpha,
// // A,
// // Adesc,
// // B,
// // Bdesc,
// // beta,
// // C,
// // Cdesc,
// // C,
// // Cdesc,
// // (findAlgo == 1 ? (&algo) : NULL),
// // workSpace,
// // workspaceSize,
// // stream_);
// cublasLtMatmulDescDestroy(operationDesc);
// cublasLtMatrixLayoutDestroy(Adesc);
// cublasLtMatrixLayoutDestroy(Bdesc);
// cublasLtMatrixLayoutDestroy(Cdesc);
// sync_check_cuda_error();
// }
// else {
int cublasAlgo = info.algoId; int cublasAlgo = info.algoId;
check_cuda_error(cublasGemmEx(cublas_handle_, check_cuda_error(cublasGemmEx(cublas_handle_,
transa, transa,
...@@ -325,7 +326,7 @@ void cublasMMWrapper::Gemm(cublasOperation_t transa, ...@@ -325,7 +326,7 @@ void cublasMMWrapper::Gemm(cublasOperation_t transa,
computeType_, computeType_,
static_cast<cublasGemmAlgo_t>(cublasAlgo))); static_cast<cublasGemmAlgo_t>(cublasAlgo)));
sync_check_cuda_error(); sync_check_cuda_error();
} // }
mu_->unlock(); mu_->unlock();
} }
...@@ -382,81 +383,81 @@ CublasDataType cublasMMWrapper::getCublasDataType(cudaDataType_t data_type) ...@@ -382,81 +383,81 @@ CublasDataType cublasMMWrapper::getCublasDataType(cudaDataType_t data_type)
return FLOAT_DATATYPE; return FLOAT_DATATYPE;
} }
#if (CUDART_VERSION >= 11000) // #if (CUDART_VERSION >= 11000)
// input, weight, output are row-major // // input, weight, output are row-major
// only works for cublas 11.x // // only works for cublas 11.x
void cublasMMWrapper::Gemm(cublasOperation_t transa, // void cublasMMWrapper::Gemm(cublasOperation_t transa,
cublasOperation_t transb, // cublasOperation_t transb,
const int m, // const int m,
const int n, // const int n,
const int k, // const int k,
const void* A, // const void* A,
const int lda, // const int lda,
const void* B, // const void* B,
const int ldb, // const int ldb,
const void* bias, // const void* bias,
void* C, // void* C,
const int ldc) // const int ldc)
{ // {
TM_LOG_DEBUG(__PRETTY_FUNCTION__); // TM_LOG_DEBUG(__PRETTY_FUNCTION__);
cudaDataType_t Atype, Btype, Ctype; // cudaDataType_t Atype, Btype, Ctype;
cublasComputeType_t computeType; // cublasComputeType_t computeType;
cudaDataType_t scaleType; // cudaDataType_t scaleType;
float alpha_float = 1.0f; // float alpha_float = 1.0f;
float beta_float = 0.0f; // float beta_float = 0.0f;
half alpha_half = half(1.0f); // half alpha_half = half(1.0f);
half beta_half = half(0.0f); // half beta_half = half(0.0f);
void * alpha, *beta; // void * alpha, *beta;
// int is_fp16_computeType = computeType_ == CUDA_R_16F ? 1 : 0; // // int is_fp16_computeType = computeType_ == CUDA_R_16F ? 1 : 0;
if (Atype_ == CUDA_R_32F) { // if (Atype_ == CUDA_R_32F) {
computeType = CUBLAS_COMPUTE_32F_FAST_TF32; // computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
Atype = CUDA_R_32F; // Atype = CUDA_R_32F;
Btype = CUDA_R_32F; // Btype = CUDA_R_32F;
Ctype = CUDA_R_32F; // Ctype = CUDA_R_32F;
scaleType = CUDA_R_32F; // scaleType = CUDA_R_32F;
alpha = &alpha_float; // alpha = &alpha_float;
beta = &beta_float; // beta = &beta_float;
} // }
else if (Atype_ == CUDA_R_16BF) { // else if (Atype_ == CUDA_R_16BF) {
computeType = CUBLAS_COMPUTE_32F_FAST_TF32; // computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
Atype = CUDA_R_16BF; // Atype = CUDA_R_16BF;
Btype = CUDA_R_16BF; // Btype = CUDA_R_16BF;
Ctype = CUDA_R_16BF; // Ctype = CUDA_R_16BF;
scaleType = CUDA_R_32F; // scaleType = CUDA_R_32F;
alpha = &alpha_float; // alpha = &alpha_float;
beta = &beta_float; // beta = &beta_float;
} // }
else { // else {
computeType = CUBLAS_COMPUTE_16F; // computeType = CUBLAS_COMPUTE_16F;
Atype = CUDA_R_16F; // Atype = CUDA_R_16F;
Btype = CUDA_R_16F; // Btype = CUDA_R_16F;
Ctype = CUDA_R_16F; // Ctype = CUDA_R_16F;
scaleType = CUDA_R_16F; // scaleType = CUDA_R_16F;
alpha = &alpha_half; // alpha = &alpha_half;
beta = &beta_half; // beta = &beta_half;
} // }
cublasLtMatmulDesc_t operationDesc = NULL; // cublasLtMatmulDesc_t operationDesc = NULL;
cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL; // cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
cublasLtEpilogue_t epi = CUBLASLT_EPILOGUE_BIAS; // cublasLtEpilogue_t epi = CUBLASLT_EPILOGUE_BIAS;
cublasLtMatrixLayoutCreate(&Adesc, Atype, (transa == CUBLAS_OP_N) ? m : k, (transa == CUBLAS_OP_N) ? k : m, lda); // cublasLtMatrixLayoutCreate(&Adesc, Atype, (transa == CUBLAS_OP_N) ? m : k, (transa == CUBLAS_OP_N) ? k : m, lda);
cublasLtMatrixLayoutCreate(&Bdesc, Btype, (transb == CUBLAS_OP_N) ? k : n, (transb == CUBLAS_OP_N) ? n : k, ldb); // cublasLtMatrixLayoutCreate(&Bdesc, Btype, (transb == CUBLAS_OP_N) ? k : n, (transb == CUBLAS_OP_N) ? n : k, ldb);
cublasLtMatrixLayoutCreate(&Cdesc, Ctype, m, n, ldc); // cublasLtMatrixLayoutCreate(&Cdesc, Ctype, m, n, ldc);
cublasLtMatmulDescCreate(&operationDesc, computeType, scaleType); // cublasLtMatmulDescCreate(&operationDesc, computeType, scaleType);
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(cublasOperation_t)); // cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(cublasOperation_t));
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(cublasOperation_t)); // cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(cublasOperation_t));
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epi, sizeof(cublasLtEpilogue_t)); // cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epi, sizeof(cublasLtEpilogue_t));
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(const void*)); // cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(const void*));
// check_cuda_error(cublasLtMatmul( // // check_cuda_error(cublasLtMatmul(
// cublaslt_handle_, operationDesc, alpha, A, Adesc, B, Bdesc, beta, C, Cdesc, C, Cdesc, NULL, NULL, 0, stream_)); // // cublaslt_handle_, operationDesc, alpha, A, Adesc, B, Bdesc, beta, C, Cdesc, C, Cdesc, NULL, NULL, 0, stream_));
cublasLtMatrixLayoutDestroy(Adesc); // cublasLtMatrixLayoutDestroy(Adesc);
cublasLtMatrixLayoutDestroy(Bdesc); // cublasLtMatrixLayoutDestroy(Bdesc);
cublasLtMatrixLayoutDestroy(Cdesc); // cublasLtMatrixLayoutDestroy(Cdesc);
cublasLtMatmulDescDestroy(operationDesc); // cublasLtMatmulDescDestroy(operationDesc);
} // }
#endif // #endif
void cublasMMWrapper::setStream(cudaStream_t stream) void cublasMMWrapper::setStream(cudaStream_t stream)
{ {
stream_ = stream; stream_ = stream;
......
...@@ -207,20 +207,20 @@ public: ...@@ -207,20 +207,20 @@ public:
CublasDataType getCublasDataType(cudaDataType_t data_type); CublasDataType getCublasDataType(cudaDataType_t data_type);
#if (CUDART_VERSION >= 11000) // #if (CUDART_VERSION >= 11000)
void Gemm(cublasOperation_t transa, // void Gemm(cublasOperation_t transa,
cublasOperation_t transb, // cublasOperation_t transb,
const int m, // const int m,
const int n, // const int n,
const int k, // const int k,
const void* A, // const void* A,
const int lda, // const int lda,
const void* B, // const void* B,
const int ldb, // const int ldb,
const void* bias, // const void* bias,
void* C, // void* C,
const int ldc); // const int ldc);
#endif // #endif
void stridedBatchedGemm(cublasOperation_t transa, void stridedBatchedGemm(cublasOperation_t transa,
cublasOperation_t transb, cublasOperation_t transb,
......
...@@ -152,17 +152,17 @@ void initCustomAllReduceComm(std::vector<std::shared_ptr<AbstractCustomComm>>* c ...@@ -152,17 +152,17 @@ void initCustomAllReduceComm(std::vector<std::shared_ptr<AbstractCustomComm>>* c
return; return;
} }
#if defined(CUDART_VERSION) && CUDART_VERSION >= 11020 // #if defined(CUDART_VERSION) && CUDART_VERSION >= 11020
for (size_t i = 0; i < rank_size; i++) { // for (size_t i = 0; i < rank_size; i++) {
custom_all_reduce_comms->push_back(std::make_shared<CustomAllReduceComm<T>>(rank_size, i)); // custom_all_reduce_comms->push_back(std::make_shared<CustomAllReduceComm<T>>(rank_size, i));
} // }
custom_all_reduce_comms->at(0)->allocateAndExchangePeerAccessPointer(custom_all_reduce_comms); // custom_all_reduce_comms->at(0)->allocateAndExchangePeerAccessPointer(custom_all_reduce_comms);
#else // #else
TM_LOG_WARNING("Custom All Reduce is not supported before CUDA 11.2. Using NCCL as Comm."); TM_LOG_WARNING("Custom All Reduce is not supported before CUDA 11.2. Using NCCL as Comm.");
for (size_t i = 0; i < rank_size; i++) { for (size_t i = 0; i < rank_size; i++) {
custom_all_reduce_comms->push_back(nullptr); custom_all_reduce_comms->push_back(nullptr);
} }
#endif // #endif
} }
// Template instantiation // Template instantiation
......
...@@ -269,82 +269,82 @@ void Gemm::gemm(const GemmOp transa, ...@@ -269,82 +269,82 @@ void Gemm::gemm(const GemmOp transa,
} }
// if (using_cublasLt) { // if (using_cublasLt) {
if(0) { // if(0) {
const size_t a_rows = (a_op == getCublasOperation(GEMM_OP_N)) ? _m : k; // const size_t a_rows = (a_op == getCublasOperation(GEMM_OP_N)) ? _m : k;
const size_t a_cols = (a_op == getCublasOperation(GEMM_OP_N)) ? k : _m; // const size_t a_cols = (a_op == getCublasOperation(GEMM_OP_N)) ? k : _m;
const size_t b_rows = (b_op == getCublasOperation(GEMM_OP_N)) ? k : _n; // const size_t b_rows = (b_op == getCublasOperation(GEMM_OP_N)) ? k : _n;
const size_t b_cols = (b_op == getCublasOperation(GEMM_OP_N)) ? _n : k; // const size_t b_cols = (b_op == getCublasOperation(GEMM_OP_N)) ? _n : k;
cublasLtMatmulDesc_t matmul_desc = NULL; // cublasLtMatmulDesc_t matmul_desc = NULL;
cublasLtMatrixLayout_t a_desc = NULL, b_desc = NULL, c_desc = NULL; // cublasLtMatrixLayout_t a_desc = NULL, b_desc = NULL, c_desc = NULL;
cudaDataType_t scale_type = getCublasDataType(compute_type_); // cudaDataType_t scale_type = getCublasDataType(compute_type_);
auto compute_type = getCublasComputeType(compute_type_); // auto compute_type = getCublasComputeType(compute_type_);
// -------------------------------------- // // --------------------------------------
// Create descriptors for the original matrices // // Create descriptors for the original matrices
cublasLtMatrixLayoutCreate(&a_desc, a_type, a_rows, a_cols, _lda); // cublasLtMatrixLayoutCreate(&a_desc, a_type, a_rows, a_cols, _lda);
cublasLtMatrixLayoutCreate(&b_desc, b_type, b_rows, b_cols, _ldb); // cublasLtMatrixLayoutCreate(&b_desc, b_type, b_rows, b_cols, _ldb);
cublasLtMatrixLayoutCreate(&c_desc, c_type, _m, _n, ldc); // cublasLtMatrixLayoutCreate(&c_desc, c_type, _m, _n, ldc);
#if (CUDART_VERSION >= 11000) // #if (CUDART_VERSION >= 11000)
cublasLtMatmulDescCreate(&matmul_desc, compute_type, scale_type); // cublasLtMatmulDescCreate(&matmul_desc, compute_type, scale_type);
#else // #else
cublasLtMatmulDescCreate(&matmul_desc, compute_type); // cublasLtMatmulDescCreate(&matmul_desc, compute_type);
#endif // #endif
cublasLtMatmulDescSetAttribute(matmul_desc, CUBLASLT_MATMUL_DESC_TRANSA, &a_op, sizeof(cublasOperation_t)); // cublasLtMatmulDescSetAttribute(matmul_desc, CUBLASLT_MATMUL_DESC_TRANSA, &a_op, sizeof(cublasOperation_t));
cublasLtMatmulDescSetAttribute(matmul_desc, CUBLASLT_MATMUL_DESC_TRANSB, &b_op, sizeof(cublasOperation_t)); // cublasLtMatmulDescSetAttribute(matmul_desc, CUBLASLT_MATMUL_DESC_TRANSB, &b_op, sizeof(cublasOperation_t));
cublasLtMatmulAlgo_t algo; // cublasLtMatmulAlgo_t algo;
void* workspace = workspace_; // void* workspace = workspace_;
int workspace_size = workspace_ == nullptr ? 0 : CUBLAS_WORKSPACE_SIZE; // int workspace_size = workspace_ == nullptr ? 0 : CUBLAS_WORKSPACE_SIZE;
if (findAlgo) { // if (findAlgo) {
if (info.workspaceSize > workspace_size) { // if (info.workspaceSize > workspace_size) {
findAlgo = 0; // findAlgo = 0;
} // }
else { // else {
cublasLtMatmulAlgoInit( // cublasLtMatmulAlgoInit(
cublaslt_handle_, compute_type, scale_type, a_type, b_type, c_type, c_type, info.algoId, &algo); // cublaslt_handle_, compute_type, scale_type, a_type, b_type, c_type, c_type, info.algoId, &algo);
cublasLtMatmulAlgoConfigSetAttribute( // cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &(info.customOption), sizeof(info.customOption)); // &algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &(info.customOption), sizeof(info.customOption));
cublasLtMatmulAlgoConfigSetAttribute( // cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, &(info.tile), sizeof(info.tile)); // &algo, CUBLASLT_ALGO_CONFIG_TILE_ID, &(info.tile), sizeof(info.tile));
cublasLtMatmulAlgoConfigSetAttribute( // cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &(info.splitK_val), sizeof(info.splitK_val)); // &algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &(info.splitK_val), sizeof(info.splitK_val));
cublasLtMatmulAlgoConfigSetAttribute( // cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(info.swizzle), sizeof(info.swizzle)); // &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(info.swizzle), sizeof(info.swizzle));
cublasLtMatmulAlgoConfigSetAttribute( // cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &(info.reductionScheme), sizeof(int)); // &algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &(info.reductionScheme), sizeof(int));
#if (CUDART_VERSION >= 11000) // #if (CUDART_VERSION >= 11000)
cublasLtMatmulAlgoConfigSetAttribute( // cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(info.stages), sizeof(info.stages)); // &algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(info.stages), sizeof(info.stages));
#endif // #endif
} // }
} // }
cublasLtMatmul(cublaslt_handle_, // cublasLtMatmul(cublaslt_handle_,
matmul_desc, // matmul_desc,
alpha_ptr, // alpha_ptr,
a_data_ptr, // a_data_ptr,
a_desc, // a_desc,
b_data_ptr, // b_data_ptr,
b_desc, // b_desc,
beta_ptr, // beta_ptr,
C, // C,
c_desc, // c_desc,
C, // C,
c_desc, // c_desc,
(findAlgo == 1 ? (&algo) : NULL), // (findAlgo == 1 ? (&algo) : NULL),
workspace, // workspace,
workspace_size, // workspace_size,
stream_); // stream_);
cublasLtMatmulDescDestroy(matmul_desc); // cublasLtMatmulDescDestroy(matmul_desc);
cublasLtMatrixLayoutDestroy(a_desc); // cublasLtMatrixLayoutDestroy(a_desc);
cublasLtMatrixLayoutDestroy(b_desc); // cublasLtMatrixLayoutDestroy(b_desc);
cublasLtMatrixLayoutDestroy(c_desc); // cublasLtMatrixLayoutDestroy(c_desc);
sync_check_cuda_error(); // sync_check_cuda_error();
} // }
else { // else {
cudaDataType_t compute_type = getCublasDataType(compute_type_); cudaDataType_t compute_type = getCublasDataType(compute_type_);
int cublas_algo = info.algoId; int cublas_algo = info.algoId;
check_cuda_error(cublasGemmEx(cublas_handle_, check_cuda_error(cublasGemmEx(cublas_handle_,
...@@ -367,7 +367,7 @@ void Gemm::gemm(const GemmOp transa, ...@@ -367,7 +367,7 @@ void Gemm::gemm(const GemmOp transa,
compute_type, compute_type,
static_cast<cublasGemmAlgo_t>(cublas_algo))); static_cast<cublasGemmAlgo_t>(cublas_algo)));
sync_check_cuda_error(); sync_check_cuda_error();
} // }
mutex_->unlock(); mutex_->unlock();
} }
...@@ -1035,19 +1035,19 @@ cudaDataType_t getCublasDataType(DataType dtype) ...@@ -1035,19 +1035,19 @@ cudaDataType_t getCublasDataType(DataType dtype)
} }
} }
#if (CUDART_VERSION >= 11000) // #if (CUDART_VERSION >= 11000)
cublasComputeType_t getCublasComputeType(DataType ctype) // cublasComputeType_t getCublasComputeType(DataType ctype)
{ // {
switch (ctype) { // switch (ctype) {
case TYPE_FP16: // case TYPE_FP16:
return CUBLAS_COMPUTE_16F; // return CUBLAS_COMPUTE_16F;
case TYPE_FP32: // case TYPE_FP32:
return CUBLAS_COMPUTE_32F; // return CUBLAS_COMPUTE_32F;
default: // default:
throw GemmNotSupportedException("Not supported cublas compute type."); // throw GemmNotSupportedException("Not supported cublas compute type.");
} // }
} // }
#else // #else
cudaDataType_t getCublasComputeType(DataType ctype) cudaDataType_t getCublasComputeType(DataType ctype)
{ {
switch (ctype) { switch (ctype) {
...@@ -1059,7 +1059,7 @@ cudaDataType_t getCublasComputeType(DataType ctype) ...@@ -1059,7 +1059,7 @@ cudaDataType_t getCublasComputeType(DataType ctype)
throw GemmNotSupportedException("Not supported cublas compute type."); throw GemmNotSupportedException("Not supported cublas compute type.");
} }
} }
#endif // #endif
cublasOperation_t getCublasOperation(GemmOp op) cublasOperation_t getCublasOperation(GemmOp op)
{ {
......
...@@ -622,11 +622,11 @@ std::shared_ptr<Gemm> ...@@ -622,11 +622,11 @@ std::shared_ptr<Gemm>
createGemm(IAllocator* allocator, cudaStream_t stream, bool sparse = false, bool quantized = false); createGemm(IAllocator* allocator, cudaStream_t stream, bool sparse = false, bool quantized = false);
cudaDataType_t getCublasDataType(DataType dtype); cudaDataType_t getCublasDataType(DataType dtype);
#if (CUDART_VERSION >= 11000) // #if (CUDART_VERSION >= 11000)
cublasComputeType_t getCublasComputeType(DataType dtype); // cublasComputeType_t getCublasComputeType(DataType dtype);
#else // #else
cudaDataType_t getCublasComputeType(DataType dtype); cudaDataType_t getCublasComputeType(DataType dtype);
#endif // #endif
cublasOperation_t getCublasOperation(GemmOp op); cublasOperation_t getCublasOperation(GemmOp op);
std::string getGemmOpString(const GemmOp& op); std::string getGemmOpString(const GemmOp& op);
......
...@@ -82,11 +82,11 @@ int printPerfStructure(int m, int n, int k, const customMatmulPerf_t& perf, FILE ...@@ -82,11 +82,11 @@ int printPerfStructure(int m, int n, int k, const customMatmulPerf_t& perf, FILE
matmulAlgo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &swizzle, sizeof(swizzle), NULL); matmulAlgo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &swizzle, sizeof(swizzle), NULL);
cublasLtMatmulAlgoConfigGetAttribute( cublasLtMatmulAlgoConfigGetAttribute(
matmulAlgo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &customOption, sizeof(customOption), NULL); matmulAlgo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &customOption, sizeof(customOption), NULL);
#if (CUDART_VERSION >= 11000) // #if (CUDART_VERSION >= 11000)
cublasLtMatmulAlgoConfigGetAttribute(matmulAlgo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &stages, sizeof(stages), NULL); // cublasLtMatmulAlgoConfigGetAttribute(matmulAlgo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &stages, sizeof(stages), NULL);
#else // #else
stages = 0; stages = 0;
#endif // #endif
printf("algo={ Id=%d, tileIdx=%d (%s) splitK=%d reduc=%d swizzle=%d custom=%d stages=%d} status %d " printf("algo={ Id=%d, tileIdx=%d (%s) splitK=%d reduc=%d swizzle=%d custom=%d stages=%d} status %d "
"time %f workspace=%d mathMode=%d waves=%f\n", "time %f workspace=%d mathMode=%d waves=%f\n",
...@@ -148,11 +148,11 @@ int printBatchPerfStructure( ...@@ -148,11 +148,11 @@ int printBatchPerfStructure(
matmulAlgo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &swizzle, sizeof(swizzle), NULL); matmulAlgo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &swizzle, sizeof(swizzle), NULL);
cublasLtMatmulAlgoConfigGetAttribute( cublasLtMatmulAlgoConfigGetAttribute(
matmulAlgo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &customOption, sizeof(customOption), NULL); matmulAlgo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &customOption, sizeof(customOption), NULL);
#if (CUDART_VERSION >= 11000) // #if (CUDART_VERSION >= 11000)
cublasLtMatmulAlgoConfigGetAttribute(matmulAlgo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &stages, sizeof(stages), NULL); // cublasLtMatmulAlgoConfigGetAttribute(matmulAlgo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &stages, sizeof(stages), NULL);
#else // #else
stages = 0; stages = 0;
#endif // #endif
printf("algo={ Id=%d, tileIdx=%d (%s) splitK=%d reduc=%d swizzle=%d custom=%d stages=%d} status %d " printf("algo={ Id=%d, tileIdx=%d (%s) splitK=%d reduc=%d swizzle=%d custom=%d stages=%d} status %d "
"time %f workspace=%d mathMode=%d waves=%f\n", "time %f workspace=%d mathMode=%d waves=%f\n",
...@@ -279,693 +279,693 @@ static cublasStatus_t customMatmulRun(cublasLtHandle_t ltHandle, // ...@@ -279,693 +279,693 @@ static cublasStatus_t customMatmulRun(cublasLtHandle_t ltHandle, //
// Sample wrapper running through multiple algo and config attributes combination for INT8 gemm using cublasLt low-level // Sample wrapper running through multiple algo and config attributes combination for INT8 gemm using cublasLt low-level
// API // API
template<typename T, typename scaleT> // template<typename T, typename scaleT>
int LtIgemmCustomFind(cublasLtHandle_t ltHandle, // int LtIgemmCustomFind(cublasLtHandle_t ltHandle,
int m, // int m,
int n, // int n,
int k, // int k,
const scaleT* alpha, /* host pointer */ // const scaleT* alpha, /* host pointer */
const int8_t* A, // const int8_t* A,
const int8_t* B, // const int8_t* B,
const scaleT* beta, /* host pointer */ // const scaleT* beta, /* host pointer */
T* C, // T* C,
void* workSpace, // void* workSpace,
size_t workSpaceSize, // size_t workSpaceSize,
FILE* fout) // FILE* fout)
{ // {
cublasStatus_t status = CUBLAS_STATUS_SUCCESS; // cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
cublasLtMatmulDesc_t operationDesc = NULL; // cublasLtMatmulDesc_t operationDesc = NULL;
cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL; // cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
cudaStream_t stream = 0; // cudaStream_t stream = 0;
// SplitK value that we are going to try when SplitK is supported for a given algo // // SplitK value that we are going to try when SplitK is supported for a given algo
const int splitKSequenceA[] = {2, 3, 4, 5, 6, 8, 12, 16, 32}; // const int splitKSequenceA[] = {2, 3, 4, 5, 6, 8, 12, 16, 32};
// Let try a fixed number of combinations // // Let try a fixed number of combinations
#define ALGO_COMBINATIONS 50000 // #define ALGO_COMBINATIONS 50000
int AlgoCombinations = ALGO_COMBINATIONS; // int AlgoCombinations = ALGO_COMBINATIONS;
int AlgoCount = 0; // int AlgoCount = 0;
int kernelRepeats = 100; // number of time the CUDA kernels will be run back to back // int kernelRepeats = 100; // number of time the CUDA kernels will be run back to back
customMatmulPerf_t perfResults[ALGO_COMBINATIONS]; // customMatmulPerf_t perfResults[ALGO_COMBINATIONS];
int nbAlgoIds = 0; // int nbAlgoIds = 0;
#define ALGO_IDS 100 // #define ALGO_IDS 100
int algoIdA[ALGO_IDS]; // int algoIdA[ALGO_IDS];
cudaDataType_t Atype, Btype, Ctype, scaleType; // cudaDataType_t Atype, Btype, Ctype, scaleType;
Atype = CUDA_R_8I; // Atype = CUDA_R_8I;
Btype = CUDA_R_8I; // Btype = CUDA_R_8I;
if (std::is_same<T, int32_t>::value && std::is_same<scaleT, int>::value) { // if (std::is_same<T, int32_t>::value && std::is_same<scaleT, int>::value) {
Ctype = CUDA_R_32I; // Ctype = CUDA_R_32I;
scaleType = CUDA_R_32I; // scaleType = CUDA_R_32I;
} // }
else if (std::is_same<T, int8_t>::value && std::is_same<scaleT, float>::value) { // else if (std::is_same<T, int8_t>::value && std::is_same<scaleT, float>::value) {
Ctype = CUDA_R_8I; // Ctype = CUDA_R_8I;
scaleType = CUDA_R_32F; // scaleType = CUDA_R_32F;
} // }
else { // else {
printf("[ERROR]<T,scaleT> of igemm is invalid\n"); // printf("[ERROR]<T,scaleT> of igemm is invalid\n");
exit(-1); // exit(-1);
} // }
#if (CUDART_VERSION >= 11000) // // #if (CUDART_VERSION >= 11000)
cublasComputeType_t computeType = CUBLAS_COMPUTE_32I; // // cublasComputeType_t computeType = CUBLAS_COMPUTE_32I;
#else // // #else
cudaDataType_t computeType = CUDA_R_32I; // cudaDataType_t computeType = CUDA_R_32I;
#endif // // #endif
cublasOperation_t opTranspose = CUBLAS_OP_T; // cublasOperation_t opTranspose = CUBLAS_OP_T;
bool use_ORDER_COL32_2R_4R4 = false; // bool use_ORDER_COL32_2R_4R4 = false;
#if (CUDART_VERSION >= 11000) // // #if (CUDART_VERSION >= 11000)
int device{-1}; // // int device{-1};
cudaGetDevice(&device); // // cudaGetDevice(&device);
cudaDeviceProp props; // // cudaDeviceProp props;
cudaGetDeviceProperties(&props, device); // // cudaGetDeviceProperties(&props, device);
if (props.major * 10 + props.minor >= 80) { // // if (props.major * 10 + props.minor >= 80) {
use_ORDER_COL32_2R_4R4 = true; // // use_ORDER_COL32_2R_4R4 = true;
} // // }
#endif // // #endif
cublasLtOrder_t order_COL32 = CUBLASLT_ORDER_COL32; // cublasLtOrder_t order_COL32 = CUBLASLT_ORDER_COL32;
cublasLtOrder_t order_matrixB; // cublasLtOrder_t order_matrixB;
#if (CUDART_VERSION >= 11000) // // #if (CUDART_VERSION >= 11000)
if (use_ORDER_COL32_2R_4R4) { // // if (use_ORDER_COL32_2R_4R4) {
order_matrixB = CUBLASLT_ORDER_COL32_2R_4R4; // // order_matrixB = CUBLASLT_ORDER_COL32_2R_4R4;
} // // }
else { // // else {
order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C; // // order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C;
} // // }
#else // // #else
order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C; // order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C;
#endif // // #endif
int ldaTransform = 32 * m; // int ldaTransform = 32 * m;
int ldbTransform; // int ldbTransform;
if (use_ORDER_COL32_2R_4R4) { // if (use_ORDER_COL32_2R_4R4) {
ldbTransform = 32 * ((n + 32 - 1) / 32) * 32; // ldbTransform = 32 * ((n + 32 - 1) / 32) * 32;
} // }
else { // else {
ldbTransform = 32 * ((n + 8 - 1) / 8) * 8; // ldbTransform = 32 * ((n + 8 - 1) / 8) * 8;
} // }
int ldcTransform = 32 * m; // int ldcTransform = 32 * m;
#if (CUDART_VERSION >= 11000) // // #if (CUDART_VERSION >= 11000)
status = cublasLtMatmulDescCreate(&operationDesc, computeType, scaleType); // // status = cublasLtMatmulDescCreate(&operationDesc, computeType, scaleType);
#else // // #else
status = cublasLtMatmulDescCreate(&operationDesc, scaleType); // status = cublasLtMatmulDescCreate(&operationDesc, scaleType);
#endif // // #endif
if (status != CUBLAS_STATUS_SUCCESS) { // if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP; // goto CLEANUP;
} // }
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opTranspose, sizeof(cublasOperation_t)); // cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opTranspose, sizeof(cublasOperation_t));
// Create matrix descriptors. // // Create matrix descriptors.
status = cublasLtMatrixLayoutCreate(&Adesc, Atype, m, k, ldaTransform); // status = cublasLtMatrixLayoutCreate(&Adesc, Atype, m, k, ldaTransform);
if (status != CUBLAS_STATUS_SUCCESS) { // if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP; // goto CLEANUP;
} // }
status = cublasLtMatrixLayoutSetAttribute(Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL32, sizeof(order_COL32)); // status = cublasLtMatrixLayoutSetAttribute(Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL32, sizeof(order_COL32));
if (status != CUBLAS_STATUS_SUCCESS) { // if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP; // goto CLEANUP;
} // }
status = cublasLtMatrixLayoutCreate(&Bdesc, Btype, n, k, ldbTransform); // status = cublasLtMatrixLayoutCreate(&Bdesc, Btype, n, k, ldbTransform);
if (status != CUBLAS_STATUS_SUCCESS) { // if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP; // goto CLEANUP;
} // }
status = // status =
cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_matrixB, sizeof(order_matrixB)); // cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_matrixB, sizeof(order_matrixB));
if (status != CUBLAS_STATUS_SUCCESS) { // if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP; // goto CLEANUP;
} // }
status = cublasLtMatrixLayoutCreate(&Cdesc, Ctype, m, n, ldcTransform); // status = cublasLtMatrixLayoutCreate(&Cdesc, Ctype, m, n, ldcTransform);
if (status != CUBLAS_STATUS_SUCCESS) { // if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP; // goto CLEANUP;
} // }
status = cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL32, sizeof(order_COL32)); // status = cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL32, sizeof(order_COL32));
if (status != CUBLAS_STATUS_SUCCESS) { // if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP; // goto CLEANUP;
} // }
// Request AlgoId available for IGEMM // // Request AlgoId available for IGEMM
status = cublasLtMatmulAlgoGetIds( // status = cublasLtMatmulAlgoGetIds(
ltHandle, computeType, scaleType, Atype, Btype, Ctype, Ctype, ALGO_IDS, algoIdA, &nbAlgoIds); // ltHandle, computeType, scaleType, Atype, Btype, Ctype, Ctype, ALGO_IDS, algoIdA, &nbAlgoIds);
if (status != CUBLAS_STATUS_SUCCESS) { // if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP; // goto CLEANUP;
} // }
// Loop over the Algo IDs // // Loop over the Algo IDs
for (int idx = 0; (idx < nbAlgoIds) && (AlgoCount < AlgoCombinations); idx++) { // for (int idx = 0; (idx < nbAlgoIds) && (AlgoCount < AlgoCombinations); idx++) {
cublasLtMatmulAlgo_t algo; // cublasLtMatmulAlgo_t algo;
size_t sizeWritten = 0; // size_t sizeWritten = 0;
/* Initialize algo structure with given Algp ID */ // /* Initialize algo structure with given Algp ID */
status = // status =
cublasLtMatmulAlgoInit(ltHandle, computeType, scaleType, Atype, Btype, Ctype, Ctype, algoIdA[idx], &algo); // cublasLtMatmulAlgoInit(ltHandle, computeType, scaleType, Atype, Btype, Ctype, Ctype, algoIdA[idx], &algo);
if (status != CUBLAS_STATUS_SUCCESS) { // if (status != CUBLAS_STATUS_SUCCESS) {
continue; // continue;
} // }
// Query the tiles enums supported by that algo // // Query the tiles enums supported by that algo
cublasLtMatmulAlgoCapGetAttribute(&algo, CUBLASLT_ALGO_CAP_TILE_IDS, NULL, 0, &sizeWritten); // cublasLtMatmulAlgoCapGetAttribute(&algo, CUBLASLT_ALGO_CAP_TILE_IDS, NULL, 0, &sizeWritten);
int nbTiles = int(sizeWritten / sizeof(int)); // int nbTiles = int(sizeWritten / sizeof(int));
int* tileA = new int[nbTiles == 0 ? 1 : nbTiles]; // int* tileA = new int[nbTiles == 0 ? 1 : nbTiles];
if (nbTiles == 0) { // if (nbTiles == 0) {
tileA[0] = CUBLASLT_MATMUL_TILE_UNDEFINED; // tileA[0] = CUBLASLT_MATMUL_TILE_UNDEFINED;
nbTiles = 1; // nbTiles = 1;
} // }
#if (CUDART_VERSION >= 11000) // // #if (CUDART_VERSION >= 11000)
cublasLtMatmulAlgoCapGetAttribute(&algo, CUBLASLT_ALGO_CAP_STAGES_IDS, NULL, 0, &sizeWritten); // // cublasLtMatmulAlgoCapGetAttribute(&algo, CUBLASLT_ALGO_CAP_STAGES_IDS, NULL, 0, &sizeWritten);
int nbStages = int(sizeWritten / sizeof(int)); // // int nbStages = int(sizeWritten / sizeof(int));
std::vector<int> stagesA(nbStages == 0 ? 1 : nbStages); // // std::vector<int> stagesA(nbStages == 0 ? 1 : nbStages);
if (nbStages == 0) { // // if (nbStages == 0) {
stagesA[0] = CUBLASLT_MATMUL_STAGES_UNDEFINED; // // stagesA[0] = CUBLASLT_MATMUL_STAGES_UNDEFINED;
nbStages = 1; // // nbStages = 1;
} // // }
else { // // else {
cublasLtMatmulAlgoCapGetAttribute( // // cublasLtMatmulAlgoCapGetAttribute(
&algo, CUBLASLT_ALGO_CAP_STAGES_IDS, stagesA.data(), sizeof(int) * nbStages, &sizeWritten); // // &algo, CUBLASLT_ALGO_CAP_STAGES_IDS, stagesA.data(), sizeof(int) * nbStages, &sizeWritten);
} // // }
#endif // // #endif
int splitkSupport, redMask, swizzlingMax, customOptionMax; // int splitkSupport, redMask, swizzlingMax, customOptionMax;
// Retrieve Algo Capabilities attributes to be able to setup loop over the different combinations // // Retrieve Algo Capabilities attributes to be able to setup loop over the different combinations
cublasLtMatmulAlgoCapGetAttribute( // cublasLtMatmulAlgoCapGetAttribute(
&algo, CUBLASLT_ALGO_CAP_TILE_IDS, tileA, sizeof(int) * nbTiles, &sizeWritten); // &algo, CUBLASLT_ALGO_CAP_TILE_IDS, tileA, sizeof(int) * nbTiles, &sizeWritten);
cublasLtMatmulAlgoCapGetAttribute( // cublasLtMatmulAlgoCapGetAttribute(
&algo, CUBLASLT_ALGO_CAP_SPLITK_SUPPORT, &splitkSupport, sizeof(splitkSupport), &sizeWritten); // &algo, CUBLASLT_ALGO_CAP_SPLITK_SUPPORT, &splitkSupport, sizeof(splitkSupport), &sizeWritten);
cublasLtMatmulAlgoCapGetAttribute( // cublasLtMatmulAlgoCapGetAttribute(
&algo, CUBLASLT_ALGO_CAP_REDUCTION_SCHEME_MASK, &redMask, sizeof(redMask), &sizeWritten); // &algo, CUBLASLT_ALGO_CAP_REDUCTION_SCHEME_MASK, &redMask, sizeof(redMask), &sizeWritten);
cublasLtMatmulAlgoCapGetAttribute( // cublasLtMatmulAlgoCapGetAttribute(
&algo, CUBLASLT_ALGO_CAP_CTA_SWIZZLING_SUPPORT, &swizzlingMax, sizeof(swizzlingMax), &sizeWritten); // &algo, CUBLASLT_ALGO_CAP_CTA_SWIZZLING_SUPPORT, &swizzlingMax, sizeof(swizzlingMax), &sizeWritten);
cublasLtMatmulAlgoCapGetAttribute( // cublasLtMatmulAlgoCapGetAttribute(
&algo, CUBLASLT_ALGO_CAP_CUSTOM_OPTION_MAX, &customOptionMax, sizeof(customOptionMax), &sizeWritten); // &algo, CUBLASLT_ALGO_CAP_CUSTOM_OPTION_MAX, &customOptionMax, sizeof(customOptionMax), &sizeWritten);
/* Loop over the different tiles */ // /* Loop over the different tiles */
for (int tileIdx = 0; tileIdx < nbTiles; tileIdx++) { // for (int tileIdx = 0; tileIdx < nbTiles; tileIdx++) {
#if (CUDART_VERSION >= 11000) // // #if (CUDART_VERSION >= 11000)
/* Loop over different stages count */ // // /* Loop over different stages count */
for (int stagesIdx = 0; stagesIdx < nbStages; stagesIdx++) { // // for (int stagesIdx = 0; stagesIdx < nbStages; stagesIdx++) {
cublasLtMatmulAlgoConfigSetAttribute( // // cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &stagesA[stagesIdx], sizeof(stagesA[stagesIdx])); // // &algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &stagesA[stagesIdx], sizeof(stagesA[stagesIdx]));
#endif // // #endif
/* Loop over the different custom option if any */ // /* Loop over the different custom option if any */
for (int customOption = 0; customOption <= customOptionMax; customOption++) { // for (int customOption = 0; customOption <= customOptionMax; customOption++) {
cublasLtMatmulAlgoConfigSetAttribute( // cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &customOption, sizeof(customOption)); // &algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &customOption, sizeof(customOption));
/* Loop over the CTAs swizzling support */ // /* Loop over the CTAs swizzling support */
for (int k = 0; k <= swizzlingMax; k++) { // for (int k = 0; k <= swizzlingMax; k++) {
int splitK_trial = 0; // int splitK_trial = 0;
if (splitkSupport) { // if (splitkSupport) {
splitK_trial += sizeof(splitKSequenceA) / sizeof(splitKSequenceA[0]); // splitK_trial += sizeof(splitKSequenceA) / sizeof(splitKSequenceA[0]);
} // }
// Loop over the splitK value over a fixed sequence splitKSequenceA in addition to the case // // Loop over the splitK value over a fixed sequence splitKSequenceA in addition to the case
// where splitK is not enabled // // where splitK is not enabled
for (int l = 0; (l < (1 + splitK_trial)) && (AlgoCount < AlgoCombinations); l++) { // for (int l = 0; (l < (1 + splitK_trial)) && (AlgoCount < AlgoCombinations); l++) {
/* Setup attribute of the algo to run */ // /* Setup attribute of the algo to run */
cublasLtMatmulAlgoConfigSetAttribute( // cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, &tileA[tileIdx], sizeof(tileA[tileIdx])); // &algo, CUBLASLT_ALGO_CONFIG_TILE_ID, &tileA[tileIdx], sizeof(tileA[tileIdx]));
int splitK_val = 0; // int splitK_val = 0;
int redScheme = CUBLASLT_REDUCTION_SCHEME_NONE; // int redScheme = CUBLASLT_REDUCTION_SCHEME_NONE;
cublasLtMatmulAlgoConfigSetAttribute( // cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &splitK_val, sizeof(splitK_val)); // &algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &splitK_val, sizeof(splitK_val));
cublasLtMatmulAlgoConfigSetAttribute( // cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &k, sizeof(k)); // &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &k, sizeof(k));
cublasLtMatmulAlgoConfigSetAttribute( // cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &redScheme, sizeof(int)); // &algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &redScheme, sizeof(int));
if (l > 0) { // Split-K case // if (l > 0) { // Split-K case
splitK_val = splitKSequenceA[l - 1]; // splitK_val = splitKSequenceA[l - 1];
cublasLtMatmulAlgoConfigSetAttribute(&algo, // cublasLtMatmulAlgoConfigSetAttribute(&algo,
CUBLASLT_ALGO_CONFIG_SPLITK_NUM, // CUBLASLT_ALGO_CONFIG_SPLITK_NUM,
&splitKSequenceA[l - 1], // &splitKSequenceA[l - 1],
sizeof(splitKSequenceA[l - 1])); // sizeof(splitKSequenceA[l - 1]));
/* Going over all the reduction scheme */ // /* Going over all the reduction scheme */
for (redScheme = 1; // for (redScheme = 1;
redScheme <= (int)CUBLASLT_REDUCTION_SCHEME_MASK && (AlgoCount < AlgoCombinations); // redScheme <= (int)CUBLASLT_REDUCTION_SCHEME_MASK && (AlgoCount < AlgoCombinations);
redScheme = redScheme << 1) { // redScheme = redScheme << 1) {
if (redScheme & redMask) { // if (redScheme & redMask) {
cublasLtMatmulAlgoConfigSetAttribute(&algo, // cublasLtMatmulAlgoConfigSetAttribute(&algo,
CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, // CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME,
&redScheme, // &redScheme,
sizeof(redScheme)); // sizeof(redScheme));
status = customMatmulRun(ltHandle, // status = customMatmulRun(ltHandle,
operationDesc, // operationDesc,
alpha, /* host or device pointer */ // alpha, /* host or device pointer */
A, // A,
Adesc, // Adesc,
B, // B,
Bdesc, // Bdesc,
beta, /* host or device pointer */ // beta, /* host or device pointer */
C, // C,
Cdesc, // Cdesc,
C, // C,
Cdesc, // Cdesc,
algo, // algo,
kernelRepeats, // kernelRepeats,
workSpace, // workSpace,
workSpaceSize, // workSpaceSize,
perfResults[AlgoCount], // perfResults[AlgoCount],
stream); // stream);
perfResults[AlgoCount].status = status; // perfResults[AlgoCount].status = status;
if (status == CUBLAS_STATUS_SUCCESS) { // if (status == CUBLAS_STATUS_SUCCESS) {
AlgoCount++; // AlgoCount++;
} // }
} // end if // } // end if
} // end for // } // end for
} // }
else { // Non-splitK case // else { // Non-splitK case
/* if user preference is ok with workspace */ // /* if user preference is ok with workspace */
if (AlgoCount < AlgoCombinations) { // if (AlgoCount < AlgoCombinations) {
status = customMatmulRun(ltHandle, // status = customMatmulRun(ltHandle,
operationDesc, // operationDesc,
alpha, /* host or device pointer */ // alpha, /* host or device pointer */
A, // A,
Adesc, // Adesc,
B, // B,
Bdesc, // Bdesc,
beta, /* host or device pointer */ // beta, /* host or device pointer */
C, // C,
Cdesc, // Cdesc,
C, // C,
Cdesc, // Cdesc,
algo, // algo,
kernelRepeats, // kernelRepeats,
workSpace, // workSpace,
workSpaceSize, // workSpaceSize,
perfResults[AlgoCount], // perfResults[AlgoCount],
stream); // stream);
perfResults[AlgoCount].status = status; // perfResults[AlgoCount].status = status;
if (status == CUBLAS_STATUS_SUCCESS) { // if (status == CUBLAS_STATUS_SUCCESS) {
AlgoCount++; // AlgoCount++;
} // }
} // }
} // }
} // end l // } // end l
} // end k // } // end k
} // end customOption // } // end customOption
#if (CUDART_VERSION >= 11000) // // #if (CUDART_VERSION >= 11000)
} // end stagesIdx // // } // end stagesIdx
#endif // // #endif
} // end tileIdx // } // end tileIdx
delete[] tileA; // delete[] tileA;
} // end idx // } // end idx
// Sort the results per run duration // // Sort the results per run duration
std::sort(perfResults, perfResults + AlgoCount, time_compare); // std::sort(perfResults, perfResults + AlgoCount, time_compare);
// Print timing and perf details // // Print timing and perf details
for (int i = 0, hasPrint = 0; i < AlgoCount; i++) { // for (int i = 0, hasPrint = 0; i < AlgoCount; i++) {
printf("result %03d : ", i); // printf("result %03d : ", i);
hasPrint = printPerfStructure(m, n, k, perfResults[i], fout, hasPrint); // hasPrint = printPerfStructure(m, n, k, perfResults[i], fout, hasPrint);
} // }
CLEANUP: // CLEANUP:
// Descriptors are no longer needed as all GPU work was already enqueued // // Descriptors are no longer needed as all GPU work was already enqueued
if (Cdesc) { // if (Cdesc) {
cublasLtMatrixLayoutDestroy(Cdesc); // cublasLtMatrixLayoutDestroy(Cdesc);
} // }
if (Bdesc) { // if (Bdesc) {
cublasLtMatrixLayoutDestroy(Bdesc); // cublasLtMatrixLayoutDestroy(Bdesc);
} // }
if (Adesc) { // if (Adesc) {
cublasLtMatrixLayoutDestroy(Adesc); // cublasLtMatrixLayoutDestroy(Adesc);
} // }
if (operationDesc) { // if (operationDesc) {
cublasLtMatmulDescDestroy(operationDesc); // cublasLtMatmulDescDestroy(operationDesc);
} // }
return status == CUBLAS_STATUS_SUCCESS ? 0 : 1; // return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
} // }
template int LtIgemmCustomFind(cublasLtHandle_t ltHandle, // template int LtIgemmCustomFind(cublasLtHandle_t ltHandle,
int m, // int m,
int n, // int n,
int k, // int k,
const int* alpha, /* host pointer */ // const int* alpha, /* host pointer */
const int8_t* A, // const int8_t* A,
const int8_t* B, // const int8_t* B,
const int* beta, /* host pointer */ // const int* beta, /* host pointer */
int32_t* C, // int32_t* C,
void* workSpace, // void* workSpace,
size_t workSpaceSize, // size_t workSpaceSize,
FILE* fout); // FILE* fout);
template int LtIgemmCustomFind(cublasLtHandle_t ltHandle, // template int LtIgemmCustomFind(cublasLtHandle_t ltHandle,
int m, // int m,
int n, // int n,
int k, // int k,
const float* alpha, /* host pointer */ // const float* alpha, /* host pointer */
const int8_t* A, // const int8_t* A,
const int8_t* B, // const int8_t* B,
const float* beta, /* host pointer */ // const float* beta, /* host pointer */
int8_t* C, // int8_t* C,
void* workSpace, // void* workSpace,
size_t workSpaceSize, // size_t workSpaceSize,
FILE* fout); // FILE* fout);
template<typename T, typename scaleT> // template<typename T, typename scaleT>
int LtBatchIgemmCustomFind(cublasLtHandle_t ltHandle, // int LtBatchIgemmCustomFind(cublasLtHandle_t ltHandle,
int batchCount, // int batchCount,
int m, // int m,
int n, // int n,
int k, // int k,
const scaleT* alpha, /* host pointer */ // const scaleT* alpha, /* host pointer */
const int8_t* A, // const int8_t* A,
const int8_t* B, // const int8_t* B,
const scaleT* beta, /* host pointer */ // const scaleT* beta, /* host pointer */
T* C, // T* C,
void* workSpace, // void* workSpace,
size_t workSpaceSize, // size_t workSpaceSize,
FILE* fout) // FILE* fout)
{ // {
cublasStatus_t status = CUBLAS_STATUS_SUCCESS; // cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
cublasLtMatmulDesc_t operationDesc = NULL; // cublasLtMatmulDesc_t operationDesc = NULL;
cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL; // cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
cudaStream_t stream = 0; // cudaStream_t stream = 0;
// SplitK value that we are going to try when SplitK is supported for a given algo // // SplitK value that we are going to try when SplitK is supported for a given algo
const int splitKSequenceA[] = {2, 3, 4, 5, 6, 8, 12, 16, 32}; // const int splitKSequenceA[] = {2, 3, 4, 5, 6, 8, 12, 16, 32};
// Let try a fixed number of combinations // // Let try a fixed number of combinations
#define ALGO_COMBINATIONS 50000 // #define ALGO_COMBINATIONS 50000
int AlgoCombinations = ALGO_COMBINATIONS; // int AlgoCombinations = ALGO_COMBINATIONS;
int AlgoCount = 0; // int AlgoCount = 0;
int kernelRepeats = 100; // number of time the CUDA kernels will be run back to back // int kernelRepeats = 100; // number of time the CUDA kernels will be run back to back
customMatmulPerf_t perfResults[ALGO_COMBINATIONS]; // customMatmulPerf_t perfResults[ALGO_COMBINATIONS];
int nbAlgoIds = 0; // int nbAlgoIds = 0;
#define ALGO_IDS 100 // #define ALGO_IDS 100
int algoIdA[ALGO_IDS]; // int algoIdA[ALGO_IDS];
cudaDataType_t Atype, Btype, Ctype, scaleType; // cudaDataType_t Atype, Btype, Ctype, scaleType;
Atype = CUDA_R_8I; // Atype = CUDA_R_8I;
Btype = CUDA_R_8I; // Btype = CUDA_R_8I;
if (std::is_same<T, int32_t>::value && std::is_same<scaleT, int>::value) { // if (std::is_same<T, int32_t>::value && std::is_same<scaleT, int>::value) {
Ctype = CUDA_R_32I; // Ctype = CUDA_R_32I;
scaleType = CUDA_R_32I; // scaleType = CUDA_R_32I;
} // }
else if (std::is_same<T, int8_t>::value && std::is_same<scaleT, float>::value) { // else if (std::is_same<T, int8_t>::value && std::is_same<scaleT, float>::value) {
Ctype = CUDA_R_8I; // Ctype = CUDA_R_8I;
scaleType = CUDA_R_32F; // scaleType = CUDA_R_32F;
} // }
else { // else {
printf("[ERROR]<T,scaleT> of igemm is invalid\n"); // printf("[ERROR]<T,scaleT> of igemm is invalid\n");
exit(-1); // exit(-1);
} // }
#if (CUDART_VERSION >= 11000) // // #if (CUDART_VERSION >= 11000)
cublasComputeType_t computeType = CUBLAS_COMPUTE_32I; // // cublasComputeType_t computeType = CUBLAS_COMPUTE_32I;
#else // // #else
cudaDataType_t computeType = CUDA_R_32I; // cudaDataType_t computeType = CUDA_R_32I;
#endif // // #endif
cublasOperation_t opTranspose = CUBLAS_OP_T; // cublasOperation_t opTranspose = CUBLAS_OP_T;
bool use_ORDER_COL32_2R_4R4 = false; // bool use_ORDER_COL32_2R_4R4 = false;
#if (CUDART_VERSION >= 11000) // // #if (CUDART_VERSION >= 11000)
int device{-1}; // // int device{-1};
cudaGetDevice(&device); // // cudaGetDevice(&device);
cudaDeviceProp props; // // cudaDeviceProp props;
cudaGetDeviceProperties(&props, device); // // cudaGetDeviceProperties(&props, device);
if (props.major * 10 + props.minor >= 80) { // // if (props.major * 10 + props.minor >= 80) {
use_ORDER_COL32_2R_4R4 = true; // // use_ORDER_COL32_2R_4R4 = true;
} // // }
#endif // // #endif
cublasLtOrder_t order_COL32 = CUBLASLT_ORDER_COL32; // cublasLtOrder_t order_COL32 = CUBLASLT_ORDER_COL32;
cublasLtOrder_t order_matrixB; // cublasLtOrder_t order_matrixB;
#if (CUDART_VERSION >= 11000) // // #if (CUDART_VERSION >= 11000)
if (use_ORDER_COL32_2R_4R4) { // // if (use_ORDER_COL32_2R_4R4) {
order_matrixB = CUBLASLT_ORDER_COL32_2R_4R4; // // order_matrixB = CUBLASLT_ORDER_COL32_2R_4R4;
} // // }
else { // // else {
order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C; // // order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C;
} // // }
#else // // #else
order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C; // order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C;
#endif // // #endif
int ldaTransform = 32 * m; // int ldaTransform = 32 * m;
int ldbTransform; // int ldbTransform;
if (use_ORDER_COL32_2R_4R4) { // if (use_ORDER_COL32_2R_4R4) {
ldbTransform = 32 * ((n + 32 - 1) / 32) * 32; // ldbTransform = 32 * ((n + 32 - 1) / 32) * 32;
} // }
else { // else {
ldbTransform = 32 * ((n + 8 - 1) / 8) * 8; // ldbTransform = 32 * ((n + 8 - 1) / 8) * 8;
} // }
int ldcTransform = 32 * m; // int ldcTransform = 32 * m;
int64_t stridea, strideb, stridec; // int64_t stridea, strideb, stridec;
stridea = m * k; // stridea = m * k;
strideb = n * k; // strideb = n * k;
stridec = m * n; // stridec = m * n;
#if (CUDART_VERSION >= 11000) // // #if (CUDART_VERSION >= 11000)
status = cublasLtMatmulDescCreate(&operationDesc, computeType, scaleType); // // status = cublasLtMatmulDescCreate(&operationDesc, computeType, scaleType);
#else // // #else
status = cublasLtMatmulDescCreate(&operationDesc, scaleType); // status = cublasLtMatmulDescCreate(&operationDesc, scaleType);
#endif // // #endif
if (status != CUBLAS_STATUS_SUCCESS) { // if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP; // goto CLEANUP;
} // }
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opTranspose, sizeof(cublasOperation_t)); // cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opTranspose, sizeof(cublasOperation_t));
// Create matrix descriptors. // // Create matrix descriptors.
status = cublasLtMatrixLayoutCreate(&Adesc, Atype, m, k, ldaTransform); // status = cublasLtMatrixLayoutCreate(&Adesc, Atype, m, k, ldaTransform);
if (status != CUBLAS_STATUS_SUCCESS) { // if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP; // goto CLEANUP;
} // }
status = cublasLtMatrixLayoutSetAttribute(Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL32, sizeof(order_COL32)); // status = cublasLtMatrixLayoutSetAttribute(Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL32, sizeof(order_COL32));
if (status != CUBLAS_STATUS_SUCCESS) { // if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP; // goto CLEANUP;
} // }
cublasLtMatrixLayoutSetAttribute(Adesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batchCount, sizeof(batchCount)); // cublasLtMatrixLayoutSetAttribute(Adesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batchCount, sizeof(batchCount));
cublasLtMatrixLayoutSetAttribute(Adesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridea, sizeof(stridea)); // cublasLtMatrixLayoutSetAttribute(Adesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridea, sizeof(stridea));
status = cublasLtMatrixLayoutCreate(&Bdesc, Btype, n, k, ldbTransform); // status = cublasLtMatrixLayoutCreate(&Bdesc, Btype, n, k, ldbTransform);
if (status != CUBLAS_STATUS_SUCCESS) { // if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP; // goto CLEANUP;
} // }
status = // status =
cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_matrixB, sizeof(order_matrixB)); // cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_matrixB, sizeof(order_matrixB));
if (status != CUBLAS_STATUS_SUCCESS) { // if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP; // goto CLEANUP;
} // }
cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batchCount, sizeof(batchCount)); // cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batchCount, sizeof(batchCount));
cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideb, sizeof(strideb)); // cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideb, sizeof(strideb));
status = cublasLtMatrixLayoutCreate(&Cdesc, Ctype, m, n, ldcTransform); // status = cublasLtMatrixLayoutCreate(&Cdesc, Ctype, m, n, ldcTransform);
if (status != CUBLAS_STATUS_SUCCESS) { // if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP; // goto CLEANUP;
} // }
status = cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL32, sizeof(order_COL32)); // status = cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL32, sizeof(order_COL32));
if (status != CUBLAS_STATUS_SUCCESS) { // if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP; // goto CLEANUP;
} // }
cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batchCount, sizeof(batchCount)); // cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batchCount, sizeof(batchCount));
cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridec, sizeof(stridec)); // cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridec, sizeof(stridec));
// Request AlgoId available for IGEMM // // Request AlgoId available for IGEMM
status = cublasLtMatmulAlgoGetIds( // status = cublasLtMatmulAlgoGetIds(
ltHandle, computeType, scaleType, Atype, Btype, Ctype, Ctype, ALGO_IDS, algoIdA, &nbAlgoIds); // ltHandle, computeType, scaleType, Atype, Btype, Ctype, Ctype, ALGO_IDS, algoIdA, &nbAlgoIds);
if (status != CUBLAS_STATUS_SUCCESS) { // if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP; // goto CLEANUP;
} // }
// Loop over the Algo IDs // // Loop over the Algo IDs
for (int idx = 0; (idx < nbAlgoIds) && (AlgoCount < AlgoCombinations); idx++) { // for (int idx = 0; (idx < nbAlgoIds) && (AlgoCount < AlgoCombinations); idx++) {
cublasLtMatmulAlgo_t algo; // cublasLtMatmulAlgo_t algo;
size_t sizeWritten = 0; // size_t sizeWritten = 0;
/* Initialize algo structure with given Algp ID */ // /* Initialize algo structure with given Algp ID */
status = // status =
cublasLtMatmulAlgoInit(ltHandle, computeType, scaleType, Atype, Btype, Ctype, Ctype, algoIdA[idx], &algo); // cublasLtMatmulAlgoInit(ltHandle, computeType, scaleType, Atype, Btype, Ctype, Ctype, algoIdA[idx], &algo);
if (status != CUBLAS_STATUS_SUCCESS) { // if (status != CUBLAS_STATUS_SUCCESS) {
continue; // continue;
} // }
// Query the tiles enums supported by that algo // // Query the tiles enums supported by that algo
cublasLtMatmulAlgoCapGetAttribute(&algo, CUBLASLT_ALGO_CAP_TILE_IDS, NULL, 0, &sizeWritten); // cublasLtMatmulAlgoCapGetAttribute(&algo, CUBLASLT_ALGO_CAP_TILE_IDS, NULL, 0, &sizeWritten);
int nbTiles = int(sizeWritten / sizeof(int)); // int nbTiles = int(sizeWritten / sizeof(int));
int* tileA = new int[nbTiles == 0 ? 1 : nbTiles]; // int* tileA = new int[nbTiles == 0 ? 1 : nbTiles];
if (nbTiles == 0) { // if (nbTiles == 0) {
tileA[0] = CUBLASLT_MATMUL_TILE_UNDEFINED; // tileA[0] = CUBLASLT_MATMUL_TILE_UNDEFINED;
nbTiles = 1; // nbTiles = 1;
} // }
#if (CUDART_VERSION >= 11000) // // #if (CUDART_VERSION >= 11000)
cublasLtMatmulAlgoCapGetAttribute(&algo, CUBLASLT_ALGO_CAP_STAGES_IDS, NULL, 0, &sizeWritten); // // cublasLtMatmulAlgoCapGetAttribute(&algo, CUBLASLT_ALGO_CAP_STAGES_IDS, NULL, 0, &sizeWritten);
int nbStages = int(sizeWritten / sizeof(int)); // // int nbStages = int(sizeWritten / sizeof(int));
std::vector<int> stagesA(nbStages == 0 ? 1 : nbStages); // // std::vector<int> stagesA(nbStages == 0 ? 1 : nbStages);
if (nbStages == 0) { // // if (nbStages == 0) {
stagesA[0] = CUBLASLT_MATMUL_STAGES_UNDEFINED; // // stagesA[0] = CUBLASLT_MATMUL_STAGES_UNDEFINED;
nbStages = 1; // // nbStages = 1;
} // // }
else { // // else {
cublasLtMatmulAlgoCapGetAttribute( // // cublasLtMatmulAlgoCapGetAttribute(
&algo, CUBLASLT_ALGO_CAP_STAGES_IDS, stagesA.data(), sizeof(int) * nbStages, &sizeWritten); // // &algo, CUBLASLT_ALGO_CAP_STAGES_IDS, stagesA.data(), sizeof(int) * nbStages, &sizeWritten);
} // // }
#endif // // #endif
int splitkSupport, redMask, swizzlingMax, customOptionMax; // int splitkSupport, redMask, swizzlingMax, customOptionMax;
// Retrieve Algo Capabilities attributes to be able to setup loop over the different combinations // // Retrieve Algo Capabilities attributes to be able to setup loop over the different combinations
cublasLtMatmulAlgoCapGetAttribute( // cublasLtMatmulAlgoCapGetAttribute(
&algo, CUBLASLT_ALGO_CAP_TILE_IDS, tileA, sizeof(int) * nbTiles, &sizeWritten); // &algo, CUBLASLT_ALGO_CAP_TILE_IDS, tileA, sizeof(int) * nbTiles, &sizeWritten);
cublasLtMatmulAlgoCapGetAttribute( // cublasLtMatmulAlgoCapGetAttribute(
&algo, CUBLASLT_ALGO_CAP_SPLITK_SUPPORT, &splitkSupport, sizeof(splitkSupport), &sizeWritten); // &algo, CUBLASLT_ALGO_CAP_SPLITK_SUPPORT, &splitkSupport, sizeof(splitkSupport), &sizeWritten);
cublasLtMatmulAlgoCapGetAttribute( // cublasLtMatmulAlgoCapGetAttribute(
&algo, CUBLASLT_ALGO_CAP_REDUCTION_SCHEME_MASK, &redMask, sizeof(redMask), &sizeWritten); // &algo, CUBLASLT_ALGO_CAP_REDUCTION_SCHEME_MASK, &redMask, sizeof(redMask), &sizeWritten);
cublasLtMatmulAlgoCapGetAttribute( // cublasLtMatmulAlgoCapGetAttribute(
&algo, CUBLASLT_ALGO_CAP_CTA_SWIZZLING_SUPPORT, &swizzlingMax, sizeof(swizzlingMax), &sizeWritten); // &algo, CUBLASLT_ALGO_CAP_CTA_SWIZZLING_SUPPORT, &swizzlingMax, sizeof(swizzlingMax), &sizeWritten);
cublasLtMatmulAlgoCapGetAttribute( // cublasLtMatmulAlgoCapGetAttribute(
&algo, CUBLASLT_ALGO_CAP_CUSTOM_OPTION_MAX, &customOptionMax, sizeof(customOptionMax), &sizeWritten); // &algo, CUBLASLT_ALGO_CAP_CUSTOM_OPTION_MAX, &customOptionMax, sizeof(customOptionMax), &sizeWritten);
/* Loop over the different tiles */ // /* Loop over the different tiles */
for (int tileIdx = 0; tileIdx < nbTiles; tileIdx++) { // for (int tileIdx = 0; tileIdx < nbTiles; tileIdx++) {
#if (CUDART_VERSION >= 11000) // // #if (CUDART_VERSION >= 11000)
/* Loop over different stages count */ // // /* Loop over different stages count */
for (int stagesIdx = 0; stagesIdx < nbStages; stagesIdx++) { // // for (int stagesIdx = 0; stagesIdx < nbStages; stagesIdx++) {
cublasLtMatmulAlgoConfigSetAttribute( // // cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &stagesA[stagesIdx], sizeof(stagesA[stagesIdx])); // // &algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &stagesA[stagesIdx], sizeof(stagesA[stagesIdx]));
#endif // // #endif
/* Loop over the different custom option if any */ // /* Loop over the different custom option if any */
for (int customOption = 0; customOption <= customOptionMax; customOption++) { // for (int customOption = 0; customOption <= customOptionMax; customOption++) {
cublasLtMatmulAlgoConfigSetAttribute( // cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &customOption, sizeof(customOption)); // &algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &customOption, sizeof(customOption));
/* Loop over the CTAs swizzling support */ // /* Loop over the CTAs swizzling support */
for (int k = 0; k <= swizzlingMax; k++) { // for (int k = 0; k <= swizzlingMax; k++) {
int splitK_trial = 0; // int splitK_trial = 0;
if (splitkSupport) { // if (splitkSupport) {
splitK_trial += sizeof(splitKSequenceA) / sizeof(splitKSequenceA[0]); // splitK_trial += sizeof(splitKSequenceA) / sizeof(splitKSequenceA[0]);
} // }
// Loop over the splitK value over a fixed sequence splitKSequenceA in addition to the case // // Loop over the splitK value over a fixed sequence splitKSequenceA in addition to the case
// where splitK is not enabled // // where splitK is not enabled
for (int l = 0; (l < (1 + splitK_trial)) && (AlgoCount < AlgoCombinations); l++) { // for (int l = 0; (l < (1 + splitK_trial)) && (AlgoCount < AlgoCombinations); l++) {
/* Setup attribute of the algo to run */ // /* Setup attribute of the algo to run */
cublasLtMatmulAlgoConfigSetAttribute( // cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, &tileA[tileIdx], sizeof(tileA[tileIdx])); // &algo, CUBLASLT_ALGO_CONFIG_TILE_ID, &tileA[tileIdx], sizeof(tileA[tileIdx]));
int splitK_val = 0; // int splitK_val = 0;
int redScheme = CUBLASLT_REDUCTION_SCHEME_NONE; // int redScheme = CUBLASLT_REDUCTION_SCHEME_NONE;
cublasLtMatmulAlgoConfigSetAttribute( // cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &splitK_val, sizeof(splitK_val)); // &algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &splitK_val, sizeof(splitK_val));
cublasLtMatmulAlgoConfigSetAttribute( // cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &k, sizeof(k)); // &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &k, sizeof(k));
cublasLtMatmulAlgoConfigSetAttribute( // cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &redScheme, sizeof(int)); // &algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &redScheme, sizeof(int));
if (l > 0) { // Split-K case // if (l > 0) { // Split-K case
splitK_val = splitKSequenceA[l - 1]; // splitK_val = splitKSequenceA[l - 1];
cublasLtMatmulAlgoConfigSetAttribute(&algo, // cublasLtMatmulAlgoConfigSetAttribute(&algo,
CUBLASLT_ALGO_CONFIG_SPLITK_NUM, // CUBLASLT_ALGO_CONFIG_SPLITK_NUM,
&splitKSequenceA[l - 1], // &splitKSequenceA[l - 1],
sizeof(splitKSequenceA[l - 1])); // sizeof(splitKSequenceA[l - 1]));
/* Going over all the reduction scheme */ // /* Going over all the reduction scheme */
for (redScheme = 1; // for (redScheme = 1;
redScheme <= (int)CUBLASLT_REDUCTION_SCHEME_MASK && (AlgoCount < AlgoCombinations); // redScheme <= (int)CUBLASLT_REDUCTION_SCHEME_MASK && (AlgoCount < AlgoCombinations);
redScheme = redScheme << 1) { // redScheme = redScheme << 1) {
if (redScheme & redMask) { // if (redScheme & redMask) {
cublasLtMatmulAlgoConfigSetAttribute(&algo, // cublasLtMatmulAlgoConfigSetAttribute(&algo,
CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, // CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME,
&redScheme, // &redScheme,
sizeof(redScheme)); // sizeof(redScheme));
status = customMatmulRun(ltHandle, // status = customMatmulRun(ltHandle,
operationDesc, // operationDesc,
alpha, /* host or device pointer */ // alpha, /* host or device pointer */
A, // A,
Adesc, // Adesc,
B, // B,
Bdesc, // Bdesc,
beta, /* host or device pointer */ // beta, /* host or device pointer */
C, // C,
Cdesc, // Cdesc,
C, // C,
Cdesc, // Cdesc,
algo, // algo,
kernelRepeats, // kernelRepeats,
workSpace, // workSpace,
workSpaceSize, // workSpaceSize,
perfResults[AlgoCount], // perfResults[AlgoCount],
stream); // stream);
perfResults[AlgoCount].status = status; // perfResults[AlgoCount].status = status;
if (status == CUBLAS_STATUS_SUCCESS) { // if (status == CUBLAS_STATUS_SUCCESS) {
AlgoCount++; // AlgoCount++;
} // }
} // end if // } // end if
} // end for // } // end for
} // }
else { // Non-splitK case // else { // Non-splitK case
/* if user preference is ok with workspace */ // /* if user preference is ok with workspace */
if (AlgoCount < AlgoCombinations) { // if (AlgoCount < AlgoCombinations) {
status = customMatmulRun(ltHandle, // status = customMatmulRun(ltHandle,
operationDesc, // operationDesc,
alpha, /* host or device pointer */ // alpha, /* host or device pointer */
A, // A,
Adesc, // Adesc,
B, // B,
Bdesc, // Bdesc,
beta, /* host or device pointer */ // beta, /* host or device pointer */
C, // C,
Cdesc, // Cdesc,
C, // C,
Cdesc, // Cdesc,
algo, // algo,
kernelRepeats, // kernelRepeats,
workSpace, // workSpace,
workSpaceSize, // workSpaceSize,
perfResults[AlgoCount], // perfResults[AlgoCount],
stream); // stream);
perfResults[AlgoCount].status = status; // perfResults[AlgoCount].status = status;
if (status == CUBLAS_STATUS_SUCCESS) { // if (status == CUBLAS_STATUS_SUCCESS) {
AlgoCount++; // AlgoCount++;
} // }
} // }
} // }
} // end l // } // end l
} // end k // } // end k
} // end customOption // } // end customOption
#if (CUDART_VERSION >= 11000) // // #if (CUDART_VERSION >= 11000)
} // end stagesIdx // // } // end stagesIdx
#endif // // #endif
} // end tileIdx // } // end tileIdx
delete[] tileA; // delete[] tileA;
} // end idx // } // end idx
// Sort the results per run duration // // Sort the results per run duration
std::sort(perfResults, perfResults + AlgoCount, time_compare); // std::sort(perfResults, perfResults + AlgoCount, time_compare);
// Print timing and perf details // // Print timing and perf details
for (int i = 0, hasPrint = 0; i < AlgoCount; i++) { // for (int i = 0, hasPrint = 0; i < AlgoCount; i++) {
printf("result %03d : ", i); // printf("result %03d : ", i);
hasPrint = printBatchPerfStructure(batchCount, m, n, k, perfResults[i], fout, hasPrint); // hasPrint = printBatchPerfStructure(batchCount, m, n, k, perfResults[i], fout, hasPrint);
} // }
CLEANUP: // CLEANUP:
// Descriptors are no longer needed as all GPU work was already enqueued // // Descriptors are no longer needed as all GPU work was already enqueued
if (Cdesc) { // if (Cdesc) {
cublasLtMatrixLayoutDestroy(Cdesc); // cublasLtMatrixLayoutDestroy(Cdesc);
} // }
if (Bdesc) { // if (Bdesc) {
cublasLtMatrixLayoutDestroy(Bdesc); // cublasLtMatrixLayoutDestroy(Bdesc);
} // }
if (Adesc) { // if (Adesc) {
cublasLtMatrixLayoutDestroy(Adesc); // cublasLtMatrixLayoutDestroy(Adesc);
} // }
if (operationDesc) { // if (operationDesc) {
cublasLtMatmulDescDestroy(operationDesc); // cublasLtMatmulDescDestroy(operationDesc);
} // }
return status == CUBLAS_STATUS_SUCCESS ? 0 : 1; // return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
} // }
template int LtBatchIgemmCustomFind(cublasLtHandle_t ltHandle, // template int LtBatchIgemmCustomFind(cublasLtHandle_t ltHandle,
int batchCount, // int batchCount,
int m, // int m,
int n, // int n,
int k, // int k,
const int* alpha, /* host pointer */ // const int* alpha, /* host pointer */
const int8_t* A, // const int8_t* A,
const int8_t* B, // const int8_t* B,
const int* beta, /* host pointer */ // const int* beta, /* host pointer */
int32_t* C, // int32_t* C,
void* workSpace, // void* workSpace,
size_t workSpaceSize, // size_t workSpaceSize,
FILE* fout); // FILE* fout);
template int LtBatchIgemmCustomFind(cublasLtHandle_t ltHandle, // template int LtBatchIgemmCustomFind(cublasLtHandle_t ltHandle,
int batchCount, // int batchCount,
int m, // int m,
int n, // int n,
int k, // int k,
const float* alpha, /* host pointer */ // const float* alpha, /* host pointer */
const int8_t* A, // const int8_t* A,
const int8_t* B, // const int8_t* B,
const float* beta, /* host pointer */ // const float* beta, /* host pointer */
int8_t* C, // int8_t* C,
void* workSpace, // void* workSpace,
size_t workSpaceSize, // size_t workSpaceSize,
FILE* fout); // FILE* fout);
// initialize matrix in column-major // initialize matrix in column-major
void matInit(int rows, int cols, int8_t* p, int ld) void matInit(int rows, int cols, int8_t* p, int ld)
......
...@@ -52,11 +52,11 @@ int printPerfStructure(int batch_size, ...@@ -52,11 +52,11 @@ int printPerfStructure(int batch_size,
matmulAlgo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &swizzle, sizeof(swizzle), NULL); matmulAlgo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &swizzle, sizeof(swizzle), NULL);
cublasLtMatmulAlgoConfigGetAttribute( cublasLtMatmulAlgoConfigGetAttribute(
matmulAlgo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &customOption, sizeof(customOption), NULL); matmulAlgo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &customOption, sizeof(customOption), NULL);
#if (CUDART_VERSION >= 11000) // #if (CUDART_VERSION >= 11000)
cublasLtMatmulAlgoConfigGetAttribute(matmulAlgo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &stages, sizeof(stages), NULL); // cublasLtMatmulAlgoConfigGetAttribute(matmulAlgo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &stages, sizeof(stages), NULL);
#else // #else
stages = 0; stages = 0;
#endif // #endif
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3) #if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
uint16_t inner_shapeId, cluster_shapeId; uint16_t inner_shapeId, cluster_shapeId;
cublasLtMatmulAlgoConfigGetAttribute( cublasLtMatmulAlgoConfigGetAttribute(
...@@ -74,9 +74,9 @@ int printPerfStructure(int batch_size, ...@@ -74,9 +74,9 @@ int printPerfStructure(int batch_size,
#endif #endif
printf("algo={ Id=%d, tileIdx=%d (%s) splitK=%d reduc=%d swizzle=%d custom=%d " printf("algo={ Id=%d, tileIdx=%d (%s) splitK=%d reduc=%d swizzle=%d custom=%d "
#if (CUDART_VERSION >= 11000) // #if (CUDART_VERSION >= 11000)
"stages=%d " // "stages=%d "
#endif // #endif
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3) #if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
"inner_shapeId=%d cluster_shapeId=%d" "inner_shapeId=%d cluster_shapeId=%d"
#elif (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH < 3) #elif (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH < 3)
...@@ -91,9 +91,9 @@ int printPerfStructure(int batch_size, ...@@ -91,9 +91,9 @@ int printPerfStructure(int batch_size,
reductionScheme, reductionScheme,
swizzle, swizzle,
customOption, customOption,
#if (CUDART_VERSION >= 11000) // #if (CUDART_VERSION >= 11000)
stages, // stages,
#endif // #endif
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3) #if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
inner_shapeId, inner_shapeId,
cluster_shapeId, cluster_shapeId,
...@@ -154,704 +154,704 @@ static inline bool time_compare(const customMatmulPerf_t& perf_a, const customMa ...@@ -154,704 +154,704 @@ static inline bool time_compare(const customMatmulPerf_t& perf_a, const customMa
return ((perf_a.status == CUBLAS_STATUS_SUCCESS) && (perf_a.time < perf_b.time)); return ((perf_a.status == CUBLAS_STATUS_SUCCESS) && (perf_a.time < perf_b.time));
} }
static cublasStatus_t customMatmulRun(cublasLtHandle_t ltHandle, // to get the capabilities (required a GPU) // static cublasStatus_t customMatmulRun(cublasLtHandle_t ltHandle, // to get the capabilities (required a GPU)
cublasLtMatmulDesc_t operationDesc, // cublasLtMatmulDesc_t operationDesc,
const void* alpha, /* host or device pointer */ // const void* alpha, /* host or device pointer */
const void* A, // const void* A,
cublasLtMatrixLayout_t Adesc, // cublasLtMatrixLayout_t Adesc,
const void* B, // const void* B,
cublasLtMatrixLayout_t Bdesc, // cublasLtMatrixLayout_t Bdesc,
const void* beta, /* host or device pointer */ // const void* beta, /* host or device pointer */
const void* C, // const void* C,
cublasLtMatrixLayout_t Cdesc, // cublasLtMatrixLayout_t Cdesc,
void* D, // void* D,
cublasLtMatrixLayout_t Ddesc, // cublasLtMatrixLayout_t Ddesc,
const cublasLtMatmulAlgo_t& algo, // const cublasLtMatmulAlgo_t& algo,
int kernelRepeats, // int kernelRepeats,
void* workSpace, // void* workSpace,
size_t workSpaceSizeInBytes, // size_t workSpaceSizeInBytes,
customMatmulPerf_t& perfResults, // customMatmulPerf_t& perfResults,
cudaStream_t stream, // cudaStream_t stream,
cudaEvent_t& startEvent, // cudaEvent_t& startEvent,
cudaEvent_t& stopEvent) // cudaEvent_t& stopEvent)
{ // {
cublasLtMatmulHeuristicResult_t heurResult; // cublasLtMatmulHeuristicResult_t heurResult;
/* Looping over the Algo */ // /* Looping over the Algo */
int repeats = kernelRepeats; // int repeats = kernelRepeats;
cublasStatus_t algoStatus = // cublasStatus_t algoStatus =
cublasLtMatmulAlgoCheck(ltHandle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, &algo, &heurResult); // cublasLtMatmulAlgoCheck(ltHandle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, &algo, &heurResult);
if (algoStatus == CUBLAS_STATUS_SUCCESS) { // if (algoStatus == CUBLAS_STATUS_SUCCESS) {
if (heurResult.workspaceSize <= workSpaceSizeInBytes) { // if (heurResult.workspaceSize <= workSpaceSizeInBytes) {
cudaError_t err, err1, err2, err3; // cudaError_t err, err1, err2, err3;
err = cudaEventRecord(startEvent, stream); // err = cudaEventRecord(startEvent, stream);
for (int loop = 0; loop < repeats; loop++) { // for (int loop = 0; loop < repeats; loop++) {
cublasStatus_t oneRunStatus = cublasLtMatmul(ltHandle, // cublasStatus_t oneRunStatus = cublasLtMatmul(ltHandle,
operationDesc, // operationDesc,
alpha, // alpha,
A, // A,
Adesc, // Adesc,
B, // B,
Bdesc, // Bdesc,
beta, // beta,
C, // C,
Cdesc, // Cdesc,
D, // D,
Ddesc, // Ddesc,
&algo, // &algo,
workSpace, // workSpace,
workSpaceSizeInBytes, // workSpaceSizeInBytes,
stream); // stream);
if (oneRunStatus != CUBLAS_STATUS_SUCCESS) { // if (oneRunStatus != CUBLAS_STATUS_SUCCESS) {
algoStatus = oneRunStatus; // algoStatus = oneRunStatus;
break; // break;
} // }
} // }
err1 = cudaEventRecord(stopEvent, stream); // err1 = cudaEventRecord(stopEvent, stream);
err2 = cudaEventSynchronize(stopEvent); // err2 = cudaEventSynchronize(stopEvent);
float time; // float time;
err3 = cudaEventElapsedTime(&time, startEvent, stopEvent); // err3 = cudaEventElapsedTime(&time, startEvent, stopEvent);
if ((err != cudaSuccess) || (err1 != cudaSuccess) || (err2 != cudaSuccess) || (err3 != cudaSuccess)) { // if ((err != cudaSuccess) || (err1 != cudaSuccess) || (err2 != cudaSuccess) || (err3 != cudaSuccess)) {
algoStatus = CUBLAS_STATUS_INTERNAL_ERROR; // algoStatus = CUBLAS_STATUS_INTERNAL_ERROR;
} // }
// For the moment only add successful findings // // For the moment only add successful findings
if (algoStatus == CUBLAS_STATUS_SUCCESS) { // if (algoStatus == CUBLAS_STATUS_SUCCESS) {
perfResults.algo = algo; // perfResults.algo = algo;
perfResults.time = time / repeats; // perfResults.time = time / repeats;
perfResults.workspaceSize = heurResult.workspaceSize; // perfResults.workspaceSize = heurResult.workspaceSize;
perfResults.wavesCount = heurResult.wavesCount; // perfResults.wavesCount = heurResult.wavesCount;
} // }
} // }
else { // else {
// printf("not enough workspace! %ld\n", heurResult.workspaceSize); // // printf("not enough workspace! %ld\n", heurResult.workspaceSize);
algoStatus = CUBLAS_STATUS_NOT_SUPPORTED; // Not enough workspace // algoStatus = CUBLAS_STATUS_NOT_SUPPORTED; // Not enough workspace
} // }
} // }
return algoStatus; // return algoStatus;
} // }
template<typename T, typename scaleT> // template<typename T, typename scaleT>
int LtHgemmCustomFind(cublasLtHandle_t ltHandle, // int LtHgemmCustomFind(cublasLtHandle_t ltHandle,
int batch_size, // int batch_size,
int seq_len, // int seq_len,
int head_num, // int head_num,
int size_per_head, // int size_per_head,
int m, // int m,
int n, // int n,
int k, // int k,
const scaleT* alpha, /* host pointer */ // const scaleT* alpha, /* host pointer */
const T* A, // const T* A,
const T* B, // const T* B,
const scaleT* beta, /* host pointer */ // const scaleT* beta, /* host pointer */
T* C, // T* C,
void* workSpace, // void* workSpace,
size_t workSpaceSize, // size_t workSpaceSize,
FILE* fout, // FILE* fout,
customMatmulPerf_t perfResults[], // customMatmulPerf_t perfResults[],
int AlgoCombinations, // int AlgoCombinations,
cudaDataType_t dtype_fp8, // cudaDataType_t dtype_fp8,
int batchCount, // int batchCount,
int64_t strideA, // int64_t strideA,
int64_t strideB, // int64_t strideB,
int64_t strideD) // int64_t strideD)
{ // {
cublasStatus_t status = CUBLAS_STATUS_SUCCESS; // cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
cudaEvent_t startEvent; // cudaEvent_t startEvent;
cudaEvent_t stopEvent; // cudaEvent_t stopEvent;
CublasDataType data_type; // CublasDataType data_type;
cublasLtMatmulDesc_t operationDesc = NULL; // cublasLtMatmulDesc_t operationDesc = NULL;
cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL, Ddesc = NULL; // cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL, Ddesc = NULL;
cudaStream_t stream = 0; // cudaStream_t stream = 0;
// SplitK value that we are going to try when SplitK is supported for a // // SplitK value that we are going to try when SplitK is supported for a
// given algo // // given algo
const int splitKSequenceA[] = {2, 3, 4, 5, 6, 8, 12, 16, 32}; // const int splitKSequenceA[] = {2, 3, 4, 5, 6, 8, 12, 16, 32};
// Let try a fixed number of combinations // // Let try a fixed number of combinations
int AlgoCount = 0; // int AlgoCount = 0;
int AlgoCountRestrict = 0; // workspace == 0 // int AlgoCountRestrict = 0; // workspace == 0
const int maxNumTraversal = 50; // max number of traversal // const int maxNumTraversal = 50; // max number of traversal
std::vector<cublasLtMatmulAlgo_t> algos(AlgoCombinations); // 0 <= workspace <= 32MB // std::vector<cublasLtMatmulAlgo_t> algos(AlgoCombinations); // 0 <= workspace <= 32MB
std::vector<cublasLtMatmulAlgo_t> algosRestrict(AlgoCombinations); // workspace == 0 // std::vector<cublasLtMatmulAlgo_t> algosRestrict(AlgoCombinations); // workspace == 0
const int kernelRepeats = 100; // number of time the CUDA kernels will be run back to back // const int kernelRepeats = 100; // number of time the CUDA kernels will be run back to back
int nbAlgoIds = 0; // Number of algorithms actually returned by // int nbAlgoIds = 0; // Number of algorithms actually returned by
// cublasLtMatmulAlgoGetIds function. // // cublasLtMatmulAlgoGetIds function.
#define ALGO_IDS 100 // Number of algorithms requested. // #define ALGO_IDS 100 // Number of algorithms requested.
int algoIdA[ALGO_IDS]; // Array containing the algorithm IDs returned by // int algoIdA[ALGO_IDS]; // Array containing the algorithm IDs returned by
// cublasLtMatmulAlgoGetIds function. // // cublasLtMatmulAlgoGetIds function.
cudaDataType_t Atype, Btype, Ctype, scaleType, Dtype; // cudaDataType_t Atype, Btype, Ctype, scaleType, Dtype;
#if (CUDART_VERSION >= 11000) // // #if (CUDART_VERSION >= 11000)
cublasComputeType_t computeType; // // cublasComputeType_t computeType;
#else // // #else
cudaDataType_t computeType; // cudaDataType_t computeType;
#endif // // #endif
if (std::is_same<T, float>::value) { // if (std::is_same<T, float>::value) {
data_type = FLOAT_DATATYPE; // data_type = FLOAT_DATATYPE;
Atype = CUDA_R_32F, Btype = CUDA_R_32F, Ctype = CUDA_R_32F, Dtype = CUDA_R_32F; // Atype = CUDA_R_32F, Btype = CUDA_R_32F, Ctype = CUDA_R_32F, Dtype = CUDA_R_32F;
} // }
else if (std::is_same<T, half>::value) { // else if (std::is_same<T, half>::value) {
data_type = HALF_DATATYPE; // data_type = HALF_DATATYPE;
Atype = CUDA_R_16F, Btype = CUDA_R_16F, Ctype = CUDA_R_16F, Dtype = CUDA_R_16F; // Atype = CUDA_R_16F, Btype = CUDA_R_16F, Ctype = CUDA_R_16F, Dtype = CUDA_R_16F;
} // }
#ifdef ENABLE_BF16 // #ifdef ENABLE_BF16
else if (std::is_same<T, __nv_bfloat16>::value) { // else if (std::is_same<T, __nv_bfloat16>::value) {
data_type = BFLOAT16_DATATYPE; // data_type = BFLOAT16_DATATYPE;
Atype = CUDA_R_16BF, Btype = CUDA_R_16BF, Ctype = CUDA_R_16BF, Dtype = CUDA_R_16BF; // Atype = CUDA_R_16BF, Btype = CUDA_R_16BF, Ctype = CUDA_R_16BF, Dtype = CUDA_R_16BF;
} // }
#endif // #endif
#ifdef ENABLE_FP8 // #ifdef ENABLE_FP8
else if (std::is_same<T, __nv_fp8_e4m3>::value) { // else if (std::is_same<T, __nv_fp8_e4m3>::value) {
data_type = FP8_DATATYPE; // data_type = FP8_DATATYPE;
Atype = CUDA_R_8F_E4M3, Btype = CUDA_R_8F_E4M3, Ctype = CUDA_R_16BF; // Atype = CUDA_R_8F_E4M3, Btype = CUDA_R_8F_E4M3, Ctype = CUDA_R_16BF;
#ifdef FP8_GEMM_OUTPUT_QUANT_DISABLE // #ifdef FP8_GEMM_OUTPUT_QUANT_DISABLE
Dtype = CUDA_R_16BF; // Dtype = CUDA_R_16BF;
#else // #else
Dtype = dtype_fp8; // Dtype = dtype_fp8;
#endif // #endif
} // }
#endif // #endif
if (sizeof(scaleT) == sizeof(float)) { // if (sizeof(scaleT) == sizeof(float)) {
scaleType = CUDA_R_32F; // scaleType = CUDA_R_32F;
#if (CUDART_VERSION >= 11000) // // #if (CUDART_VERSION >= 11000)
computeType = CUBLAS_COMPUTE_32F; // // computeType = CUBLAS_COMPUTE_32F;
#else // // #else
computeType = CUDA_R_32F; // computeType = CUDA_R_32F;
#endif // // #endif
} // }
else { // else {
scaleType = CUDA_R_16F; // scaleType = CUDA_R_16F;
#if (CUDART_VERSION >= 11000) // // #if (CUDART_VERSION >= 11000)
computeType = CUBLAS_COMPUTE_16F; // // computeType = CUBLAS_COMPUTE_16F;
#else // // #else
computeType = CUDA_R_16F; // computeType = CUDA_R_16F;
#endif // // #endif
} // }
const cublasOperation_t tA = data_type == FP8_DATATYPE ? CUBLAS_OP_T : CUBLAS_OP_N; // const cublasOperation_t tA = data_type == FP8_DATATYPE ? CUBLAS_OP_T : CUBLAS_OP_N;
// Create operation descriptor; see cublasLtMatmulDescAttributes_t for // // Create operation descriptor; see cublasLtMatmulDescAttributes_t for
// details about defaults; here we just need to set the transforms for A and // // details about defaults; here we just need to set the transforms for A and
// B // // B
#if (CUDART_VERSION >= 11000) // // #if (CUDART_VERSION >= 11000)
status = cublasLtMatmulDescCreate(&operationDesc, computeType, // // status = cublasLtMatmulDescCreate(&operationDesc, computeType,
scaleType); // creates a matrix multiply descriptor // // scaleType); // creates a matrix multiply descriptor
#else // // #else
status = cublasLtMatmulDescCreate(&operationDesc, computeType); // status = cublasLtMatmulDescCreate(&operationDesc, computeType);
#endif // // #endif
if (status != CUBLAS_STATUS_SUCCESS) { // if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP; // goto CLEANUP;
} // }
status = cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &tA, sizeof(tA)); // status = cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &tA, sizeof(tA));
if (status != CUBLAS_STATUS_SUCCESS) { // if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP; // goto CLEANUP;
} // }
#ifdef ENABLE_FP8 // #ifdef ENABLE_FP8
if (data_type == FP8_DATATYPE) { // if (data_type == FP8_DATATYPE) {
const int8_t fastAccuMode = 1; // enable fast imprecise accum // const int8_t fastAccuMode = 1; // enable fast imprecise accum
status = cublasLtMatmulDescSetAttribute( // status = cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fastAccuMode, sizeof(decltype(fastAccuMode))); // operationDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fastAccuMode, sizeof(decltype(fastAccuMode)));
if (status != CUBLAS_STATUS_SUCCESS) { // if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP; // goto CLEANUP;
} // }
} // }
#endif // #endif
// Create matrix descriptors. We are good with the details here so no need // // Create matrix descriptors. We are good with the details here so no need
// to set any extra attributes // // to set any extra attributes
if (data_type == FP8_DATATYPE) { // if (data_type == FP8_DATATYPE) {
status = cublasLtMatrixLayoutCreate(&Adesc, Atype, k, m, k); // status = cublasLtMatrixLayoutCreate(&Adesc, Atype, k, m, k);
} // }
else { // else {
status = cublasLtMatrixLayoutCreate(&Adesc, Atype, m, k, m); // status = cublasLtMatrixLayoutCreate(&Adesc, Atype, m, k, m);
} // }
if (status != CUBLAS_STATUS_SUCCESS) { // if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP; // goto CLEANUP;
} // }
status = cublasLtMatrixLayoutCreate(&Bdesc, Btype, k, n, k); // status = cublasLtMatrixLayoutCreate(&Bdesc, Btype, k, n, k);
if (status != CUBLAS_STATUS_SUCCESS) { // if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP; // goto CLEANUP;
} // }
status = cublasLtMatrixLayoutCreate(&Cdesc, Ctype, m, n, m); // status = cublasLtMatrixLayoutCreate(&Cdesc, Ctype, m, n, m);
if (status != CUBLAS_STATUS_SUCCESS) { // if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP; // goto CLEANUP;
} // }
status = cublasLtMatrixLayoutCreate(&Ddesc, Dtype, m, n, m); // status = cublasLtMatrixLayoutCreate(&Ddesc, Dtype, m, n, m);
if (status != CUBLAS_STATUS_SUCCESS) { // if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP; // goto CLEANUP;
} // }
if (batchCount > 1) { // if (batchCount > 1) {
check_cuda_error(cublasLtMatrixLayoutSetAttribute( // check_cuda_error(cublasLtMatrixLayoutSetAttribute(
Adesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batchCount, sizeof(batchCount))); // Adesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batchCount, sizeof(batchCount)));
check_cuda_error(cublasLtMatrixLayoutSetAttribute( // check_cuda_error(cublasLtMatrixLayoutSetAttribute(
Bdesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batchCount, sizeof(batchCount))); // Bdesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batchCount, sizeof(batchCount)));
check_cuda_error(cublasLtMatrixLayoutSetAttribute( // check_cuda_error(cublasLtMatrixLayoutSetAttribute(
Cdesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batchCount, sizeof(batchCount))); // Cdesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batchCount, sizeof(batchCount)));
check_cuda_error(cublasLtMatrixLayoutSetAttribute( // check_cuda_error(cublasLtMatrixLayoutSetAttribute(
Ddesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batchCount, sizeof(batchCount))); // Ddesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batchCount, sizeof(batchCount)));
check_cuda_error(cublasLtMatrixLayoutSetAttribute( // check_cuda_error(cublasLtMatrixLayoutSetAttribute(
Adesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideA, sizeof(strideA))); // Adesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideA, sizeof(strideA)));
check_cuda_error(cublasLtMatrixLayoutSetAttribute( // check_cuda_error(cublasLtMatrixLayoutSetAttribute(
Bdesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideB, sizeof(strideB))); // Bdesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideB, sizeof(strideB)));
check_cuda_error(cublasLtMatrixLayoutSetAttribute( // check_cuda_error(cublasLtMatrixLayoutSetAttribute(
Cdesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideD, sizeof(strideD))); // Cdesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideD, sizeof(strideD)));
check_cuda_error(cublasLtMatrixLayoutSetAttribute( // check_cuda_error(cublasLtMatrixLayoutSetAttribute(
Ddesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideD, sizeof(strideD))); // Ddesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideD, sizeof(strideD)));
} // }
// Create CUDA event to time the execution time of each algo // // Create CUDA event to time the execution time of each algo
if (cudaEventCreate(&startEvent, cudaEventBlockingSync) != cudaSuccess) { // if (cudaEventCreate(&startEvent, cudaEventBlockingSync) != cudaSuccess) {
goto CLEANUP; // goto CLEANUP;
} // }
if (cudaEventCreate(&stopEvent, cudaEventBlockingSync) != cudaSuccess) { // if (cudaEventCreate(&stopEvent, cudaEventBlockingSync) != cudaSuccess) {
goto CLEANUP; // goto CLEANUP;
} // }
// Request the 100 first AlgoId available // // Request the 100 first AlgoId available
status = cublasLtMatmulAlgoGetIds( // status = cublasLtMatmulAlgoGetIds(
ltHandle, computeType, scaleType, Atype, Btype, Ctype, Dtype, ALGO_IDS, algoIdA, &nbAlgoIds); // ltHandle, computeType, scaleType, Atype, Btype, Ctype, Dtype, ALGO_IDS, algoIdA, &nbAlgoIds);
if (status != CUBLAS_STATUS_SUCCESS) { // if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP; // goto CLEANUP;
} // }
if (nbAlgoIds > ALGO_IDS) { // if (nbAlgoIds > ALGO_IDS) {
printf( // printf(
"Warning: the algo id count is not large enough to guarantee the best algo %d, %d\n", nbAlgoIds, ALGO_IDS); // "Warning: the algo id count is not large enough to guarantee the best algo %d, %d\n", nbAlgoIds, ALGO_IDS);
} // }
// Loop over the Algo IDs // // Loop over the Algo IDs
// This loop doesn't work for fp8 gemm // // This loop doesn't work for fp8 gemm
for (int idx = 0; (idx < nbAlgoIds) && (AlgoCount < AlgoCombinations); idx++) { // for (int idx = 0; (idx < nbAlgoIds) && (AlgoCount < AlgoCombinations); idx++) {
cublasLtMatmulAlgo_t algo; // cublasLtMatmulAlgo_t algo;
size_t sizeWritten = 0; // size_t sizeWritten = 0;
/* Initialize algo structure with given Algp ID */ // /* Initialize algo structure with given Algp ID */
status = // status =
cublasLtMatmulAlgoInit(ltHandle, computeType, scaleType, Atype, Btype, Ctype, Dtype, algoIdA[idx], &algo); // cublasLtMatmulAlgoInit(ltHandle, computeType, scaleType, Atype, Btype, Ctype, Dtype, algoIdA[idx], &algo);
if (status != CUBLAS_STATUS_SUCCESS) { // if (status != CUBLAS_STATUS_SUCCESS) {
continue; // continue;
} // }
// Query the tiles enums supported by that algo // // Query the tiles enums supported by that algo
cublasLtMatmulAlgoCapGetAttribute(&algo, CUBLASLT_ALGO_CAP_TILE_IDS, NULL, 0, &sizeWritten); // cublasLtMatmulAlgoCapGetAttribute(&algo, CUBLASLT_ALGO_CAP_TILE_IDS, NULL, 0, &sizeWritten);
int nbTiles = int(sizeWritten / sizeof(int)); // int nbTiles = int(sizeWritten / sizeof(int));
int* tileA = new int[nbTiles == 0 ? 1 : nbTiles]; // int* tileA = new int[nbTiles == 0 ? 1 : nbTiles];
if (nbTiles == 0) { // if (nbTiles == 0) {
tileA[0] = CUBLASLT_MATMUL_TILE_UNDEFINED; // tileA[0] = CUBLASLT_MATMUL_TILE_UNDEFINED;
nbTiles = 1; // nbTiles = 1;
} // }
#if (CUDART_VERSION >= 11000) // // #if (CUDART_VERSION >= 11000)
cublasLtMatmulAlgoCapGetAttribute(&algo, CUBLASLT_ALGO_CAP_STAGES_IDS, NULL, 0, &sizeWritten); // // cublasLtMatmulAlgoCapGetAttribute(&algo, CUBLASLT_ALGO_CAP_STAGES_IDS, NULL, 0, &sizeWritten);
int nbStages = int(sizeWritten / sizeof(int)); // // int nbStages = int(sizeWritten / sizeof(int));
std::vector<int> stagesA(nbStages == 0 ? 1 : nbStages); // // std::vector<int> stagesA(nbStages == 0 ? 1 : nbStages);
if (nbStages == 0) { // // if (nbStages == 0) {
stagesA[0] = CUBLASLT_MATMUL_STAGES_UNDEFINED; // // stagesA[0] = CUBLASLT_MATMUL_STAGES_UNDEFINED;
nbStages = 1; // // nbStages = 1;
} // // }
else { // // else {
cublasLtMatmulAlgoCapGetAttribute( // // cublasLtMatmulAlgoCapGetAttribute(
&algo, CUBLASLT_ALGO_CAP_STAGES_IDS, stagesA.data(), sizeof(int) * nbStages, &sizeWritten); // // &algo, CUBLASLT_ALGO_CAP_STAGES_IDS, stagesA.data(), sizeof(int) * nbStages, &sizeWritten);
} // // }
#endif // // #endif
int splitkSupport, redMask, swizzlingMax, customOptionMax; // int splitkSupport, redMask, swizzlingMax, customOptionMax;
// Retrieve Algo Capabilities attributes to be able to setup loop over // // Retrieve Algo Capabilities attributes to be able to setup loop over
// the different combinations // // the different combinations
cublasLtMatmulAlgoCapGetAttribute( // cublasLtMatmulAlgoCapGetAttribute(
&algo, CUBLASLT_ALGO_CAP_TILE_IDS, tileA, sizeof(int) * nbTiles, &sizeWritten); // &algo, CUBLASLT_ALGO_CAP_TILE_IDS, tileA, sizeof(int) * nbTiles, &sizeWritten);
cublasLtMatmulAlgoCapGetAttribute( // cublasLtMatmulAlgoCapGetAttribute(
&algo, CUBLASLT_ALGO_CAP_SPLITK_SUPPORT, &splitkSupport, sizeof(splitkSupport), &sizeWritten); // &algo, CUBLASLT_ALGO_CAP_SPLITK_SUPPORT, &splitkSupport, sizeof(splitkSupport), &sizeWritten);
cublasLtMatmulAlgoCapGetAttribute( // cublasLtMatmulAlgoCapGetAttribute(
&algo, CUBLASLT_ALGO_CAP_REDUCTION_SCHEME_MASK, &redMask, sizeof(redMask), &sizeWritten); // &algo, CUBLASLT_ALGO_CAP_REDUCTION_SCHEME_MASK, &redMask, sizeof(redMask), &sizeWritten);
cublasLtMatmulAlgoCapGetAttribute( // cublasLtMatmulAlgoCapGetAttribute(
&algo, CUBLASLT_ALGO_CAP_CTA_SWIZZLING_SUPPORT, &swizzlingMax, sizeof(swizzlingMax), &sizeWritten); // &algo, CUBLASLT_ALGO_CAP_CTA_SWIZZLING_SUPPORT, &swizzlingMax, sizeof(swizzlingMax), &sizeWritten);
cublasLtMatmulAlgoCapGetAttribute( // cublasLtMatmulAlgoCapGetAttribute(
&algo, CUBLASLT_ALGO_CAP_CUSTOM_OPTION_MAX, &customOptionMax, sizeof(customOptionMax), &sizeWritten); // &algo, CUBLASLT_ALGO_CAP_CUSTOM_OPTION_MAX, &customOptionMax, sizeof(customOptionMax), &sizeWritten);
/* Loop over the different tiles */ // /* Loop over the different tiles */
for (int tileIdx = 0; tileIdx < nbTiles; tileIdx++) { // for (int tileIdx = 0; tileIdx < nbTiles; tileIdx++) {
#if (CUDART_VERSION >= 11000) // // #if (CUDART_VERSION >= 11000)make:q
/* Loop over different stages count */ // // /* Loop over different stages count */
for (int stagesIdx = 0; stagesIdx < nbStages; stagesIdx++) { // // for (int stagesIdx = 0; stagesIdx < nbStages; stagesIdx++) {
cublasLtMatmulAlgoConfigSetAttribute( // // cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &stagesA[stagesIdx], sizeof(stagesA[stagesIdx])); // // &algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &stagesA[stagesIdx], sizeof(stagesA[stagesIdx]));
#endif // // #endif
/* Loop over the different custom option if any */ // /* Loop over the different custom option if any */
for (int customOption = 0; customOption <= customOptionMax; customOption++) { // for (int customOption = 0; customOption <= customOptionMax; customOption++) {
cublasLtMatmulAlgoConfigSetAttribute( // cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &customOption, sizeof(customOption)); // &algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &customOption, sizeof(customOption));
/* Loop over the CTAs swizzling support */ // /* Loop over the CTAs swizzling support */
for (int k = 0; k <= swizzlingMax; k++) { // for (int k = 0; k <= swizzlingMax; k++) {
int splitK_trial = 0; // int splitK_trial = 0;
if (splitkSupport) { // if (splitkSupport) {
splitK_trial += sizeof(splitKSequenceA) / sizeof(splitKSequenceA[0]); // splitK_trial += sizeof(splitKSequenceA) / sizeof(splitKSequenceA[0]);
} // }
// Loop over the splitK value over a fixed sequence // // Loop over the splitK value over a fixed sequence
// splitKSequenceA in addition to the case where splitK // // splitKSequenceA in addition to the case where splitK
// is not enabled // // is not enabled
for (int l = 0; (l < (1 + splitK_trial)) && (AlgoCount < AlgoCombinations); l++) { // for (int l = 0; (l < (1 + splitK_trial)) && (AlgoCount < AlgoCombinations); l++) {
/* Setup attribute of the algo to run */ // /* Setup attribute of the algo to run */
cublasLtMatmulAlgoConfigSetAttribute( // cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, &tileA[tileIdx], sizeof(tileA[tileIdx])); // &algo, CUBLASLT_ALGO_CONFIG_TILE_ID, &tileA[tileIdx], sizeof(tileA[tileIdx]));
int splitK_val = 0; // int splitK_val = 0;
int redScheme = CUBLASLT_REDUCTION_SCHEME_NONE; // int redScheme = CUBLASLT_REDUCTION_SCHEME_NONE;
cublasLtMatmulAlgoConfigSetAttribute( // cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &splitK_val, sizeof(splitK_val)); // &algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &splitK_val, sizeof(splitK_val));
cublasLtMatmulAlgoConfigSetAttribute( // cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &k, sizeof(k)); // &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &k, sizeof(k));
cublasLtMatmulAlgoConfigSetAttribute( // cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &redScheme, sizeof(int)); // &algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &redScheme, sizeof(int));
if (l > 0) { // Split-K case // if (l > 0) { // Split-K case
splitK_val = splitKSequenceA[l - 1]; // splitK_val = splitKSequenceA[l - 1];
cublasLtMatmulAlgoConfigSetAttribute(&algo, // cublasLtMatmulAlgoConfigSetAttribute(&algo,
CUBLASLT_ALGO_CONFIG_SPLITK_NUM, // CUBLASLT_ALGO_CONFIG_SPLITK_NUM,
&splitKSequenceA[l - 1], // &splitKSequenceA[l - 1],
sizeof(splitKSequenceA[l - 1])); // sizeof(splitKSequenceA[l - 1]));
/* Going over all the reduction scheme */ // /* Going over all the reduction scheme */
for (redScheme = 1; // for (redScheme = 1;
redScheme < (int)CUBLASLT_REDUCTION_SCHEME_MASK && (AlgoCount < AlgoCombinations); // redScheme < (int)CUBLASLT_REDUCTION_SCHEME_MASK && (AlgoCount < AlgoCombinations);
redScheme = redScheme << 1) { // redScheme = redScheme << 1) {
if (redScheme & redMask) { // if (redScheme & redMask) {
cublasLtMatmulAlgoConfigSetAttribute(&algo, // cublasLtMatmulAlgoConfigSetAttribute(&algo,
CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, // CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME,
&redScheme, // &redScheme,
sizeof(redScheme)); // sizeof(redScheme));
cublasLtMatmulHeuristicResult_t heurResult; // cublasLtMatmulHeuristicResult_t heurResult;
cublasStatus_t algoStatus = cublasLtMatmulAlgoCheck( // cublasStatus_t algoStatus = cublasLtMatmulAlgoCheck(
ltHandle, operationDesc, Adesc, Bdesc, Cdesc, Cdesc, &algo, &heurResult); // ltHandle, operationDesc, Adesc, Bdesc, Cdesc, Cdesc, &algo, &heurResult);
if (heurResult.workspaceSize > workSpaceSize) { // if (heurResult.workspaceSize > workSpaceSize) {
// printf("not enough workspace! // // printf("not enough workspace!
// %ld\n", // // %ld\n",
// heurResult.workspaceSize); // // heurResult.workspaceSize);
algoStatus = CUBLAS_STATUS_NOT_SUPPORTED; // Not enough workspace // algoStatus = CUBLAS_STATUS_NOT_SUPPORTED; // Not enough workspace
} // }
else if (heurResult.workspaceSize == 0) { // else if (heurResult.workspaceSize == 0) {
if (algoStatus == CUBLAS_STATUS_SUCCESS) { // if (algoStatus == CUBLAS_STATUS_SUCCESS) {
algosRestrict[AlgoCountRestrict++] = algo; // algosRestrict[AlgoCountRestrict++] = algo;
} // }
} // }
if (algoStatus == CUBLAS_STATUS_SUCCESS) { // if (algoStatus == CUBLAS_STATUS_SUCCESS) {
algos[AlgoCount++] = algo; // algos[AlgoCount++] = algo;
} // }
} // end if // } // end if
} // end for // } // end for
} // }
else { // Non-splitK case // else { // Non-splitK case
/* if user preference is ok with workspace */ // /* if user preference is ok with workspace */
if (AlgoCount < AlgoCombinations) { // if (AlgoCount < AlgoCombinations) {
cublasLtMatmulHeuristicResult_t heurResult; // cublasLtMatmulHeuristicResult_t heurResult;
cublasStatus_t algoStatus = cublasLtMatmulAlgoCheck( // cublasStatus_t algoStatus = cublasLtMatmulAlgoCheck(
ltHandle, operationDesc, Adesc, Bdesc, Cdesc, Cdesc, &algo, &heurResult); // ltHandle, operationDesc, Adesc, Bdesc, Cdesc, Cdesc, &algo, &heurResult);
if (heurResult.workspaceSize > workSpaceSize) { // if (heurResult.workspaceSize > workSpaceSize) {
// printf("not enough workspace! %ld\n", // // printf("not enough workspace! %ld\n",
// heurResult.workspaceSize); // // heurResult.workspaceSize);
algoStatus = CUBLAS_STATUS_NOT_SUPPORTED; // Not // algoStatus = CUBLAS_STATUS_NOT_SUPPORTED; // Not
// enough // // enough
// workspace // // workspace
} // }
else if (heurResult.workspaceSize == 0) { // else if (heurResult.workspaceSize == 0) {
if (algoStatus == CUBLAS_STATUS_SUCCESS) { // if (algoStatus == CUBLAS_STATUS_SUCCESS) {
algosRestrict[AlgoCountRestrict++] = algo; // algosRestrict[AlgoCountRestrict++] = algo;
} // }
} // }
if (algoStatus == CUBLAS_STATUS_SUCCESS) { // if (algoStatus == CUBLAS_STATUS_SUCCESS) {
algos[AlgoCount++] = algo; // algos[AlgoCount++] = algo;
} // }
} // }
} // }
} // end l // } // end l
} // end k // } // end k
} // end customOption // } // end customOption
#if (CUDART_VERSION >= 11000) // // #if (CUDART_VERSION >= 11000)
} // end stagesIdx // } // end stagesIdx
#endif // // #endif
} // end tileIdx // } // end tileIdx
delete[] tileA; // delete[] tileA;
} // end idx // } // end idx
printf("AlgoCount: %d\n", AlgoCount); // printf("AlgoCount: %d\n", AlgoCount);
if (data_type == FP8_DATATYPE) { // if (data_type == FP8_DATATYPE) {
assert(AlgoCount == 0); // assert(AlgoCount == 0);
} // }
if (AlgoCount < maxNumTraversal && data_type != FP8_DATATYPE) { // if (AlgoCount < maxNumTraversal && data_type != FP8_DATATYPE) {
// 0 <= workspacesize <= 32MB // // 0 <= workspacesize <= 32MB
for (int i = 0; i < AlgoCount; i++) { // for (int i = 0; i < AlgoCount; i++) {
status = customMatmulRun(ltHandle, // status = customMatmulRun(ltHandle,
operationDesc, // operationDesc,
alpha, /* host or device pointer */ // alpha, /* host or device pointer */
A, // A,
Adesc, // Adesc,
B, // B,
Bdesc, // Bdesc,
beta, /* host or device pointer */ // beta, /* host or device pointer */
C, // C,
Cdesc, // Cdesc,
C, // C,
Cdesc, // Cdesc,
algos[i], // algos[i],
kernelRepeats, // kernelRepeats,
workSpace, // workSpace,
workSpaceSize, // workSpaceSize,
perfResults[i], // perfResults[i],
stream, // stream,
startEvent, // startEvent,
stopEvent); // stopEvent);
perfResults[i].status = status; // perfResults[i].status = status;
// if (status == CUBLAS_STATUS_SUCCESS) AlgoCount++; // // if (status == CUBLAS_STATUS_SUCCESS) AlgoCount++;
} // }
} // }
else { // else {
// Heuristic + workspacesize==0 // // Heuristic + workspacesize==0
AlgoCount = 0; // AlgoCount = 0;
nbAlgoIds = 0; // nbAlgoIds = 0;
cublasLtMatmulPreference_t pref; // cublasLtMatmulPreference_t pref;
cublasLtMatmulPreferenceCreate(&pref); // cublasLtMatmulPreferenceCreate(&pref);
uint64_t maxWorkSpaceSize = workSpaceSize; //(32MB) // uint64_t maxWorkSpaceSize = workSpaceSize; //(32MB)
cublasLtMatmulPreferenceSetAttribute( // cublasLtMatmulPreferenceSetAttribute(
pref, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &maxWorkSpaceSize, sizeof(maxWorkSpaceSize)); // pref, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &maxWorkSpaceSize, sizeof(maxWorkSpaceSize));
cublasLtMatmulHeuristicResult_t heuristicResultsArray[maxNumTraversal]; // cublasLtMatmulHeuristicResult_t heuristicResultsArray[maxNumTraversal];
cublasLtMatmulAlgoGetHeuristic(ltHandle, // cublasLtMatmulAlgoGetHeuristic(ltHandle,
operationDesc, // operationDesc,
Adesc, // Adesc,
Bdesc, // Bdesc,
Cdesc, // Cdesc,
Ddesc, // Ddesc,
pref, // pref,
maxNumTraversal, // maxNumTraversal,
heuristicResultsArray, // heuristicResultsArray,
&nbAlgoIds); // &nbAlgoIds);
cublasLtMatmulPreferenceDestroy(pref); // cublasLtMatmulPreferenceDestroy(pref);
printf("return %d and run heuristic algo\n", nbAlgoIds); // printf("return %d and run heuristic algo\n", nbAlgoIds);
for (int i = 0; i < nbAlgoIds; i++) { // for (int i = 0; i < nbAlgoIds; i++) {
if (heuristicResultsArray[i].state == CUBLAS_STATUS_SUCCESS) { // if (heuristicResultsArray[i].state == CUBLAS_STATUS_SUCCESS) {
status = customMatmulRun(ltHandle, // status = customMatmulRun(ltHandle,
operationDesc, // operationDesc,
alpha, /* host or device pointer */ // alpha, /* host or device pointer */
A, // A,
Adesc, // Adesc,
B, // B,
Bdesc, // Bdesc,
beta, /* host or device pointer */ // beta, /* host or device pointer */
C, // C,
Cdesc, // Cdesc,
C, // C,
Ddesc, // Ddesc,
heuristicResultsArray[i].algo, // heuristicResultsArray[i].algo,
kernelRepeats, // kernelRepeats,
workSpace, // workSpace,
workSpaceSize, // workSpaceSize,
perfResults[AlgoCount], // perfResults[AlgoCount],
stream, // stream,
startEvent, // startEvent,
stopEvent); // stopEvent);
perfResults[AlgoCount].status = status; // perfResults[AlgoCount].status = status;
if (status == CUBLAS_STATUS_SUCCESS) { // if (status == CUBLAS_STATUS_SUCCESS) {
AlgoCount++; // AlgoCount++;
} // }
} // }
} // }
// workspacesize==0 // // workspacesize==0
printf("workspacesize==0, run %d algos\n", AlgoCountRestrict); // printf("workspacesize==0, run %d algos\n", AlgoCountRestrict);
for (int i = 0; i < AlgoCountRestrict && i < (maxNumTraversal - nbAlgoIds); i++) { // for (int i = 0; i < AlgoCountRestrict && i < (maxNumTraversal - nbAlgoIds); i++) {
status = customMatmulRun(ltHandle, // status = customMatmulRun(ltHandle,
operationDesc, // operationDesc,
alpha, /* host or device pointer */ // alpha, /* host or device pointer */
A, // A,
Adesc, // Adesc,
B, // B,
Bdesc, // Bdesc,
beta, /* host or device pointer */ // beta, /* host or device pointer */
C, // C,
Cdesc, // Cdesc,
C, // C,
Ddesc, // Ddesc,
algosRestrict[i], // algosRestrict[i],
kernelRepeats, // kernelRepeats,
NULL, // NULL,
0, // 0,
perfResults[AlgoCount], // perfResults[AlgoCount],
stream, // stream,
startEvent, // startEvent,
stopEvent); // stopEvent);
perfResults[AlgoCount].status = status; // perfResults[AlgoCount].status = status;
if (status == CUBLAS_STATUS_SUCCESS) { // if (status == CUBLAS_STATUS_SUCCESS) {
AlgoCount++; // AlgoCount++;
} // }
} // }
} // }
// Sort the results per run duration // // Sort the results per run duration
std::sort(perfResults, perfResults + AlgoCount, time_compare); // std::sort(perfResults, perfResults + AlgoCount, time_compare);
// Print timing and perf details // // Print timing and perf details
for (int i = 0, hasPrint = 1; i < AlgoCount; i++) { // for (int i = 0, hasPrint = 1; i < AlgoCount; i++) {
printf("result %03d : ", i); // printf("result %03d : ", i);
hasPrint = printPerfStructure(batch_size, // hasPrint = printPerfStructure(batch_size,
seq_len, // seq_len,
head_num, // head_num,
size_per_head, // size_per_head,
m, // m,
n, // n,
k, // k,
perfResults[i], // perfResults[i],
fout, // fout,
data_type, // data_type,
hasPrint, // hasPrint,
batchCount); // batchCount);
} // }
CLEANUP: // CLEANUP:
// Descriptors are no longer needed as all GPU work was already enqueued // // Descriptors are no longer needed as all GPU work was already enqueued
if (Cdesc) { // if (Cdesc) {
cublasLtMatrixLayoutDestroy(Cdesc); // cublasLtMatrixLayoutDestroy(Cdesc);
} // }
if (Bdesc) { // if (Bdesc) {
cublasLtMatrixLayoutDestroy(Bdesc); // cublasLtMatrixLayoutDestroy(Bdesc);
} // }
if (Adesc) { // if (Adesc) {
cublasLtMatrixLayoutDestroy(Adesc); // cublasLtMatrixLayoutDestroy(Adesc);
} // }
if (operationDesc) { // if (operationDesc) {
cublasLtMatmulDescDestroy(operationDesc); // cublasLtMatmulDescDestroy(operationDesc);
} // }
if (startEvent) { // if (startEvent) {
cudaEventDestroy(startEvent); // cudaEventDestroy(startEvent);
} // }
if (stopEvent) { // if (stopEvent) {
cudaEventDestroy(stopEvent); // cudaEventDestroy(stopEvent);
} // }
return status == CUBLAS_STATUS_SUCCESS ? 0 : 1; // return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
} // }
template int LtHgemmCustomFind(cublasLtHandle_t ltHandle, // template int LtHgemmCustomFind(cublasLtHandle_t ltHandle,
int batch_size, // int batch_size,
int seq_len, // int seq_len,
int head_num, // int head_num,
int size_per_head, // int size_per_head,
int m, // int m,
int n, // int n,
int k, // int k,
const float* alpha, /* host pointer */ // const float* alpha, /* host pointer */
const float* A, // const float* A,
const float* B, // const float* B,
const float* beta, /* host pointer */ // const float* beta, /* host pointer */
float* C, // float* C,
void* workSpace, // void* workSpace,
size_t workSpaceSize, // size_t workSpaceSize,
FILE* fout, // FILE* fout,
customMatmulPerf_t perfResults[], // customMatmulPerf_t perfResults[],
int AlgoCombinations, // int AlgoCombinations,
cudaDataType_t dtype_fp8, // cudaDataType_t dtype_fp8,
int batchCount, // int batchCount,
int64_t strideA, // int64_t strideA,
int64_t strideB, // int64_t strideB,
int64_t strideD); // int64_t strideD);
template int LtHgemmCustomFind(cublasLtHandle_t ltHandle, // template int LtHgemmCustomFind(cublasLtHandle_t ltHandle,
int batch_size, // int batch_size,
int seq_len, // int seq_len,
int head_num, // int head_num,
int size_per_head, // int size_per_head,
int m, // int m,
int n, // int n,
int k, // int k,
const half* alpha, /* host pointer */ // const half* alpha, /* host pointer */
const half* A, // const half* A,
const half* B, // const half* B,
const half* beta, /* host pointer */ // const half* beta, /* host pointer */
half* C, // half* C,
void* workSpace, // void* workSpace,
size_t workSpaceSize, // size_t workSpaceSize,
FILE* fout, // FILE* fout,
customMatmulPerf_t perfResults[], // customMatmulPerf_t perfResults[],
int AlgoCombinations, // int AlgoCombinations,
cudaDataType_t dtype_fp8, // cudaDataType_t dtype_fp8,
int batchCount, // int batchCount,
int64_t strideA, // int64_t strideA,
int64_t strideB, // int64_t strideB,
int64_t strideD); // int64_t strideD);
#ifdef ENABLE_BF16 // #ifdef ENABLE_BF16
template int LtHgemmCustomFind(cublasLtHandle_t ltHandle, // template int LtHgemmCustomFind(cublasLtHandle_t ltHandle,
int batch_size, // int batch_size,
int seq_len, // int seq_len,
int head_num, // int head_num,
int size_per_head, // int size_per_head,
int m, // int m,
int n, // int n,
int k, // int k,
const float* alpha, /* host pointer */ // const float* alpha, /* host pointer */
const __nv_bfloat16* A, // const __nv_bfloat16* A,
const __nv_bfloat16* B, // const __nv_bfloat16* B,
const float* beta, /* host pointer */ // const float* beta, /* host pointer */
__nv_bfloat16* C, // __nv_bfloat16* C,
void* workSpace, // void* workSpace,
size_t workSpaceSize, // size_t workSpaceSize,
FILE* fout, // FILE* fout,
customMatmulPerf_t perfResults[], // customMatmulPerf_t perfResults[],
int AlgoCombinations, // int AlgoCombinations,
cudaDataType_t dtype_fp8, // cudaDataType_t dtype_fp8,
int batchCount, // int batchCount,
int64_t strideA, // int64_t strideA,
int64_t strideB, // int64_t strideB,
int64_t strideD); // int64_t strideD);
#endif // #endif
#ifdef ENABLE_FP8 // #ifdef ENABLE_FP8
template int LtHgemmCustomFind(cublasLtHandle_t ltHandle, // template int LtHgemmCustomFind(cublasLtHandle_t ltHandle,
int batch_size, // int batch_size,
int seq_len, // int seq_len,
int head_num, // int head_num,
int size_per_head, // int size_per_head,
int m, // int m,
int n, // int n,
int k, // int k,
const float* alpha, /* host pointer */ // const float* alpha, /* host pointer */
const __nv_fp8_e4m3* A, // const __nv_fp8_e4m3* A,
const __nv_fp8_e4m3* B, // const __nv_fp8_e4m3* B,
const float* beta, /* host pointer */ // const float* beta, /* host pointer */
__nv_fp8_e4m3* C, // __nv_fp8_e4m3* C,
void* workSpace, // void* workSpace,
size_t workSpaceSize, // size_t workSpaceSize,
FILE* fout, // FILE* fout,
customMatmulPerf_t perfResults[], // customMatmulPerf_t perfResults[],
int AlgoCombinations, // int AlgoCombinations,
cudaDataType_t dtype_fp8, // cudaDataType_t dtype_fp8,
int batchCount, // int batchCount,
int64_t strideA, // int64_t strideA,
int64_t strideB, // int64_t strideB,
int64_t strideD); // int64_t strideD);
#endif // #endif
template int LtHgemmCustomFind(cublasLtHandle_t ltHandle, // template int LtHgemmCustomFind(cublasLtHandle_t ltHandle,
int batch_size, // int batch_size,
int seq_len, // int seq_len,
int head_num, // int head_num,
int size_per_head, // int size_per_head,
int m, // int m,
int n, // int n,
int k, // int k,
const float* alpha, /* host pointer */ // const float* alpha, /* host pointer */
const half* A, // const half* A,
const half* B, // const half* B,
const float* beta, /* host pointer */ // const float* beta, /* host pointer */
half* C, // half* C,
void* workSpace, // void* workSpace,
size_t workSpaceSize, // size_t workSpaceSize,
FILE* fout, // FILE* fout,
customMatmulPerf_t perfResults[], // customMatmulPerf_t perfResults[],
int AlgoCombinations, // int AlgoCombinations,
cudaDataType_t dtype_fp8, // cudaDataType_t dtype_fp8,
int batchCount, // int batchCount,
int64_t strideA, // int64_t strideA,
int64_t strideB, // int64_t strideB,
int64_t strideD); // int64_t strideD);
size_t calGemmTestBufSizeInByte(int batch_size, size_t calGemmTestBufSizeInByte(int batch_size,
int seq_len, int seq_len,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment