Unverified Commit 9874946c authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #83 from YdrMaster/main

issue/63/fix: 移除 cuda mat mul 中无意义的模板
parents 75b89b17 5cd13ff8
......@@ -38,90 +38,74 @@ infiniStatus_t Descriptor::create(
return INFINI_STATUS_SUCCESS;
}
template <typename Tdata>
void calculate(
const MatmulInfo &info,
std::shared_ptr<Pool<cublasHandle_t>> &cublas_handle_pool,
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *c,
float beta,
const void *a,
const void *b,
float alpha,
cudaStream_t stream) {
if (info.is_transed) {
std::swap(a, b);
}
void *stream) const {
cudaDataType a_type, b_type, c_type;
cublasComputeType_t compute_type;
if constexpr (std::is_same<Tdata, half>::value) {
switch (_dtype) {
case INFINI_DTYPE_F16:
a_type = b_type = c_type = CUDA_R_16F;
compute_type = CUBLAS_COMPUTE_32F;
} else {
break;
case INFINI_DTYPE_F32:
a_type = b_type = c_type = CUDA_R_32F;
#ifdef ENABLE_SUGON_CUDA_API
compute_type = CUBLAS_COMPUTE_32F;
#else
compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
#endif
break;
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
auto op_a = info.a_matrix.row_stride == 1 ? CUBLAS_OP_N : CUBLAS_OP_T;
auto op_b = info.b_matrix.row_stride == 1 ? CUBLAS_OP_N : CUBLAS_OP_T;
if (_info.is_transed) {
std::swap(a, b);
}
auto op_a = _info.a_matrix.row_stride == 1 ? CUBLAS_OP_N : CUBLAS_OP_T;
auto op_b = _info.b_matrix.row_stride == 1 ? CUBLAS_OP_N : CUBLAS_OP_T;
use_cublas(cublas_handle_pool,
stream,
use_cublas(_opaque->cublas_handle_pool,
(cudaStream_t)stream,
[&](cublasHandle_t handle) {
cublasGemmStridedBatchedEx(
handle,
op_a,
op_b,
static_cast<int>(info.m),
static_cast<int>(info.n),
static_cast<int>(info.k),
static_cast<int>(_info.m),
static_cast<int>(_info.n),
static_cast<int>(_info.k),
&alpha,
a,
a_type,
static_cast<int>(info.a_matrix.ld()),
info.a_matrix.stride,
static_cast<int>(_info.a_matrix.ld()),
_info.a_matrix.stride,
b,
b_type,
static_cast<int>(info.b_matrix.ld()),
info.b_matrix.stride,
static_cast<int>(_info.b_matrix.ld()),
_info.b_matrix.stride,
&beta,
c,
c_type,
static_cast<int>(info.c_matrix.ld()),
info.c_matrix.stride,
static_cast<int>(info.batch),
static_cast<int>(_info.c_matrix.ld()),
_info.c_matrix.stride,
static_cast<int>(_info.batch),
compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
});
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *c,
float beta,
const void *a,
const void *b,
float alpha,
void *stream) const {
switch (_dtype) {
case INFINI_DTYPE_F16:
cuda::calculate<uint16_t>(_info, _opaque->cublas_handle_pool, c, beta, a, b, alpha, (cudaStream_t)stream);
return INFINI_STATUS_SUCCESS;
case INFINI_DTYPE_F32:
cuda::calculate<float>(_info, _opaque->cublas_handle_pool, c, beta, a, b, alpha, (cudaStream_t)stream);
return INFINI_STATUS_SUCCESS;
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace matmul::cuda
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment