Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
a32a8776
Commit
a32a8776
authored
Apr 07, 2025
by
xgqdut2016
Browse files
issue/130: use int
parent
54b47924
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
4 deletions
+4
-4
src/infiniop/ops/gemm/cpu/gemm_cpu.cc
src/infiniop/ops/gemm/cpu/gemm_cpu.cc
+4
-4
No files found.
src/infiniop/ops/gemm/cpu/gemm_cpu.cc
View file @
a32a8776
...
@@ -43,12 +43,12 @@ void calculate(
...
@@ -43,12 +43,12 @@ void calculate(
std
::
swap
(
a
,
b
);
std
::
swap
(
a
,
b
);
}
}
#pragma omp parallel for collapse(3)
#pragma omp parallel for collapse(3)
for
(
ptrdiff_
t
i
=
0
;
i
<
ptrdiff_t
(
info
.
batch
);
++
i
)
{
for
(
in
t
i
=
0
;
i
<
static_cast
<
int
>
(
info
.
batch
);
++
i
)
{
for
(
ptrdiff_
t
m_
=
0
;
m_
<
ptrdiff_t
(
info
.
m
);
++
m_
)
{
for
(
in
t
m_
=
0
;
m_
<
static_cast
<
int
>
(
info
.
m
);
++
m_
)
{
for
(
ptrdiff_
t
n_
=
0
;
n_
<
ptrdiff_t
(
info
.
n
);
++
n_
)
{
for
(
in
t
n_
=
0
;
n_
<
static_cast
<
int
>
(
info
.
n
);
++
n_
)
{
auto
c_
=
reinterpret_cast
<
Tdata
*>
(
c
)
+
i
*
info
.
c_matrix
.
stride
+
m_
*
info
.
c_matrix
.
row_stride
+
n_
*
info
.
c_matrix
.
col_stride
;
auto
c_
=
reinterpret_cast
<
Tdata
*>
(
c
)
+
i
*
info
.
c_matrix
.
stride
+
m_
*
info
.
c_matrix
.
row_stride
+
n_
*
info
.
c_matrix
.
col_stride
;
float
sum
=
0
;
float
sum
=
0
;
for
(
size_
t
k_
=
0
;
k_
<
info
.
k
;
++
k_
)
{
for
(
in
t
k_
=
0
;
k_
<
static_cast
<
int
>
(
info
.
k
)
;
++
k_
)
{
auto
a_
=
reinterpret_cast
<
const
Tdata
*>
(
a
)
+
i
*
info
.
a_matrix
.
stride
+
m_
*
info
.
a_matrix
.
row_stride
+
k_
*
info
.
a_matrix
.
col_stride
;
auto
a_
=
reinterpret_cast
<
const
Tdata
*>
(
a
)
+
i
*
info
.
a_matrix
.
stride
+
m_
*
info
.
a_matrix
.
row_stride
+
k_
*
info
.
a_matrix
.
col_stride
;
auto
b_
=
reinterpret_cast
<
const
Tdata
*>
(
b
)
+
i
*
info
.
b_matrix
.
stride
+
n_
*
info
.
b_matrix
.
col_stride
+
k_
*
info
.
b_matrix
.
row_stride
;
auto
b_
=
reinterpret_cast
<
const
Tdata
*>
(
b
)
+
i
*
info
.
b_matrix
.
stride
+
n_
*
info
.
b_matrix
.
col_stride
+
k_
*
info
.
b_matrix
.
row_stride
;
if
constexpr
(
std
::
is_same
<
Tdata
,
fp16_t
>::
value
)
{
if
constexpr
(
std
::
is_same
<
Tdata
,
fp16_t
>::
value
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment