Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
bc981f39
Commit
bc981f39
authored
Apr 07, 2025
by
xgqdut2016
Browse files
issue/130: delete collapse
parent
a32a8776
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
27 additions
and
24 deletions
+27
-24
src/infiniop/ops/gemm/cpu/gemm_cpu.cc
src/infiniop/ops/gemm/cpu/gemm_cpu.cc
+27
-24
No files found.
src/infiniop/ops/gemm/cpu/gemm_cpu.cc
View file @
bc981f39
...
...
@@ -42,32 +42,35 @@ void calculate(
if
(
info
.
is_transed
)
{
std
::
swap
(
a
,
b
);
}
#pragma omp parallel for collapse(3)
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
info
.
batch
);
++
i
)
{
for
(
int
m_
=
0
;
m_
<
static_cast
<
int
>
(
info
.
m
);
++
m_
)
{
for
(
int
n_
=
0
;
n_
<
static_cast
<
int
>
(
info
.
n
);
++
n_
)
{
auto
c_
=
reinterpret_cast
<
Tdata
*>
(
c
)
+
i
*
info
.
c_matrix
.
stride
+
m_
*
info
.
c_matrix
.
row_stride
+
n_
*
info
.
c_matrix
.
col_stride
;
float
sum
=
0
;
for
(
int
k_
=
0
;
k_
<
static_cast
<
int
>
(
info
.
k
);
++
k_
)
{
auto
a_
=
reinterpret_cast
<
const
Tdata
*>
(
a
)
+
i
*
info
.
a_matrix
.
stride
+
m_
*
info
.
a_matrix
.
row_stride
+
k_
*
info
.
a_matrix
.
col_stride
;
auto
b_
=
reinterpret_cast
<
const
Tdata
*>
(
b
)
+
i
*
info
.
b_matrix
.
stride
+
n_
*
info
.
b_matrix
.
col_stride
+
k_
*
info
.
b_matrix
.
row_stride
;
if
constexpr
(
std
::
is_same
<
Tdata
,
fp16_t
>::
value
)
{
sum
+=
utils
::
cast
<
float
>
(
*
a_
)
*
utils
::
cast
<
float
>
(
*
b_
);
}
else
{
sum
+=
*
a_
*
(
*
b_
);
}
}
if
constexpr
(
std
::
is_same
<
Tdata
,
fp16_t
>::
value
)
{
if
(
beta
==
0
)
{
*
c_
=
utils
::
cast
<
fp16_t
>
(
alpha
*
sum
);
}
else
{
*
c_
=
utils
::
cast
<
fp16_t
>
(
beta
*
utils
::
cast
<
float
>
(
*
c_
)
+
alpha
*
sum
);
}
}
else
{
*
c_
=
beta
*
(
*
c_
)
+
alpha
*
sum
;
}
const
size_t
m_n
=
info
.
m
*
info
.
n
;
const
size_t
n
=
info
.
n
;
#pragma omp parallel for
for
(
ptrdiff_t
index
=
0
;
index
<
ptrdiff_t
(
info
.
batch
*
info
.
m
*
info
.
n
);
++
index
)
{
size_t
i
,
m_
,
n_
;
i
=
index
/
m_n
;
size_t
rem
=
index
-
i
*
m_n
;
// 替代 `%` 用减法
m_
=
rem
/
n
;
n_
=
rem
-
m_
*
n
;
// 替代 `%` 用减法
auto
c_
=
reinterpret_cast
<
Tdata
*>
(
c
)
+
i
*
info
.
c_matrix
.
stride
+
m_
*
info
.
c_matrix
.
row_stride
+
n_
*
info
.
c_matrix
.
col_stride
;
float
sum
=
0
;
for
(
int
k_
=
0
;
k_
<
static_cast
<
int
>
(
info
.
k
);
++
k_
)
{
auto
a_
=
reinterpret_cast
<
const
Tdata
*>
(
a
)
+
i
*
info
.
a_matrix
.
stride
+
m_
*
info
.
a_matrix
.
row_stride
+
k_
*
info
.
a_matrix
.
col_stride
;
auto
b_
=
reinterpret_cast
<
const
Tdata
*>
(
b
)
+
i
*
info
.
b_matrix
.
stride
+
n_
*
info
.
b_matrix
.
col_stride
+
k_
*
info
.
b_matrix
.
row_stride
;
if
constexpr
(
std
::
is_same
<
Tdata
,
fp16_t
>::
value
)
{
sum
+=
utils
::
cast
<
float
>
(
*
a_
)
*
utils
::
cast
<
float
>
(
*
b_
);
}
else
{
sum
+=
*
a_
*
(
*
b_
);
}
}
if
constexpr
(
std
::
is_same
<
Tdata
,
fp16_t
>::
value
)
{
if
(
beta
==
0
)
{
*
c_
=
utils
::
cast
<
fp16_t
>
(
alpha
*
sum
);
}
else
{
*
c_
=
utils
::
cast
<
fp16_t
>
(
beta
*
utils
::
cast
<
float
>
(
*
c_
)
+
alpha
*
sum
);
}
}
else
{
*
c_
=
beta
*
(
*
c_
)
+
alpha
*
sum
;
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment