Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
724d9692
Commit
724d9692
authored
Aug 19, 2025
by
wooway777
Browse files
issue/240 - removed some redundant operations
parent
64b5a2bc
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
35 additions
and
105 deletions
+35
-105
src/infiniop/reduce/bang/reduce_bang.h
src/infiniop/reduce/bang/reduce_bang.h
+35
-105
No files found.
src/infiniop/reduce/bang/reduce_bang.h
View file @
724d9692
...
...
@@ -86,36 +86,21 @@ __mlu_func__ float sumBatched(const T *source, T *src, float *dst, int num_eleme
// Copy data to NRAM
__memcpy
(
src
+
offset
,
source
+
processed
,
curr_batch
*
sizeof
(
T
),
GDRAM2NRAM
);
if
constexpr
(
std
::
is_same_v
<
T
,
float
>
)
{
// Process aligned portion
if
(
aligned_batch
>
0
)
{
sumInternal
(
dst
,
(
float
*
)(
src
+
offset
),
aligned_batch
);
res
+=
dst
[
0
];
}
// Process unaligned tail
if
(
remainder
>
0
)
{
for
(
size_t
i
=
aligned_batch
;
i
<
curr_batch
;
++
i
)
{
res
+=
((
float
*
)(
src
+
offset
))[
i
];
}
}
}
else
{
// half/bfloat16 processing path
if
constexpr
(
std
::
is_same_v
<
T
,
half
>
)
{
__bang_half2float
((
float
*
)(
src
+
offset
),
src
+
offset
,
curr_batch
);
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
bfloat16_t
>
)
{
__bang_bfloat162float
((
float
*
)(
src
+
offset
),
src
+
offset
,
curr_batch
);
}
if
constexpr
(
std
::
is_same_v
<
T
,
half
>
)
{
__bang_half2float
((
float
*
)(
src
+
offset
),
src
+
offset
,
curr_batch
);
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
bfloat16_t
>
)
{
__bang_bfloat162float
((
float
*
)(
src
+
offset
),
src
+
offset
,
curr_batch
);
}
// Process aligned portion
if
(
aligned_batch
>
0
)
{
sumInternal
(
dst
,
(
float
*
)(
src
+
offset
),
aligned_batch
);
res
+=
dst
[
0
];
}
// Process unaligned tail
if
(
remainder
>
0
)
{
for
(
size_t
i
=
aligned_batch
;
i
<
curr_batch
;
++
i
)
{
res
+=
((
float
*
)(
src
+
offset
))[
i
];
}
// Process aligned portion
if
(
aligned_batch
>
0
)
{
sumInternal
(
dst
,
(
float
*
)(
src
+
offset
),
aligned_batch
);
res
+=
dst
[
0
];
}
// Process unaligned tail
if
(
remainder
>
0
)
{
for
(
size_t
i
=
aligned_batch
;
i
<
curr_batch
;
++
i
)
{
res
+=
((
float
*
)(
src
+
offset
))[
i
];
}
}
...
...
@@ -140,37 +125,21 @@ __mlu_func__ float sumSquared(const T *source, T *src, float *dst, int num_eleme
__memcpy
(
src
+
offset
,
source
+
processed
,
curr_batch
*
sizeof
(
T
),
GDRAM2NRAM
);
// Find max absolute value
float
max_val
=
0.0
f
;
for
(
size_t
i
=
0
;
i
<
curr_batch
;
++
i
)
{
float
val
=
0.0
f
;
if
constexpr
(
std
::
is_same_v
<
T
,
half
>
)
{
val
=
fabs
(
__half2float
(
src
[
offset
+
i
]));
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
bfloat16_t
>
)
{
val
=
fabs
(
__bfloat162float
(
src
[
offset
+
i
]));
}
else
{
val
=
fabs
(
src
[
offset
+
i
]);
}
max_val
=
std
::
max
(
val
,
max_val
);
}
float
scale
=
(
max_val
>
1e3
f
)
?
1e3
f
/
max_val
:
1.0
f
;
// Prevent overflow
float
sum
=
0.0
f
;
// Scaled computation
for
(
size_t
i
=
0
;
i
<
curr_batch
;
++
i
)
{
float
val
=
0.0
f
;
if
constexpr
(
std
::
is_same_v
<
T
,
half
>
)
{
val
=
__half2float
(
src
[
offset
+
i
])
*
scale
;
val
=
__half2float
(
src
[
offset
+
i
]);
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
bfloat16_t
>
)
{
val
=
__bfloat162float
(
src
[
offset
+
i
])
*
scale
;
val
=
__bfloat162float
(
src
[
offset
+
i
]);
}
else
{
val
=
src
[
offset
+
i
]
*
scale
;
val
=
src
[
offset
+
i
];
}
sum
+=
val
*
val
;
}
res
+=
sum
/
(
scale
*
scale
)
;
res
+=
sum
;
processed
+=
curr_batch
;
}
...
...
@@ -201,64 +170,25 @@ __mlu_func__ float sumSquaredBatched(const T *source, T *src, float *dst, int nu
// Copy data to NRAM
__memcpy
(
src
+
offset
,
source
+
processed
,
curr_batch
*
sizeof
(
T
),
GDRAM2NRAM
);
if
constexpr
(
std
::
is_same_v
<
T
,
float
>
)
{
// float32 processing path
if
(
aligned_batch
>
0
)
{
__bang_mul
((
float
*
)(
src
+
offset
),
(
float
*
)(
src
+
offset
),
(
float
*
)(
src
+
offset
),
aligned_batch
);
sumInternal
(
dst
,
(
float
*
)(
src
+
offset
),
aligned_batch
);
res
+=
dst
[
0
];
}
// Process unaligned tail
if
(
remainder
>
0
)
{
for
(
size_t
i
=
aligned_batch
;
i
<
curr_batch
;
++
i
)
{
float
val
=
((
float
*
)(
src
+
offset
))[
i
];
res
+=
val
*
val
;
}
}
}
else
{
// half/bfloat16 processing path
if
constexpr
(
std
::
is_same_v
<
T
,
half
>
)
{
__bang_half2float
((
float
*
)(
src
+
offset
),
src
+
offset
,
curr_batch
);
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
bfloat16_t
>
)
{
__bang_bfloat162float
((
float
*
)(
src
+
offset
),
src
+
offset
,
curr_batch
);
}
// Find maximum absolute value
float
max_val
=
0.0
f
;
if
(
aligned_batch
>
0
)
{
__bang_abs
((
float
*
)(
src
+
offset
),
(
float
*
)(
src
+
offset
),
aligned_batch
);
sumInternal
(
dst
,
(
float
*
)(
src
+
offset
),
aligned_batch
);
max_val
=
dst
[
0
]
/
(
aligned_batch
/
batch_size
);
}
// Check for max value in tail elements
if
(
remainder
>
0
)
{
for
(
size_t
i
=
aligned_batch
;
i
<
curr_batch
;
++
i
)
{
float
val
=
fabs
(((
float
*
)(
src
+
offset
))[
i
]);
max_val
=
std
::
max
(
max_val
,
val
);
}
}
// Scale and compute squared sum
float
scale
=
(
max_val
>
1e3
f
)
?
1e3
f
/
max_val
:
1.0
f
;
if
constexpr
(
std
::
is_same_v
<
T
,
half
>
)
{
__bang_half2float
((
float
*
)(
src
+
offset
),
src
+
offset
,
curr_batch
);
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
bfloat16_t
>
)
{
__bang_bfloat162float
((
float
*
)(
src
+
offset
),
src
+
offset
,
curr_batch
);
}
// Process aligned portion
if
(
aligned_batch
>
0
)
{
__bang_mul_scalar
((
float
*
)(
src
+
offset
),
(
float
*
)(
src
+
offset
),
scale
,
aligned_batch
);
__bang_mul
((
float
*
)(
src
+
offset
),
(
float
*
)(
src
+
offset
),
(
float
*
)(
src
+
offset
),
aligned_batch
);
sumInternal
(
dst
,
(
float
*
)(
src
+
offset
),
aligned_batch
);
res
+=
dst
[
0
]
/
(
scale
*
scale
);
}
// Process aligned portion
if
(
aligned_batch
>
0
)
{
__bang_mul
((
float
*
)(
src
+
offset
),
(
float
*
)(
src
+
offset
),
(
float
*
)(
src
+
offset
),
aligned_batch
);
sumInternal
(
dst
,
(
float
*
)(
src
+
offset
),
aligned_batch
);
res
+=
dst
[
0
];
}
// Process unaligned tail
if
(
remainder
>
0
)
{
for
(
size_t
i
=
aligned_batch
;
i
<
curr_batch
;
++
i
)
{
float
val
=
((
float
*
)(
src
+
offset
))[
i
]
*
scale
;
res
+=
val
*
val
/
(
scale
*
scale
);
}
// Process unaligned tail
if
(
remainder
>
0
)
{
for
(
size_t
i
=
aligned_batch
;
i
<
curr_batch
;
++
i
)
{
float
val
=
((
float
*
)(
src
+
offset
))[
i
];
res
+=
val
*
val
;
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment