Commit ee33e2e7 authored by zhouxiang's avatar zhouxiang
Browse files

support dtk23.10

parent e432dbb0
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple from typing import Tuple
__dcu_version__ = '0.0.13' __dcu_version__ = '0.0.13+gite432dbb.abi0.dtk2310.torch1.13'
__version__ = '0.0.13' __version__ = '0.0.13'
short_version = __version__ short_version = __version__
......
...@@ -12,6 +12,6 @@ setuptools ...@@ -12,6 +12,6 @@ setuptools
shortuuid shortuuid
tiktoken tiktoken
torch torch
transformers>=4.33.0 transformers>=4.33.2
tritonclient[all] tritonclient[all]
uvicorn uvicorn
...@@ -37,14 +37,14 @@ __forceinline__ __device__ float copysignf_pos(float a, float b) ...@@ -37,14 +37,14 @@ __forceinline__ __device__ float copysignf_pos(float a, float b)
__inline__ __device__ float tanh_opt(float x) __inline__ __device__ float tanh_opt(float x)
{ {
#if (__CUDA_ARCH__ >= 750 && CUDART_VERSION >= 11000) // #if (__CUDA_ARCH__ >= 750 && CUDART_VERSION >= 11000)
float r; // float r;
asm("tanh.approx.f32 %0,%1; \n\t" : "=f"(r) : "f"(x)); // asm("tanh.approx.f32 %0,%1; \n\t" : "=f"(r) : "f"(x));
return r; // return r;
#else // #else
const float exp_val = -1.f * fabs(2 * x); const float exp_val = -1.f * fabs(2 * x);
return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x); return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x);
#endif // #endif
} }
template<typename T> template<typename T>
......
...@@ -7,11 +7,11 @@ ...@@ -7,11 +7,11 @@
namespace turbomind { namespace turbomind {
#if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 4) // #if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 4)
#define L2_CACHEHINT(size) ".L2::" #size "B" // #define L2_CACHEHINT(size) ".L2::" #size "B"
#else // #else
#define L2_CACHEHINT(size) #define L2_CACHEHINT(size)
#endif // #endif
template<typename T> template<typename T>
__inline__ __device__ void cp_async_cg_A(uint32_t smem_int_ptr, const T* __restrict__ src, bool mask) __inline__ __device__ void cp_async_cg_A(uint32_t smem_int_ptr, const T* __restrict__ src, bool mask)
......
...@@ -61,12 +61,12 @@ __inline__ __device__ uint transpose_m8n8_b16_movmatrix(uint a) ...@@ -61,12 +61,12 @@ __inline__ __device__ uint transpose_m8n8_b16_movmatrix(uint a)
__inline__ __device__ uint transpose_m8n8_b16(uint a, int lane_id) __inline__ __device__ uint transpose_m8n8_b16(uint a, int lane_id)
{ {
#if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 8) // #if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 8)
(void)lane_id; // (void)lane_id;
return transpose_m8n8_b16_movmatrix(a); // return transpose_m8n8_b16_movmatrix(a);
#else // #else
return transpose_m8n8_b16_warp_shuffle(a, lane_id); return transpose_m8n8_b16_warp_shuffle(a, lane_id);
#endif // #endif
} }
namespace ops { namespace ops {
......
...@@ -16,11 +16,11 @@ ...@@ -16,11 +16,11 @@
#pragma once #pragma once
#include <array> #include <array>
#include <assert.h> #include <assert.h>
#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0)) // #if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))
#include <cooperative_groups/reduce.h> // #include <cooperative_groups/reduce.h>
#else // #else
#include <cooperative_groups.h> #include <cooperative_groups.h>
#endif // #endif
#include "src/turbomind/utils/cuda_bf16_wrapper.h" #include "src/turbomind/utils/cuda_bf16_wrapper.h"
#include "src/turbomind/utils/cuda_type_utils.cuh" #include "src/turbomind/utils/cuda_type_utils.cuh"
#include <cuda_fp16.h> #include <cuda_fp16.h>
...@@ -244,15 +244,15 @@ __inline__ __device__ void cgBlockReduceSumElements(float* element_list, float* ...@@ -244,15 +244,15 @@ __inline__ __device__ void cgBlockReduceSumElements(float* element_list, float*
const int tid = cta.thread_rank(); const int tid = cta.thread_rank();
const int blockz = blockDim.x; const int blockz = blockDim.x;
for (int i = 0; i < NUM; i++) { for (int i = 0; i < NUM; i++) {
#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0)) // #if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))
cgBlockReduceSumElements_shm[i * blockz + tid] = cg::reduce(tile, element_list[i], cg::plus<float>()); // cgBlockReduceSumElements_shm[i * blockz + tid] = cg::reduce(tile, element_list[i], cg::plus<float>());
#else // #else
// TODO Add implementation here // TODO Add implementation here
if (threadIdx.x == 0 && blockIdx.x == 0) { if (threadIdx.x == 0 && blockIdx.x == 0) {
printf("[ERROR] Not support cgBlockReduceSumElements when CUDA < 11 \n"); printf("[ERROR] Not support cgBlockReduceSumElements when CUDA < 11 \n");
assert(false); assert(false);
} }
#endif // #endif
} }
cg::sync(cta); cg::sync(cta);
if (tid == 0) { if (tid == 0) {
......
...@@ -77,11 +77,11 @@ if (BUILD_MULTI_GPU) ...@@ -77,11 +77,11 @@ if (BUILD_MULTI_GPU)
target_link_libraries(nccl_utils PUBLIC ${NCCL_LIBRARIES} logger) target_link_libraries(nccl_utils PUBLIC ${NCCL_LIBRARIES} logger)
endif() endif()
add_library(cublasINT8MMWrapper STATIC cublasINT8MMWrapper.cc) # add_library(cublasINT8MMWrapper STATIC cublasINT8MMWrapper.cc)
#set_property(TARGET cublasINT8MMWrapper PROPERTY POSITION_INDEPENDENT_CODE ON) #set_property(TARGET cublasINT8MMWrapper PROPERTY POSITION_INDEPENDENT_CODE ON)
#set_property(TARGET cublasINT8MMWrapper PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) #set_property(TARGET cublasINT8MMWrapper PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
#target_link_libraries(cublasINT8MMWrapper PUBLIC cublasLt cudart curand cublasAlgoMap cublasMMWrapper cuda_utils logger) #target_link_libraries(cublasINT8MMWrapper PUBLIC cublasLt cudart curand cublasAlgoMap cublasMMWrapper cuda_utils logger)
target_link_libraries(cublasINT8MMWrapper PUBLIC cudart curand cublasAlgoMap cublasMMWrapper cuda_utils logger) # target_link_libraries(cublasINT8MMWrapper PUBLIC cudart curand cublasAlgoMap cublasMMWrapper cuda_utils logger)
if(ENABLE_FP8) if(ENABLE_FP8)
add_library(cublasFP8MMWrapper STATIC cublasFP8MMWrapper.cu) add_library(cublasFP8MMWrapper STATIC cublasFP8MMWrapper.cu)
...@@ -108,7 +108,7 @@ if (SPARSITY_SUPPORT) ...@@ -108,7 +108,7 @@ if (SPARSITY_SUPPORT)
target_link_libraries(gemm PUBLIC cusparse -lcusparseLt) target_link_libraries(gemm PUBLIC cusparse -lcusparseLt)
endif() endif()
add_library(cuda_fp8_utils STATIC cuda_fp8_utils.cu) # add_library(cuda_fp8_utils STATIC cuda_fp8_utils.cu)
#set_property(TARGET cuda_fp8_utils PROPERTY POSITION_INDEPENDENT_CODE ON) #set_property(TARGET cuda_fp8_utils PROPERTY POSITION_INDEPENDENT_CODE ON)
#set_property(TARGET cuda_fp8_utils PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) #set_property(TARGET cuda_fp8_utils PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
......
...@@ -44,9 +44,9 @@ ...@@ -44,9 +44,9 @@
#include "src/turbomind/utils/logger.h" #include "src/turbomind/utils/logger.h"
#if defined(CUDART_VERSION) && CUDART_VERSION < 11020 // #if defined(CUDART_VERSION) && CUDART_VERSION < 11020
#define CUDA_MEMORY_POOL_DISABLED #define CUDA_MEMORY_POOL_DISABLED
#endif // #endif
namespace turbomind { namespace turbomind {
......
...@@ -237,10 +237,10 @@ void cublasFP8MMWrapper::Gemm(__nv_bfloat16* res, ...@@ -237,10 +237,10 @@ void cublasFP8MMWrapper::Gemm(__nv_bfloat16* res,
cublasLtMatmulAlgoConfigSetAttribute( cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &(info.reductionScheme), sizeof(info.reductionScheme)); &algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &(info.reductionScheme), sizeof(info.reductionScheme));
#if (CUDART_VERSION >= 11000) // #if (CUDART_VERSION >= 11000)
cublasLtMatmulAlgoConfigSetAttribute( // cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(info.stages), sizeof(info.stages)); // &algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(info.stages), sizeof(info.stages));
#endif // #endif
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3) #if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
cublasLtMatmulAlgoConfigSetAttribute( cublasLtMatmulAlgoConfigSetAttribute(
...@@ -462,10 +462,10 @@ void cublasFP8MMWrapper::Gemm(__nv_fp8_e4m3* res, ...@@ -462,10 +462,10 @@ void cublasFP8MMWrapper::Gemm(__nv_fp8_e4m3* res,
cublasLtMatmulAlgoConfigSetAttribute( cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &(info.reductionScheme), sizeof(info.reductionScheme)); &algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &(info.reductionScheme), sizeof(info.reductionScheme));
#if (CUDART_VERSION >= 11000) // #if (CUDART_VERSION >= 11000)
cublasLtMatmulAlgoConfigSetAttribute( // cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(info.stages), sizeof(info.stages)); // &algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(info.stages), sizeof(info.stages));
#endif // #endif
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3) #if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
cublasLtMatmulAlgoConfigSetAttribute( cublasLtMatmulAlgoConfigSetAttribute(
......
...@@ -94,11 +94,11 @@ void cublasINT8MMWrapper::Gemm(int* res, ...@@ -94,11 +94,11 @@ void cublasINT8MMWrapper::Gemm(int* res,
{ {
mu_->lock(); mu_->lock();
cublasOperation_t opTranspose = CUBLAS_OP_T; cublasOperation_t opTranspose = CUBLAS_OP_T;
#if (CUDART_VERSION >= 11000) // #if (CUDART_VERSION >= 11000)
cublasComputeType_t computeType = CUBLAS_COMPUTE_32I; // cublasComputeType_t computeType = CUBLAS_COMPUTE_32I;
#else // #else
cudaDataType_t computeType = CUDA_R_32I; cudaDataType_t computeType = CUDA_R_32I;
#endif // #endif
cublasLtMatmulDesc_t matmulDesc; cublasLtMatmulDesc_t matmulDesc;
cublasLtMatrixLayout_t AtransformDesc = NULL; cublasLtMatrixLayout_t AtransformDesc = NULL;
cublasLtMatrixLayout_t BtransformDesc = NULL; cublasLtMatrixLayout_t BtransformDesc = NULL;
...@@ -106,16 +106,16 @@ void cublasINT8MMWrapper::Gemm(int* res, ...@@ -106,16 +106,16 @@ void cublasINT8MMWrapper::Gemm(int* res,
cublasLtOrder_t order_COL32 = CUBLASLT_ORDER_COL32; cublasLtOrder_t order_COL32 = CUBLASLT_ORDER_COL32;
cublasLtOrder_t order_matrixB; cublasLtOrder_t order_matrixB;
#if (CUDART_VERSION >= 11000) // #if (CUDART_VERSION >= 11000)
if (use_ORDER_COL32_2R_4R4_) { // if (use_ORDER_COL32_2R_4R4_) {
order_matrixB = CUBLASLT_ORDER_COL32_2R_4R4; // order_matrixB = CUBLASLT_ORDER_COL32_2R_4R4;
} // }
else { // else {
order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C; // order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C;
} // }
#else // #else
order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C; order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C;
#endif // #endif
int ldaTransform = 32 * m; int ldaTransform = 32 * m;
int ldbTransform; int ldbTransform;
...@@ -128,11 +128,11 @@ void cublasINT8MMWrapper::Gemm(int* res, ...@@ -128,11 +128,11 @@ void cublasINT8MMWrapper::Gemm(int* res,
int ldcTransform = 32 * m; int ldcTransform = 32 * m;
// create matmulDesc // create matmulDesc
#if (CUDART_VERSION >= 11000) // #if (CUDART_VERSION >= 11000)
cublasLtMatmulDescCreate(&matmulDesc, computeType, CUDA_R_32I); // cublasLtMatmulDescCreate(&matmulDesc, computeType, CUDA_R_32I);
#else // #else
cublasLtMatmulDescCreate(&matmulDesc, computeType); cublasLtMatmulDescCreate(&matmulDesc, computeType);
#endif // #endif
cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opTranspose, sizeof(cublasOperation_t)); cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opTranspose, sizeof(cublasOperation_t));
cublasLtMatrixLayoutCreate(&AtransformDesc, CUDA_R_8I, m, k, ldaTransform); cublasLtMatrixLayoutCreate(&AtransformDesc, CUDA_R_8I, m, k, ldaTransform);
cublasLtMatrixLayoutSetAttribute(AtransformDesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL32, sizeof(order_COL32)); cublasLtMatrixLayoutSetAttribute(AtransformDesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL32, sizeof(order_COL32));
...@@ -187,10 +187,10 @@ void cublasINT8MMWrapper::Gemm(int* res, ...@@ -187,10 +187,10 @@ void cublasINT8MMWrapper::Gemm(int* res,
&algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(tmp_info.swizzle), sizeof(tmp_info.swizzle)); &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(tmp_info.swizzle), sizeof(tmp_info.swizzle));
cublasLtMatmulAlgoConfigSetAttribute( cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &(tmp_info.reductionScheme), sizeof(int)); &algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &(tmp_info.reductionScheme), sizeof(int));
#if (CUDART_VERSION >= 11000) // #if (CUDART_VERSION >= 11000)
cublasLtMatmulAlgoConfigSetAttribute( // cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(tmp_info.stages), sizeof(tmp_info.stages)); // &algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(tmp_info.stages), sizeof(tmp_info.stages));
#endif // #endif
} }
else { else {
findAlgo = 1; findAlgo = 1;
...@@ -215,16 +215,16 @@ void cublasINT8MMWrapper::Gemm(int* res, ...@@ -215,16 +215,16 @@ void cublasINT8MMWrapper::Gemm(int* res,
cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(swizzle), sizeof(swizzle)); cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(swizzle), sizeof(swizzle));
cublasLtMatmulAlgoConfigSetAttribute( cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &(reductionScheme), sizeof(int)); &algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &(reductionScheme), sizeof(int));
#if (CUDART_VERSION >= 11000) // #if (CUDART_VERSION >= 11000)
int stages; // int stages;
if (use_ORDER_COL32_2R_4R4_) { // if (use_ORDER_COL32_2R_4R4_) {
stages = 15; // stages = 15;
} // }
else { // else {
stages = 13; // stages = 13;
} // }
cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(stages), sizeof(stages)); // cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(stages), sizeof(stages));
#endif // #endif
} }
cublasLtMatmul(cublaslt_handle_, cublasLtMatmul(cublaslt_handle_,
...@@ -273,11 +273,11 @@ void cublasINT8MMWrapper::Gemm(int8_t* res, ...@@ -273,11 +273,11 @@ void cublasINT8MMWrapper::Gemm(int8_t* res,
// int8 gemm does not support CUBLAS_POINTER_MODE_DEVICE // int8 gemm does not support CUBLAS_POINTER_MODE_DEVICE
// cublasLtPointerMode_t pointerMode = CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO; // cublasLtPointerMode_t pointerMode = CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO;
cudaDataType_t scaleType = CUDA_R_32F; cudaDataType_t scaleType = CUDA_R_32F;
#if (CUDART_VERSION >= 11000) // #if (CUDART_VERSION >= 11000)
cublasComputeType_t computeType = CUBLAS_COMPUTE_32I; // cublasComputeType_t computeType = CUBLAS_COMPUTE_32I;
#else // #else
cudaDataType_t computeType = CUDA_R_32I; cudaDataType_t computeType = CUDA_R_32I;
#endif // #endif
cublasLtMatmulDesc_t matmulDesc; cublasLtMatmulDesc_t matmulDesc;
cublasLtMatrixLayout_t AtransformDesc = NULL; cublasLtMatrixLayout_t AtransformDesc = NULL;
cublasLtMatrixLayout_t BtransformDesc = NULL; cublasLtMatrixLayout_t BtransformDesc = NULL;
...@@ -285,16 +285,16 @@ void cublasINT8MMWrapper::Gemm(int8_t* res, ...@@ -285,16 +285,16 @@ void cublasINT8MMWrapper::Gemm(int8_t* res,
cublasLtOrder_t order_COL32 = CUBLASLT_ORDER_COL32; cublasLtOrder_t order_COL32 = CUBLASLT_ORDER_COL32;
cublasLtOrder_t order_matrixB; cublasLtOrder_t order_matrixB;
#if (CUDART_VERSION >= 11000) // #if (CUDART_VERSION >= 11000)
if (use_ORDER_COL32_2R_4R4_) { // if (use_ORDER_COL32_2R_4R4_) {
order_matrixB = CUBLASLT_ORDER_COL32_2R_4R4; // order_matrixB = CUBLASLT_ORDER_COL32_2R_4R4;
} // }
else { // else {
order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C; // order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C;
} // }
#else // #else
order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C; order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C;
#endif // #endif
int ldaTransform = 32 * m; int ldaTransform = 32 * m;
...@@ -309,11 +309,11 @@ void cublasINT8MMWrapper::Gemm(int8_t* res, ...@@ -309,11 +309,11 @@ void cublasINT8MMWrapper::Gemm(int8_t* res,
int ldcTransform = 32 * m; int ldcTransform = 32 * m;
// create matmulDesc // create matmulDesc
#if (CUDART_VERSION >= 11000) // #if (CUDART_VERSION >= 11000)
cublasLtMatmulDescCreate(&matmulDesc, computeType, scaleType); // cublasLtMatmulDescCreate(&matmulDesc, computeType, scaleType);
#else // #else
cublasLtMatmulDescCreate(&matmulDesc, computeType); cublasLtMatmulDescCreate(&matmulDesc, computeType);
#endif // #endif
cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opTranspose, sizeof(cublasOperation_t)); cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opTranspose, sizeof(cublasOperation_t));
cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_SCALE_TYPE, &scaleType, sizeof(scaleType)); cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_SCALE_TYPE, &scaleType, sizeof(scaleType));
// cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointerMode, // cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointerMode,
...@@ -367,10 +367,10 @@ void cublasINT8MMWrapper::Gemm(int8_t* res, ...@@ -367,10 +367,10 @@ void cublasINT8MMWrapper::Gemm(int8_t* res,
&algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(tmp_info.swizzle), sizeof(tmp_info.swizzle)); &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(tmp_info.swizzle), sizeof(tmp_info.swizzle));
cublasLtMatmulAlgoConfigSetAttribute( cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &(tmp_info.reductionScheme), sizeof(int)); &algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &(tmp_info.reductionScheme), sizeof(int));
#if (CUDART_VERSION >= 11000) // #if (CUDART_VERSION >= 11000)
cublasLtMatmulAlgoConfigSetAttribute( // cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(tmp_info.stages), sizeof(tmp_info.stages)); // &algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(tmp_info.stages), sizeof(tmp_info.stages));
#endif // #endif
} }
else { else {
findAlgo = 1; findAlgo = 1;
...@@ -395,16 +395,16 @@ void cublasINT8MMWrapper::Gemm(int8_t* res, ...@@ -395,16 +395,16 @@ void cublasINT8MMWrapper::Gemm(int8_t* res,
cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(swizzle), sizeof(swizzle)); cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(swizzle), sizeof(swizzle));
cublasLtMatmulAlgoConfigSetAttribute( cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &(reductionScheme), sizeof(int)); &algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &(reductionScheme), sizeof(int));
#if (CUDART_VERSION >= 11000) // #if (CUDART_VERSION >= 11000)
int stages; // int stages;
if (use_ORDER_COL32_2R_4R4_) { // if (use_ORDER_COL32_2R_4R4_) {
stages = 15; // stages = 15;
} // }
else { // else {
stages = 13; // stages = 13;
} // }
cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(stages), sizeof(stages)); // cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(stages), sizeof(stages));
#endif // #endif
} }
float beta = 0.0f; float beta = 0.0f;
......
...@@ -192,118 +192,119 @@ void cublasMMWrapper::Gemm(cublasOperation_t transa, ...@@ -192,118 +192,119 @@ void cublasMMWrapper::Gemm(cublasOperation_t transa,
} }
} }
// if (using_cublasLt) {
if (0) {
cublasLtMatmulDesc_t operationDesc = NULL;
cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
cudaDataType_t scaleType;
#if (CUDART_VERSION >= 11000)
cublasComputeType_t computeType;
#else
cudaDataType_t computeType;
#endif
if (is_fp16_computeType) {
#if (CUDART_VERSION >= 11000)
computeType = CUBLAS_COMPUTE_16F;
#else
computeType = CUDA_R_16F;
#endif
scaleType = CUDA_R_16F;
}
else {
#if (CUDART_VERSION >= 11000)
computeType = CUBLAS_COMPUTE_32F;
#else
computeType = CUDA_R_32F;
#endif
scaleType = CUDA_R_32F;
}
// --------------------------------------
// Create descriptors for the original matrices
cublasLtMatrixLayoutCreate(&Adesc, Atype_, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
cublasLtMatrixLayoutCreate(&Bdesc, Btype_, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
cublasLtMatrixLayoutCreate(&Cdesc, Ctype_, m, n, ldc);
#if (CUDART_VERSION >= 11000)
cublasLtMatmulDescCreate(&operationDesc, computeType, scaleType);
#else
cublasLtMatmulDescCreate(&operationDesc, computeType);
#endif
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(cublasOperation_t));
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(cublasOperation_t));
cublasLtMatmulAlgo_t algo; // if (using_cublasLt) {
void* workSpace = cublas_workspace_; // if (0) {
int workspaceSize = cublas_workspace_ == NULL ? 0 : CUBLAS_WORKSPACE_SIZE; // cublasLtMatmulDesc_t operationDesc = NULL;
if (findAlgo) { // cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
if (info.workspaceSize > workspaceSize) { // cudaDataType_t scaleType;
findAlgo = 0; // #if (CUDART_VERSION >= 11000)
} // cublasComputeType_t computeType;
else { // #else
cublasLtMatmulAlgoInit( // cudaDataType_t computeType;
cublaslt_handle_, computeType, scaleType, Atype_, Btype_, Ctype_, Ctype_, info.algoId, &algo); // #endif
cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &(info.customOption), sizeof(info.customOption)); // if (is_fp16_computeType) {
cublasLtMatmulAlgoConfigSetAttribute( // #if (CUDART_VERSION >= 11000)
&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, &(info.tile), sizeof(info.tile)); // computeType = CUBLAS_COMPUTE_16F;
cublasLtMatmulAlgoConfigSetAttribute( // #else
&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &(info.splitK_val), sizeof(info.splitK_val)); // computeType = CUDA_R_16F;
cublasLtMatmulAlgoConfigSetAttribute( // #endif
&algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(info.swizzle), sizeof(info.swizzle)); // scaleType = CUDA_R_16F;
cublasLtMatmulAlgoConfigSetAttribute(&algo, // }
CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, // else {
&(info.reductionScheme), // #if (CUDART_VERSION >= 11000)
sizeof(info.reductionScheme)); // computeType = CUBLAS_COMPUTE_32F;
// #else
#if (CUDART_VERSION >= 11000) // computeType = CUDA_R_32F;
cublasLtMatmulAlgoConfigSetAttribute( // #endif
&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(info.stages), sizeof(info.stages)); // scaleType = CUDA_R_32F;
#endif // }
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3) // // --------------------------------------
cublasLtMatmulAlgoConfigSetAttribute( // // Create descriptors for the original matrices
&algo, CUBLASLT_ALGO_CONFIG_INNER_SHAPE_ID, &(info.inner_shapeId), sizeof(info.inner_shapeId)); // cublasLtMatrixLayoutCreate(&Adesc, Atype_, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
cublasLtMatmulAlgoConfigSetAttribute(&algo, // cublasLtMatrixLayoutCreate(&Bdesc, Btype_, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
CUBLASLT_ALGO_CONFIG_CLUSTER_SHAPE_ID, // cublasLtMatrixLayoutCreate(&Cdesc, Ctype_, m, n, ldc);
&(info.cluster_shapeId), // #if (CUDART_VERSION >= 11000)
sizeof(info.cluster_shapeId)); // cublasLtMatmulDescCreate(&operationDesc, computeType, scaleType);
#elif (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH < 3) // #else
cublasLtMatmulAlgoConfigSetAttribute( // cublasLtMatmulDescCreate(&operationDesc, computeType);
&algo, CUBLASLT_ALGO_CONFIG_MMA_SHAPE_ID, &(info.mma_shapeId), sizeof(info.mma_shapeId)); // #endif
cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_CGA_SHAPE_ID, &(info.cga_shapeId), sizeof(info.cga_shapeId)); // cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(cublasOperation_t));
cublasLtMatmulAlgoConfigSetAttribute( // cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(cublasOperation_t));
&algo, CUBLASLT_ALGO_CONFIG_SCHEDULING_MODE, &(info.sche_mode), sizeof(info.sche_mode));
#endif // cublasLtMatmulAlgo_t algo;
} // void* workSpace = cublas_workspace_;
} // int workspaceSize = cublas_workspace_ == NULL ? 0 : CUBLAS_WORKSPACE_SIZE;
// if (findAlgo) {
// cublasLtMatmul(cublaslt_handle_, // if (info.workspaceSize > workspaceSize) {
// operationDesc, // findAlgo = 0;
// alpha, // }
// A, // else {
// Adesc, // cublasLtMatmulAlgoInit(
// B, // cublaslt_handle_, computeType, scaleType, Atype_, Btype_, Ctype_, Ctype_, info.algoId, &algo);
// Bdesc, // cublasLtMatmulAlgoConfigSetAttribute(
// beta, // &algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &(info.customOption), sizeof(info.customOption));
// C, // cublasLtMatmulAlgoConfigSetAttribute(
// Cdesc, // &algo, CUBLASLT_ALGO_CONFIG_TILE_ID, &(info.tile), sizeof(info.tile));
// C, // cublasLtMatmulAlgoConfigSetAttribute(
// Cdesc, // &algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &(info.splitK_val), sizeof(info.splitK_val));
// (findAlgo == 1 ? (&algo) : NULL), // cublasLtMatmulAlgoConfigSetAttribute(
// workSpace, // &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(info.swizzle), sizeof(info.swizzle));
// workspaceSize, // cublasLtMatmulAlgoConfigSetAttribute(&algo,
// stream_); // CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME,
// &(info.reductionScheme),
cublasLtMatmulDescDestroy(operationDesc); // sizeof(info.reductionScheme));
cublasLtMatrixLayoutDestroy(Adesc);
cublasLtMatrixLayoutDestroy(Bdesc); // // #if (CUDART_VERSION >= 11000)
cublasLtMatrixLayoutDestroy(Cdesc); // // cublasLtMatmulAlgoConfigSetAttribute(
sync_check_cuda_error(); // // &algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(info.stages), sizeof(info.stages));
} // // #endif
else {
// #if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
// cublasLtMatmulAlgoConfigSetAttribute(
// &algo, CUBLASLT_ALGO_CONFIG_INNER_SHAPE_ID, &(info.inner_shapeId), sizeof(info.inner_shapeId));
// cublasLtMatmulAlgoConfigSetAttribute(&algo,
// CUBLASLT_ALGO_CONFIG_CLUSTER_SHAPE_ID,
// &(info.cluster_shapeId),
// sizeof(info.cluster_shapeId));
// #elif (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH < 3)
// cublasLtMatmulAlgoConfigSetAttribute(
// &algo, CUBLASLT_ALGO_CONFIG_MMA_SHAPE_ID, &(info.mma_shapeId), sizeof(info.mma_shapeId));
// cublasLtMatmulAlgoConfigSetAttribute(
// &algo, CUBLASLT_ALGO_CONFIG_CGA_SHAPE_ID, &(info.cga_shapeId), sizeof(info.cga_shapeId));
// cublasLtMatmulAlgoConfigSetAttribute(
// &algo, CUBLASLT_ALGO_CONFIG_SCHEDULING_MODE, &(info.sche_mode), sizeof(info.sche_mode));
// #endif
// }
// }
// // cublasLtMatmul(cublaslt_handle_,
// // operationDesc,
// // alpha,
// // A,
// // Adesc,
// // B,
// // Bdesc,
// // beta,
// // C,
// // Cdesc,
// // C,
// // Cdesc,
// // (findAlgo == 1 ? (&algo) : NULL),
// // workSpace,
// // workspaceSize,
// // stream_);
// cublasLtMatmulDescDestroy(operationDesc);
// cublasLtMatrixLayoutDestroy(Adesc);
// cublasLtMatrixLayoutDestroy(Bdesc);
// cublasLtMatrixLayoutDestroy(Cdesc);
// sync_check_cuda_error();
// }
// else {
int cublasAlgo = info.algoId; int cublasAlgo = info.algoId;
check_cuda_error(cublasGemmEx(cublas_handle_, check_cuda_error(cublasGemmEx(cublas_handle_,
transa, transa,
...@@ -325,7 +326,7 @@ void cublasMMWrapper::Gemm(cublasOperation_t transa, ...@@ -325,7 +326,7 @@ void cublasMMWrapper::Gemm(cublasOperation_t transa,
computeType_, computeType_,
static_cast<cublasGemmAlgo_t>(cublasAlgo))); static_cast<cublasGemmAlgo_t>(cublasAlgo)));
sync_check_cuda_error(); sync_check_cuda_error();
} // }
mu_->unlock(); mu_->unlock();
} }
...@@ -382,81 +383,81 @@ CublasDataType cublasMMWrapper::getCublasDataType(cudaDataType_t data_type) ...@@ -382,81 +383,81 @@ CublasDataType cublasMMWrapper::getCublasDataType(cudaDataType_t data_type)
return FLOAT_DATATYPE; return FLOAT_DATATYPE;
} }
#if (CUDART_VERSION >= 11000) // #if (CUDART_VERSION >= 11000)
// input, weight, output are row-major // // input, weight, output are row-major
// only works for cublas 11.x // // only works for cublas 11.x
void cublasMMWrapper::Gemm(cublasOperation_t transa, // void cublasMMWrapper::Gemm(cublasOperation_t transa,
cublasOperation_t transb, // cublasOperation_t transb,
const int m, // const int m,
const int n, // const int n,
const int k, // const int k,
const void* A, // const void* A,
const int lda, // const int lda,
const void* B, // const void* B,
const int ldb, // const int ldb,
const void* bias, // const void* bias,
void* C, // void* C,
const int ldc) // const int ldc)
{ // {
TM_LOG_DEBUG(__PRETTY_FUNCTION__); // TM_LOG_DEBUG(__PRETTY_FUNCTION__);
cudaDataType_t Atype, Btype, Ctype; // cudaDataType_t Atype, Btype, Ctype;
cublasComputeType_t computeType; // cublasComputeType_t computeType;
cudaDataType_t scaleType; // cudaDataType_t scaleType;
float alpha_float = 1.0f; // float alpha_float = 1.0f;
float beta_float = 0.0f; // float beta_float = 0.0f;
half alpha_half = half(1.0f); // half alpha_half = half(1.0f);
half beta_half = half(0.0f); // half beta_half = half(0.0f);
void * alpha, *beta; // void * alpha, *beta;
// int is_fp16_computeType = computeType_ == CUDA_R_16F ? 1 : 0; // // int is_fp16_computeType = computeType_ == CUDA_R_16F ? 1 : 0;
if (Atype_ == CUDA_R_32F) { // if (Atype_ == CUDA_R_32F) {
computeType = CUBLAS_COMPUTE_32F_FAST_TF32; // computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
Atype = CUDA_R_32F; // Atype = CUDA_R_32F;
Btype = CUDA_R_32F; // Btype = CUDA_R_32F;
Ctype = CUDA_R_32F; // Ctype = CUDA_R_32F;
scaleType = CUDA_R_32F; // scaleType = CUDA_R_32F;
alpha = &alpha_float; // alpha = &alpha_float;
beta = &beta_float; // beta = &beta_float;
} // }
else if (Atype_ == CUDA_R_16BF) { // else if (Atype_ == CUDA_R_16BF) {
computeType = CUBLAS_COMPUTE_32F_FAST_TF32; // computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
Atype = CUDA_R_16BF; // Atype = CUDA_R_16BF;
Btype = CUDA_R_16BF; // Btype = CUDA_R_16BF;
Ctype = CUDA_R_16BF; // Ctype = CUDA_R_16BF;
scaleType = CUDA_R_32F; // scaleType = CUDA_R_32F;
alpha = &alpha_float; // alpha = &alpha_float;
beta = &beta_float; // beta = &beta_float;
} // }
else { // else {
computeType = CUBLAS_COMPUTE_16F; // computeType = CUBLAS_COMPUTE_16F;
Atype = CUDA_R_16F; // Atype = CUDA_R_16F;
Btype = CUDA_R_16F; // Btype = CUDA_R_16F;
Ctype = CUDA_R_16F; // Ctype = CUDA_R_16F;
scaleType = CUDA_R_16F; // scaleType = CUDA_R_16F;
alpha = &alpha_half; // alpha = &alpha_half;
beta = &beta_half; // beta = &beta_half;
} // }
cublasLtMatmulDesc_t operationDesc = NULL; // cublasLtMatmulDesc_t operationDesc = NULL;
cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL; // cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
cublasLtEpilogue_t epi = CUBLASLT_EPILOGUE_BIAS; // cublasLtEpilogue_t epi = CUBLASLT_EPILOGUE_BIAS;
cublasLtMatrixLayoutCreate(&Adesc, Atype, (transa == CUBLAS_OP_N) ? m : k, (transa == CUBLAS_OP_N) ? k : m, lda); // cublasLtMatrixLayoutCreate(&Adesc, Atype, (transa == CUBLAS_OP_N) ? m : k, (transa == CUBLAS_OP_N) ? k : m, lda);
cublasLtMatrixLayoutCreate(&Bdesc, Btype, (transb == CUBLAS_OP_N) ? k : n, (transb == CUBLAS_OP_N) ? n : k, ldb); // cublasLtMatrixLayoutCreate(&Bdesc, Btype, (transb == CUBLAS_OP_N) ? k : n, (transb == CUBLAS_OP_N) ? n : k, ldb);
cublasLtMatrixLayoutCreate(&Cdesc, Ctype, m, n, ldc); // cublasLtMatrixLayoutCreate(&Cdesc, Ctype, m, n, ldc);
cublasLtMatmulDescCreate(&operationDesc, computeType, scaleType); // cublasLtMatmulDescCreate(&operationDesc, computeType, scaleType);
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(cublasOperation_t)); // cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(cublasOperation_t));
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(cublasOperation_t)); // cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(cublasOperation_t));
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epi, sizeof(cublasLtEpilogue_t)); // cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epi, sizeof(cublasLtEpilogue_t));
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(const void*)); // cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(const void*));
// check_cuda_error(cublasLtMatmul( // // check_cuda_error(cublasLtMatmul(
// cublaslt_handle_, operationDesc, alpha, A, Adesc, B, Bdesc, beta, C, Cdesc, C, Cdesc, NULL, NULL, 0, stream_)); // // cublaslt_handle_, operationDesc, alpha, A, Adesc, B, Bdesc, beta, C, Cdesc, C, Cdesc, NULL, NULL, 0, stream_));
cublasLtMatrixLayoutDestroy(Adesc); // cublasLtMatrixLayoutDestroy(Adesc);
cublasLtMatrixLayoutDestroy(Bdesc); // cublasLtMatrixLayoutDestroy(Bdesc);
cublasLtMatrixLayoutDestroy(Cdesc); // cublasLtMatrixLayoutDestroy(Cdesc);
cublasLtMatmulDescDestroy(operationDesc); // cublasLtMatmulDescDestroy(operationDesc);
} // }
#endif // #endif
void cublasMMWrapper::setStream(cudaStream_t stream) void cublasMMWrapper::setStream(cudaStream_t stream)
{ {
stream_ = stream; stream_ = stream;
......
...@@ -207,20 +207,20 @@ public: ...@@ -207,20 +207,20 @@ public:
CublasDataType getCublasDataType(cudaDataType_t data_type); CublasDataType getCublasDataType(cudaDataType_t data_type);
#if (CUDART_VERSION >= 11000) // #if (CUDART_VERSION >= 11000)
void Gemm(cublasOperation_t transa, // void Gemm(cublasOperation_t transa,
cublasOperation_t transb, // cublasOperation_t transb,
const int m, // const int m,
const int n, // const int n,
const int k, // const int k,
const void* A, // const void* A,
const int lda, // const int lda,
const void* B, // const void* B,
const int ldb, // const int ldb,
const void* bias, // const void* bias,
void* C, // void* C,
const int ldc); // const int ldc);
#endif // #endif
void stridedBatchedGemm(cublasOperation_t transa, void stridedBatchedGemm(cublasOperation_t transa,
cublasOperation_t transb, cublasOperation_t transb,
......
...@@ -152,17 +152,17 @@ void initCustomAllReduceComm(std::vector<std::shared_ptr<AbstractCustomComm>>* c ...@@ -152,17 +152,17 @@ void initCustomAllReduceComm(std::vector<std::shared_ptr<AbstractCustomComm>>* c
return; return;
} }
#if defined(CUDART_VERSION) && CUDART_VERSION >= 11020 // #if defined(CUDART_VERSION) && CUDART_VERSION >= 11020
for (size_t i = 0; i < rank_size; i++) { // for (size_t i = 0; i < rank_size; i++) {
custom_all_reduce_comms->push_back(std::make_shared<CustomAllReduceComm<T>>(rank_size, i)); // custom_all_reduce_comms->push_back(std::make_shared<CustomAllReduceComm<T>>(rank_size, i));
} // }
custom_all_reduce_comms->at(0)->allocateAndExchangePeerAccessPointer(custom_all_reduce_comms); // custom_all_reduce_comms->at(0)->allocateAndExchangePeerAccessPointer(custom_all_reduce_comms);
#else // #else
TM_LOG_WARNING("Custom All Reduce is not supported before CUDA 11.2. Using NCCL as Comm."); TM_LOG_WARNING("Custom All Reduce is not supported before CUDA 11.2. Using NCCL as Comm.");
for (size_t i = 0; i < rank_size; i++) { for (size_t i = 0; i < rank_size; i++) {
custom_all_reduce_comms->push_back(nullptr); custom_all_reduce_comms->push_back(nullptr);
} }
#endif // #endif
} }
// Template instantiation // Template instantiation
......
...@@ -269,82 +269,82 @@ void Gemm::gemm(const GemmOp transa, ...@@ -269,82 +269,82 @@ void Gemm::gemm(const GemmOp transa,
} }
// if (using_cublasLt) { // if (using_cublasLt) {
if(0) { // if(0) {
const size_t a_rows = (a_op == getCublasOperation(GEMM_OP_N)) ? _m : k; // const size_t a_rows = (a_op == getCublasOperation(GEMM_OP_N)) ? _m : k;
const size_t a_cols = (a_op == getCublasOperation(GEMM_OP_N)) ? k : _m; // const size_t a_cols = (a_op == getCublasOperation(GEMM_OP_N)) ? k : _m;
const size_t b_rows = (b_op == getCublasOperation(GEMM_OP_N)) ? k : _n; // const size_t b_rows = (b_op == getCublasOperation(GEMM_OP_N)) ? k : _n;
const size_t b_cols = (b_op == getCublasOperation(GEMM_OP_N)) ? _n : k; // const size_t b_cols = (b_op == getCublasOperation(GEMM_OP_N)) ? _n : k;
cublasLtMatmulDesc_t matmul_desc = NULL; // cublasLtMatmulDesc_t matmul_desc = NULL;
cublasLtMatrixLayout_t a_desc = NULL, b_desc = NULL, c_desc = NULL; // cublasLtMatrixLayout_t a_desc = NULL, b_desc = NULL, c_desc = NULL;
cudaDataType_t scale_type = getCublasDataType(compute_type_); // cudaDataType_t scale_type = getCublasDataType(compute_type_);
auto compute_type = getCublasComputeType(compute_type_); // auto compute_type = getCublasComputeType(compute_type_);
// -------------------------------------- // // --------------------------------------
// Create descriptors for the original matrices // // Create descriptors for the original matrices
cublasLtMatrixLayoutCreate(&a_desc, a_type, a_rows, a_cols, _lda); // cublasLtMatrixLayoutCreate(&a_desc, a_type, a_rows, a_cols, _lda);
cublasLtMatrixLayoutCreate(&b_desc, b_type, b_rows, b_cols, _ldb); // cublasLtMatrixLayoutCreate(&b_desc, b_type, b_rows, b_cols, _ldb);
cublasLtMatrixLayoutCreate(&c_desc, c_type, _m, _n, ldc); // cublasLtMatrixLayoutCreate(&c_desc, c_type, _m, _n, ldc);
#if (CUDART_VERSION >= 11000) // #if (CUDART_VERSION >= 11000)
cublasLtMatmulDescCreate(&matmul_desc, compute_type, scale_type); // cublasLtMatmulDescCreate(&matmul_desc, compute_type, scale_type);
#else // #else
cublasLtMatmulDescCreate(&matmul_desc, compute_type); // cublasLtMatmulDescCreate(&matmul_desc, compute_type);
#endif // #endif
cublasLtMatmulDescSetAttribute(matmul_desc, CUBLASLT_MATMUL_DESC_TRANSA, &a_op, sizeof(cublasOperation_t)); // cublasLtMatmulDescSetAttribute(matmul_desc, CUBLASLT_MATMUL_DESC_TRANSA, &a_op, sizeof(cublasOperation_t));
cublasLtMatmulDescSetAttribute(matmul_desc, CUBLASLT_MATMUL_DESC_TRANSB, &b_op, sizeof(cublasOperation_t)); // cublasLtMatmulDescSetAttribute(matmul_desc, CUBLASLT_MATMUL_DESC_TRANSB, &b_op, sizeof(cublasOperation_t));
cublasLtMatmulAlgo_t algo; // cublasLtMatmulAlgo_t algo;
void* workspace = workspace_; // void* workspace = workspace_;
int workspace_size = workspace_ == nullptr ? 0 : CUBLAS_WORKSPACE_SIZE; // int workspace_size = workspace_ == nullptr ? 0 : CUBLAS_WORKSPACE_SIZE;
if (findAlgo) { // if (findAlgo) {
if (info.workspaceSize > workspace_size) { // if (info.workspaceSize > workspace_size) {
findAlgo = 0; // findAlgo = 0;
} // }
else { // else {
cublasLtMatmulAlgoInit( // cublasLtMatmulAlgoInit(
cublaslt_handle_, compute_type, scale_type, a_type, b_type, c_type, c_type, info.algoId, &algo); // cublaslt_handle_, compute_type, scale_type, a_type, b_type, c_type, c_type, info.algoId, &algo);
cublasLtMatmulAlgoConfigSetAttribute( // cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &(info.customOption), sizeof(info.customOption)); // &algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &(info.customOption), sizeof(info.customOption));
cublasLtMatmulAlgoConfigSetAttribute( // cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, &(info.tile), sizeof(info.tile)); // &algo, CUBLASLT_ALGO_CONFIG_TILE_ID, &(info.tile), sizeof(info.tile));
cublasLtMatmulAlgoConfigSetAttribute( // cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &(info.splitK_val), sizeof(info.splitK_val)); // &algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &(info.splitK_val), sizeof(info.splitK_val));
cublasLtMatmulAlgoConfigSetAttribute( // cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(info.swizzle), sizeof(info.swizzle)); // &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(info.swizzle), sizeof(info.swizzle));
cublasLtMatmulAlgoConfigSetAttribute( // cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &(info.reductionScheme), sizeof(int)); // &algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &(info.reductionScheme), sizeof(int));
#if (CUDART_VERSION >= 11000) // #if (CUDART_VERSION >= 11000)
cublasLtMatmulAlgoConfigSetAttribute( // cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(info.stages), sizeof(info.stages)); // &algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(info.stages), sizeof(info.stages));
#endif // #endif
} // }
} // }
cublasLtMatmul(cublaslt_handle_, // cublasLtMatmul(cublaslt_handle_,
matmul_desc, // matmul_desc,
alpha_ptr, // alpha_ptr,
a_data_ptr, // a_data_ptr,
a_desc, // a_desc,
b_data_ptr, // b_data_ptr,
b_desc, // b_desc,
beta_ptr, // beta_ptr,
C, // C,
c_desc, // c_desc,
C, // C,
c_desc, // c_desc,
(findAlgo == 1 ? (&algo) : NULL), // (findAlgo == 1 ? (&algo) : NULL),
workspace, // workspace,
workspace_size, // workspace_size,
stream_); // stream_);
cublasLtMatmulDescDestroy(matmul_desc); // cublasLtMatmulDescDestroy(matmul_desc);
cublasLtMatrixLayoutDestroy(a_desc); // cublasLtMatrixLayoutDestroy(a_desc);
cublasLtMatrixLayoutDestroy(b_desc); // cublasLtMatrixLayoutDestroy(b_desc);
cublasLtMatrixLayoutDestroy(c_desc); // cublasLtMatrixLayoutDestroy(c_desc);
sync_check_cuda_error(); // sync_check_cuda_error();
} // }
else { // else {
cudaDataType_t compute_type = getCublasDataType(compute_type_); cudaDataType_t compute_type = getCublasDataType(compute_type_);
int cublas_algo = info.algoId; int cublas_algo = info.algoId;
check_cuda_error(cublasGemmEx(cublas_handle_, check_cuda_error(cublasGemmEx(cublas_handle_,
...@@ -367,7 +367,7 @@ void Gemm::gemm(const GemmOp transa, ...@@ -367,7 +367,7 @@ void Gemm::gemm(const GemmOp transa,
compute_type, compute_type,
static_cast<cublasGemmAlgo_t>(cublas_algo))); static_cast<cublasGemmAlgo_t>(cublas_algo)));
sync_check_cuda_error(); sync_check_cuda_error();
} // }
mutex_->unlock(); mutex_->unlock();
} }
...@@ -1035,19 +1035,19 @@ cudaDataType_t getCublasDataType(DataType dtype) ...@@ -1035,19 +1035,19 @@ cudaDataType_t getCublasDataType(DataType dtype)
} }
} }
#if (CUDART_VERSION >= 11000) // #if (CUDART_VERSION >= 11000)
cublasComputeType_t getCublasComputeType(DataType ctype) // cublasComputeType_t getCublasComputeType(DataType ctype)
{ // {
switch (ctype) { // switch (ctype) {
case TYPE_FP16: // case TYPE_FP16:
return CUBLAS_COMPUTE_16F; // return CUBLAS_COMPUTE_16F;
case TYPE_FP32: // case TYPE_FP32:
return CUBLAS_COMPUTE_32F; // return CUBLAS_COMPUTE_32F;
default: // default:
throw GemmNotSupportedException("Not supported cublas compute type."); // throw GemmNotSupportedException("Not supported cublas compute type.");
} // }
} // }
#else // #else
cudaDataType_t getCublasComputeType(DataType ctype) cudaDataType_t getCublasComputeType(DataType ctype)
{ {
switch (ctype) { switch (ctype) {
...@@ -1059,7 +1059,7 @@ cudaDataType_t getCublasComputeType(DataType ctype) ...@@ -1059,7 +1059,7 @@ cudaDataType_t getCublasComputeType(DataType ctype)
throw GemmNotSupportedException("Not supported cublas compute type."); throw GemmNotSupportedException("Not supported cublas compute type.");
} }
} }
#endif // #endif
cublasOperation_t getCublasOperation(GemmOp op) cublasOperation_t getCublasOperation(GemmOp op)
{ {
......
...@@ -622,11 +622,11 @@ std::shared_ptr<Gemm> ...@@ -622,11 +622,11 @@ std::shared_ptr<Gemm>
createGemm(IAllocator* allocator, cudaStream_t stream, bool sparse = false, bool quantized = false); createGemm(IAllocator* allocator, cudaStream_t stream, bool sparse = false, bool quantized = false);
cudaDataType_t getCublasDataType(DataType dtype); cudaDataType_t getCublasDataType(DataType dtype);
#if (CUDART_VERSION >= 11000) // #if (CUDART_VERSION >= 11000)
cublasComputeType_t getCublasComputeType(DataType dtype); // cublasComputeType_t getCublasComputeType(DataType dtype);
#else // #else
cudaDataType_t getCublasComputeType(DataType dtype); cudaDataType_t getCublasComputeType(DataType dtype);
#endif // #endif
cublasOperation_t getCublasOperation(GemmOp op); cublasOperation_t getCublasOperation(GemmOp op);
std::string getGemmOpString(const GemmOp& op); std::string getGemmOpString(const GemmOp& op);
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment