"git@developer.sourcefind.cn:tsoc/superbenchmark.git" did not exist on "fdc33f406c29324a47fd4b4a4f49355c8b385034"
Commit 5cd13ff8 authored by YdrMaster's avatar YdrMaster
Browse files

issue/63/fix: 移除 cuda mat mul 中无意义的模板


Signed-off-by: default avatarYdrMaster <ydrml@hotmail.com>
parent 75b89b17
...@@ -38,90 +38,74 @@ infiniStatus_t Descriptor::create( ...@@ -38,90 +38,74 @@ infiniStatus_t Descriptor::create(
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
template <typename Tdata> infiniStatus_t Descriptor::calculate(
void calculate( void *workspace,
const MatmulInfo &info, size_t workspace_size,
std::shared_ptr<Pool<cublasHandle_t>> &cublas_handle_pool,
void *c, void *c,
float beta, float beta,
const void *a, const void *a,
const void *b, const void *b,
float alpha, float alpha,
cudaStream_t stream) { void *stream) const {
if (info.is_transed) {
std::swap(a, b);
}
cudaDataType a_type, b_type, c_type; cudaDataType a_type, b_type, c_type;
cublasComputeType_t compute_type; cublasComputeType_t compute_type;
if constexpr (std::is_same<Tdata, half>::value) {
switch (_dtype) {
case INFINI_DTYPE_F16:
a_type = b_type = c_type = CUDA_R_16F; a_type = b_type = c_type = CUDA_R_16F;
compute_type = CUBLAS_COMPUTE_32F; compute_type = CUBLAS_COMPUTE_32F;
} else { break;
case INFINI_DTYPE_F32:
a_type = b_type = c_type = CUDA_R_32F; a_type = b_type = c_type = CUDA_R_32F;
#ifdef ENABLE_SUGON_CUDA_API #ifdef ENABLE_SUGON_CUDA_API
compute_type = CUBLAS_COMPUTE_32F; compute_type = CUBLAS_COMPUTE_32F;
#else #else
compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
#endif #endif
break;
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
} }
auto op_a = info.a_matrix.row_stride == 1 ? CUBLAS_OP_N : CUBLAS_OP_T; if (_info.is_transed) {
auto op_b = info.b_matrix.row_stride == 1 ? CUBLAS_OP_N : CUBLAS_OP_T; std::swap(a, b);
}
auto op_a = _info.a_matrix.row_stride == 1 ? CUBLAS_OP_N : CUBLAS_OP_T;
auto op_b = _info.b_matrix.row_stride == 1 ? CUBLAS_OP_N : CUBLAS_OP_T;
use_cublas(cublas_handle_pool, use_cublas(_opaque->cublas_handle_pool,
stream, (cudaStream_t)stream,
[&](cublasHandle_t handle) { [&](cublasHandle_t handle) {
cublasGemmStridedBatchedEx( cublasGemmStridedBatchedEx(
handle, handle,
op_a, op_a,
op_b, op_b,
static_cast<int>(info.m), static_cast<int>(_info.m),
static_cast<int>(info.n), static_cast<int>(_info.n),
static_cast<int>(info.k), static_cast<int>(_info.k),
&alpha, &alpha,
a, a,
a_type, a_type,
static_cast<int>(info.a_matrix.ld()), static_cast<int>(_info.a_matrix.ld()),
info.a_matrix.stride, _info.a_matrix.stride,
b, b,
b_type, b_type,
static_cast<int>(info.b_matrix.ld()), static_cast<int>(_info.b_matrix.ld()),
info.b_matrix.stride, _info.b_matrix.stride,
&beta, &beta,
c, c,
c_type, c_type,
static_cast<int>(info.c_matrix.ld()), static_cast<int>(_info.c_matrix.ld()),
info.c_matrix.stride, _info.c_matrix.stride,
static_cast<int>(info.batch), static_cast<int>(_info.batch),
compute_type, compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP); CUBLAS_GEMM_DEFAULT_TENSOR_OP);
}); });
} return INFINI_STATUS_SUCCESS;
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *c,
float beta,
const void *a,
const void *b,
float alpha,
void *stream) const {
switch (_dtype) {
case INFINI_DTYPE_F16:
cuda::calculate<uint16_t>(_info, _opaque->cublas_handle_pool, c, beta, a, b, alpha, (cudaStream_t)stream);
return INFINI_STATUS_SUCCESS;
case INFINI_DTYPE_F32:
cuda::calculate<float>(_info, _opaque->cublas_handle_pool, c, beta, a, b, alpha, (cudaStream_t)stream);
return INFINI_STATUS_SUCCESS;
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} }
} // namespace matmul::cuda } // namespace matmul::cuda
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment