Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
5cd13ff8
Commit
5cd13ff8
authored
Mar 05, 2025
by
YdrMaster
Browse files
issue/63/fix: 移除 cuda mat mul 中无意义的模板
Signed-off-by:
YdrMaster
<
ydrml@hotmail.com
>
parent
75b89b17
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
33 additions
and
49 deletions
+33
-49
src/infiniop/ops/matmul/cuda/matmul_cuda.cu
src/infiniop/ops/matmul/cuda/matmul_cuda.cu
+33
-49
No files found.
src/infiniop/ops/matmul/cuda/matmul_cuda.cu
View file @
5cd13ff8
...
...
@@ -38,90 +38,74 @@ infiniStatus_t Descriptor::create(
return
INFINI_STATUS_SUCCESS
;
}
template
<
typename
Tdata
>
void
calculate
(
const
MatmulInfo
&
info
,
std
::
shared_ptr
<
Pool
<
cublasHandle_t
>>
&
cublas_handle_pool
,
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
size_t
workspace_size
,
void
*
c
,
float
beta
,
const
void
*
a
,
const
void
*
b
,
float
alpha
,
cudaStream_t
stream
)
{
if
(
info
.
is_transed
)
{
std
::
swap
(
a
,
b
);
}
void
*
stream
)
const
{
cudaDataType
a_type
,
b_type
,
c_type
;
cublasComputeType_t
compute_type
;
if
constexpr
(
std
::
is_same
<
Tdata
,
half
>::
value
)
{
switch
(
_dtype
)
{
case
INFINI_DTYPE_F16
:
a_type
=
b_type
=
c_type
=
CUDA_R_16F
;
compute_type
=
CUBLAS_COMPUTE_32F
;
}
else
{
break
;
case
INFINI_DTYPE_F32
:
a_type
=
b_type
=
c_type
=
CUDA_R_32F
;
#ifdef ENABLE_SUGON_CUDA_API
compute_type
=
CUBLAS_COMPUTE_32F
;
#else
compute_type
=
CUBLAS_COMPUTE_32F_FAST_TF32
;
#endif
break
;
default:
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
auto
op_a
=
info
.
a_matrix
.
row_stride
==
1
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
auto
op_b
=
info
.
b_matrix
.
row_stride
==
1
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
if
(
_info
.
is_transed
)
{
std
::
swap
(
a
,
b
);
}
auto
op_a
=
_info
.
a_matrix
.
row_stride
==
1
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
auto
op_b
=
_info
.
b_matrix
.
row_stride
==
1
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
use_cublas
(
cublas_handle_pool
,
stream
,
use_cublas
(
_opaque
->
cublas_handle_pool
,
(
cudaStream_t
)
stream
,
[
&
](
cublasHandle_t
handle
)
{
cublasGemmStridedBatchedEx
(
handle
,
op_a
,
op_b
,
static_cast
<
int
>
(
info
.
m
),
static_cast
<
int
>
(
info
.
n
),
static_cast
<
int
>
(
info
.
k
),
static_cast
<
int
>
(
_
info
.
m
),
static_cast
<
int
>
(
_
info
.
n
),
static_cast
<
int
>
(
_
info
.
k
),
&
alpha
,
a
,
a_type
,
static_cast
<
int
>
(
info
.
a_matrix
.
ld
()),
info
.
a_matrix
.
stride
,
static_cast
<
int
>
(
_
info
.
a_matrix
.
ld
()),
_
info
.
a_matrix
.
stride
,
b
,
b_type
,
static_cast
<
int
>
(
info
.
b_matrix
.
ld
()),
info
.
b_matrix
.
stride
,
static_cast
<
int
>
(
_
info
.
b_matrix
.
ld
()),
_
info
.
b_matrix
.
stride
,
&
beta
,
c
,
c_type
,
static_cast
<
int
>
(
info
.
c_matrix
.
ld
()),
info
.
c_matrix
.
stride
,
static_cast
<
int
>
(
info
.
batch
),
static_cast
<
int
>
(
_
info
.
c_matrix
.
ld
()),
_
info
.
c_matrix
.
stride
,
static_cast
<
int
>
(
_
info
.
batch
),
compute_type
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
});
}
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
size_t
workspace_size
,
void
*
c
,
float
beta
,
const
void
*
a
,
const
void
*
b
,
float
alpha
,
void
*
stream
)
const
{
switch
(
_dtype
)
{
case
INFINI_DTYPE_F16
:
cuda
::
calculate
<
uint16_t
>
(
_info
,
_opaque
->
cublas_handle_pool
,
c
,
beta
,
a
,
b
,
alpha
,
(
cudaStream_t
)
stream
);
return
INFINI_STATUS_SUCCESS
;
case
INFINI_DTYPE_F32
:
cuda
::
calculate
<
float
>
(
_info
,
_opaque
->
cublas_handle_pool
,
c
,
beta
,
a
,
b
,
alpha
,
(
cudaStream_t
)
stream
);
return
INFINI_STATUS_SUCCESS
;
default:
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
return
INFINI_STATUS_SUCCESS
;
}
}
// namespace matmul::cuda
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment