Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4de87866
Unverified
Commit
4de87866
authored
Nov 24, 2025
by
R3hankhan
Committed by
GitHub
Nov 24, 2025
Browse files
[CPU][IBM Z] Fix BF16 support and vectorize math operations for s390x (#28926)
Signed-off-by:
Rehan Khan
<
Rehan.Khan7@ibm.com
>
parent
eca7a8fb
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
531 additions
and
57 deletions
+531
-57
csrc/cpu/cpu_attn_impl.hpp
csrc/cpu/cpu_attn_impl.hpp
+1
-1
csrc/cpu/cpu_types_vxe.hpp
csrc/cpu/cpu_types_vxe.hpp
+530
-56
No files found.
csrc/cpu/cpu_attn_impl.hpp
View file @
4de87866
...
...
@@ -847,7 +847,7 @@ struct VecTypeTrait<c10::BFloat16> {
};
#endif
#if !defined(__powerpc__)
#if !defined(__powerpc__)
&& !defined(__s390x__)
template
<
>
struct
VecTypeTrait
<
c10
::
Half
>
{
using
vec_t
=
vec_op
::
FP16Vec16
;
...
...
csrc/cpu/cpu_types_vxe.hpp
View file @
4de87866
...
...
@@ -4,6 +4,7 @@
#include <vecintrin.h>
#include <cmath>
#include <limits>
#include <torch/all.h>
namespace
vec_op
{
...
...
@@ -174,8 +175,9 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
}
explicit
FP32Vec8
(
const
BF16Vec8
&
v
)
{
reg
.
val
[
0
]
=
(
__vector
float
)
vec_mergeh
(
zero
,
v
.
reg
);
reg
.
val
[
1
]
=
(
__vector
float
)
vec_mergel
(
zero
,
v
.
reg
);
// On big-endian s390x, place BF16 first to get correct byte order
reg
.
val
[
0
]
=
(
__vector
float
)
vec_mergeh
(
v
.
reg
,
zero
);
reg
.
val
[
1
]
=
(
__vector
float
)
vec_mergel
(
v
.
reg
,
zero
);
}
float
reduce_sum
()
const
{
...
...
@@ -189,51 +191,257 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
}
FP32Vec8
exp
()
const
{
// TODO: Vectorize this
AliasReg
ar
;
ar
.
reg
=
reg
;
f32x4x4_t
ret
;
ret
.
val
[
0
][
0
]
=
std
::
exp
(
ar
.
values
[
0
]);
ret
.
val
[
0
][
1
]
=
std
::
exp
(
ar
.
values
[
1
]);
ret
.
val
[
0
][
2
]
=
std
::
exp
(
ar
.
values
[
2
]);
ret
.
val
[
0
][
3
]
=
std
::
exp
(
ar
.
values
[
3
]);
ret
.
val
[
1
][
0
]
=
std
::
exp
(
ar
.
values
[
4
]);
ret
.
val
[
1
][
1
]
=
std
::
exp
(
ar
.
values
[
5
]);
ret
.
val
[
1
][
2
]
=
std
::
exp
(
ar
.
values
[
6
]);
ret
.
val
[
1
][
3
]
=
std
::
exp
(
ar
.
values
[
7
]);
return
FP32Vec8
(
f32x4x2_t
({
ret
.
val
[
0
],
ret
.
val
[
1
]}));
f32x4x2_t
out
;
const
__vector
float
log2e
=
vec_splats
(
1.44269504088896341
f
);
const
__vector
float
one
=
vec_splats
(
1.0
f
);
const
__vector
float
min_x
=
vec_splats
(
-
87.3
f
);
const
__vector
float
max_x
=
vec_splats
(
88.7
f
);
// 5th-degree minimax polynomial for 2^r (r in [0,1))
const
__vector
float
c1
=
vec_splats
(
0.6931471805599453
f
);
const
__vector
float
c2
=
vec_splats
(
0.240226506959101
f
);
const
__vector
float
c3
=
vec_splats
(
0.05550410866482158
f
);
const
__vector
float
c4
=
vec_splats
(
0.009618129107628477
f
);
const
__vector
float
c5
=
vec_splats
(
0.0013333558146428443
f
);
for
(
int
i
=
0
;
i
<
2
;
i
++
)
{
__vector
float
x
=
reg
.
val
[
i
];
x
=
vec_max
(
x
,
min_x
);
x
=
vec_min
(
x
,
max_x
);
__vector
float
y
=
vec_mul
(
x
,
log2e
);
__vector
float
kf
=
vec_floor
(
y
);
__vector
float
r
=
vec_sub
(
y
,
kf
);
__vector
signed
int
k
=
vec_signed
(
kf
);
const
__vector
signed
int
min_k
=
vec_splats
((
signed
int
)
-
126
);
const
__vector
signed
int
max_k
=
vec_splats
((
signed
int
)
127
);
k
=
vec_min
(
vec_max
(
k
,
min_k
),
max_k
);
// Build 2^k from exponent bits
__vector
signed
int
exp_int
=
vec_add
(
k
,
vec_splats
((
signed
int
)
127
));
__vector
unsigned
int
bits
=
(
__vector
unsigned
int
)
exp_int
;
bits
=
vec_sl
(
bits
,
vec_splats
((
unsigned
int
)
23
));
__vector
float
pow2k
=
(
__vector
float
)
bits
;
// Improved minimax polynomial
__vector
float
poly
=
vec_madd
(
c5
,
r
,
c4
);
poly
=
vec_madd
(
poly
,
r
,
c3
);
poly
=
vec_madd
(
poly
,
r
,
c2
);
poly
=
vec_madd
(
poly
,
r
,
c1
);
poly
=
vec_madd
(
poly
,
r
,
one
);
out
.
val
[
i
]
=
vec_mul
(
pow2k
,
poly
);
}
return
FP32Vec8
(
out
);
}
FP32Vec8
tanh
()
const
{
// TODO: Vectorize this
AliasReg
ar
;
ar
.
reg
=
reg
;
f32x4x4_t
ret
;
ret
.
val
[
0
][
0
]
=
std
::
tanh
(
ar
.
values
[
0
]);
ret
.
val
[
0
][
1
]
=
std
::
tanh
(
ar
.
values
[
1
]);
ret
.
val
[
0
][
2
]
=
std
::
tanh
(
ar
.
values
[
2
]);
ret
.
val
[
0
][
3
]
=
std
::
tanh
(
ar
.
values
[
3
]);
ret
.
val
[
1
][
0
]
=
std
::
tanh
(
ar
.
values
[
4
]);
ret
.
val
[
1
][
1
]
=
std
::
tanh
(
ar
.
values
[
5
]);
ret
.
val
[
1
][
2
]
=
std
::
tanh
(
ar
.
values
[
6
]);
ret
.
val
[
1
][
3
]
=
std
::
tanh
(
ar
.
values
[
7
]);
return
FP32Vec8
(
f32x4x2_t
({
ret
.
val
[
0
],
ret
.
val
[
1
]}));
// tanh(x) = (exp(2x) - 1) / (exp(2x) + 1)
const
__vector
float
one
=
vec_splats
(
1.0
f
);
const
__vector
float
two
=
vec_splats
(
2.0
f
);
const
__vector
float
zero
=
vec_splats
(
0.0
f
);
const
__vector
float
sat
=
vec_splats
(
9.0
f
);
// beyond this, tanh(x) ~ sign(x)
f32x4x2_t
out
;
for
(
int
i
=
0
;
i
<
2
;
i
++
)
{
__vector
float
x
=
reg
.
val
[
i
];
__vector
float
ax
=
vec_abs
(
x
);
// sign(x): +1 or -1
__vector
float
sign
=
vec_sel
(
vec_splats
(
-
1.0
f
),
one
,
vec_cmpgt
(
x
,
zero
));
// saturation mask: |x| > sat
__vector
__bool
int
saturated
=
vec_cmpgt
(
ax
,
sat
);
// 2x
__vector
float
two_x
=
vec_mul
(
x
,
two
);
// Build a temporary FP32Vec8 with both lanes = 2x, reuse exp()
f32x4x2_t
tmp
;
tmp
.
val
[
0
]
=
two_x
;
tmp
.
val
[
1
]
=
two_x
;
FP32Vec8
exp_2x_vec
(
tmp
);
FP32Vec8
e2x
=
exp_2x_vec
.
exp
();
__vector
float
e
=
e2x
.
reg
.
val
[
i
];
// tanh(x) = (e - 1) / (e + 1)
__vector
float
num
=
vec_sub
(
e
,
one
);
__vector
float
den
=
vec_add
(
e
,
one
);
__vector
float
t
=
vec_div
(
num
,
den
);
// For large |x|, clamp to sign(x)
out
.
val
[
i
]
=
vec_sel
(
t
,
sign
,
saturated
);
}
return
FP32Vec8
(
out
);
}
FP32Vec8
er
()
const
{
// TODO: Vectorize this
AliasReg
ar
;
ar
.
reg
=
reg
;
f32x4x4_t
ret
;
ret
.
val
[
0
][
0
]
=
std
::
erf
(
ar
.
values
[
0
]);
ret
.
val
[
0
][
1
]
=
std
::
erf
(
ar
.
values
[
1
]);
ret
.
val
[
0
][
2
]
=
std
::
erf
(
ar
.
values
[
2
]);
ret
.
val
[
0
][
3
]
=
std
::
erf
(
ar
.
values
[
3
]);
ret
.
val
[
1
][
0
]
=
std
::
erf
(
ar
.
values
[
4
]);
ret
.
val
[
1
][
1
]
=
std
::
erf
(
ar
.
values
[
5
]);
ret
.
val
[
1
][
2
]
=
std
::
erf
(
ar
.
values
[
6
]);
ret
.
val
[
1
][
3
]
=
std
::
erf
(
ar
.
values
[
7
]);
return
FP32Vec8
(
f32x4x2_t
({
ret
.
val
[
0
],
ret
.
val
[
1
]}));
// A&S 7.1.26 approximation:
// erf(x) = sign(x) * (1 - ((((a5*t + a4)*t + a3)*t + a2)*t + a1) * t *
// exp(-x^2)) t = 1 / (1 + p*|x|), p = 0.3275911
const
__vector
float
one
=
vec_splats
(
1.0
f
);
const
__vector
float
zero
=
vec_splats
(
0.0
f
);
const
__vector
float
p
=
vec_splats
(
0.3275911
f
);
// Polynomial coeffs
const
__vector
float
a1
=
vec_splats
(
0.254829592
f
);
const
__vector
float
a2
=
vec_splats
(
-
0.284496736
f
);
const
__vector
float
a3
=
vec_splats
(
1.421413741
f
);
const
__vector
float
a4
=
vec_splats
(
-
1.453152027
f
);
const
__vector
float
a5
=
vec_splats
(
1.061405429
f
);
// Threshold where erf(x) ~ sign(x)
const
__vector
float
sat
=
vec_splats
(
6.0
f
);
f32x4x2_t
out
;
for
(
int
lane
=
0
;
lane
<
2
;
lane
++
)
{
__vector
float
x
=
reg
.
val
[
lane
];
__vector
float
ax
=
vec_abs
(
x
);
// sign(x)
__vector
float
sign
=
vec_sel
(
vec_splats
(
-
1.0
f
),
one
,
vec_cmpgt
(
x
,
zero
));
// |x| > 6 → erf(x) = ±1
__vector
__bool
int
saturated
=
vec_cmpgt
(
ax
,
sat
);
// t = 1 / (1 + p * |x|)
__vector
float
t
=
vec_madd
(
p
,
ax
,
one
);
t
=
vec_div
(
one
,
t
);
// poly = a5
__vector
float
poly
=
a5
;
poly
=
vec_madd
(
poly
,
t
,
a4
);
poly
=
vec_madd
(
poly
,
t
,
a3
);
poly
=
vec_madd
(
poly
,
t
,
a2
);
poly
=
vec_madd
(
poly
,
t
,
a1
);
// full polynomial: poly = poly * t
poly
=
vec_mul
(
poly
,
t
);
// Compute exp(-x^2)
__vector
float
x2
=
vec_mul
(
x
,
x
);
__vector
float
neg_x2
=
vec_neg
(
x2
);
f32x4x2_t
tmp
;
tmp
.
val
[
0
]
=
neg_x2
;
tmp
.
val
[
1
]
=
neg_x2
;
FP32Vec8
exp_neg_x2
(
tmp
);
FP32Vec8
e
=
exp_neg_x2
.
exp
();
__vector
float
ex
=
e
.
reg
.
val
[
lane
];
// erf(x) = sign * (1 - poly * exp(-x^2))
__vector
float
term
=
vec_mul
(
poly
,
ex
);
__vector
float
y
=
vec_sub
(
one
,
term
);
y
=
vec_mul
(
y
,
sign
);
// saturated → ±1
__vector
float
sat_val
=
vec_mul
(
sign
,
one
);
out
.
val
[
lane
]
=
vec_sel
(
y
,
sat_val
,
saturated
);
}
return
FP32Vec8
(
out
);
}
// Elementwise sigmoid(x) = 1 / (1 + exp(-x))
FP32Vec8
sigmoid
()
const
{
const
__vector
float
one
=
vec_splats
(
1.0
f
);
f32x4x2_t
neg
;
for
(
int
i
=
0
;
i
<
2
;
++
i
)
{
neg
.
val
[
i
]
=
vec_neg
(
reg
.
val
[
i
]);
}
FP32Vec8
neg_x
(
neg
);
FP32Vec8
e
=
neg_x
.
exp
();
// exp(-x)
f32x4x2_t
denom
;
for
(
int
i
=
0
;
i
<
2
;
++
i
)
{
denom
.
val
[
i
]
=
vec_add
(
one
,
e
.
reg
.
val
[
i
]);
}
FP32Vec8
denom_vec
(
denom
);
FP32Vec8
one_vec
(
1.0
f
);
return
one_vec
/
denom_vec
;
}
// Tanh-based GELU:
// gelu(x) = 0.5 * x * (1 + tanh(√(2/π) * (x + 0.044715 * x^3)))
FP32Vec8
gelu_tanh
()
const
{
const
__vector
float
k_s2pi
=
vec_splats
(
0.7978845608028654
f
);
// √(2/π)
const
__vector
float
k_0_0447
=
vec_splats
(
0.044715
f
);
f32x4x2_t
x2
,
x3
,
inner
;
for
(
int
i
=
0
;
i
<
2
;
++
i
)
{
__vector
float
x
=
reg
.
val
[
i
];
x2
.
val
[
i
]
=
vec_mul
(
x
,
x
);
// x^2
x3
.
val
[
i
]
=
vec_mul
(
x2
.
val
[
i
],
x
);
// x^3
__vector
float
t
=
vec_madd
(
k_0_0447
,
x3
.
val
[
i
],
x
);
// x + 0.044715*x^3
inner
.
val
[
i
]
=
vec_mul
(
k_s2pi
,
t
);
// √(2/π)*(...)
}
FP32Vec8
inner_vec
(
inner
);
FP32Vec8
t
=
inner_vec
.
tanh
();
// tanh part
FP32Vec8
one_vec
(
1.0
f
);
FP32Vec8
half_vec
(
0.5
f
);
FP32Vec8
x_vec
(
*
this
);
return
x_vec
*
half_vec
*
(
one_vec
+
t
);
}
// Erf-based GELU:
// gelu(x) = 0.5 * x * (1 + erf(x / √2))
FP32Vec8
gelu_erf
()
const
{
const
__vector
float
inv_sqrt2
=
vec_splats
(
0.7071067811865476
f
);
// 1/√2
FP32Vec8
x_vec
(
*
this
);
f32x4x2_t
scaled
;
for
(
int
i
=
0
;
i
<
2
;
++
i
)
{
scaled
.
val
[
i
]
=
vec_mul
(
reg
.
val
[
i
],
inv_sqrt2
);
}
FP32Vec8
x_scaled
(
scaled
);
FP32Vec8
erf_x
=
x_scaled
.
er
();
FP32Vec8
one_vec
(
1.0
f
);
FP32Vec8
half_vec
(
0.5
f
);
return
x_vec
*
half_vec
*
(
one_vec
+
erf_x
);
}
// Elementwise reciprocal: 1/x (scalar per lane, for correctness)
FP32Vec8
rcp
()
const
{
AliasReg
in
,
out
;
in
.
reg
=
reg
;
for
(
int
i
=
0
;
i
<
VEC_ELEM_NUM
;
++
i
)
{
out
.
values
[
i
]
=
1.0
f
/
in
.
values
[
i
];
}
return
FP32Vec8
(
out
.
reg
);
}
// Elementwise rsqrt(x) = 1 / sqrt(x) (scalar per lane, for correctness)
FP32Vec8
rsqrt
()
const
{
AliasReg
in
,
out
;
in
.
reg
=
reg
;
for
(
int
i
=
0
;
i
<
VEC_ELEM_NUM
;
++
i
)
{
out
.
values
[
i
]
=
1.0
f
/
std
::
sqrt
(
in
.
values
[
i
]);
}
return
FP32Vec8
(
out
.
reg
);
}
FP32Vec8
operator
*
(
const
FP32Vec8
&
b
)
const
{
...
...
@@ -316,10 +524,11 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
}
explicit
FP32Vec16
(
const
BF16Vec16
&
v
)
{
reg
.
val
[
0
]
=
(
__vector
float
)
vec_mergeh
(
zero
,
v
.
reg
.
val
[
0
]);
reg
.
val
[
1
]
=
(
__vector
float
)
vec_mergel
(
zero
,
v
.
reg
.
val
[
0
]);
reg
.
val
[
2
]
=
(
__vector
float
)
vec_mergeh
(
zero
,
v
.
reg
.
val
[
1
]);
reg
.
val
[
3
]
=
(
__vector
float
)
vec_mergel
(
zero
,
v
.
reg
.
val
[
1
]);
// On big-endian s390x, place BF16 first to get correct byte order
reg
.
val
[
0
]
=
(
__vector
float
)
vec_mergeh
(
v
.
reg
.
val
[
0
],
zero
);
reg
.
val
[
1
]
=
(
__vector
float
)
vec_mergel
(
v
.
reg
.
val
[
0
],
zero
);
reg
.
val
[
2
]
=
(
__vector
float
)
vec_mergeh
(
v
.
reg
.
val
[
1
],
zero
);
reg
.
val
[
3
]
=
(
__vector
float
)
vec_mergel
(
v
.
reg
.
val
[
1
],
zero
);
}
explicit
FP32Vec16
(
const
BF16Vec8
&
v
)
:
FP32Vec16
(
FP32Vec8
(
v
))
{}
...
...
@@ -376,6 +585,23 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
return
result
;
}
FP32Vec16
max
(
const
FP32Vec16
&
b
)
const
{
return
FP32Vec16
(
f32x4x4_t
({
vec_max
(
reg
.
val
[
0
],
b
.
reg
.
val
[
0
]),
vec_max
(
reg
.
val
[
1
],
b
.
reg
.
val
[
1
]),
vec_max
(
reg
.
val
[
2
],
b
.
reg
.
val
[
2
]),
vec_max
(
reg
.
val
[
3
],
b
.
reg
.
val
[
3
])}));
}
float
reduce_max
()
const
{
AliasReg
ar
;
ar
.
reg
=
reg
;
float
result
=
ar
.
values
[
0
];
unroll_loop
<
int
,
VEC_ELEM_NUM
>
([
&
result
,
&
ar
](
int
i
)
{
if
(
ar
.
values
[
i
]
>
result
)
result
=
ar
.
values
[
i
];
});
return
result
;
}
void
save
(
float
*
ptr
)
const
{
vec_xst
(
reg
.
val
[
0
],
0
,
ptr
);
vec_xst
(
reg
.
val
[
1
],
16
,
ptr
);
...
...
@@ -402,15 +628,14 @@ struct VecType<c10::BFloat16> {
using
vec_type
=
BF16Vec8
;
};
// On s390x, FP16 (Half) is not natively supported, use FP32 vectors instead
using
FP16Vec16
=
FP32Vec16
;
template
<
typename
T
>
void
storeFP32
(
float
v
,
T
*
ptr
)
{
*
ptr
=
v
;
}
inline
void
fma
(
FP32Vec16
&
acc
,
FP32Vec16
&
a
,
FP32Vec16
&
b
)
{
acc
=
acc
+
a
*
b
;
}
namespace
c10
{
struct
BFloat16
{
uint16_t
value
;
// Assume BFloat16 is defined as a struct containing a 16-bit
...
...
@@ -429,6 +654,79 @@ inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16* ptr) {
#define __VEC_CLASS_FP_NAN (1 << 6)
#endif
// Optimized FMA (Fused Multiply-Add) implementations using IBM Z vector
// intrinsics
// FP32Vec4 FMA: acc = acc + (a * b) or equivalently acc = fma(a, b, acc)
FORCE_INLINE
void
fma
(
FP32Vec4
&
acc
,
const
FP32Vec4
&
a
,
const
FP32Vec4
&
b
)
{
acc
.
reg
=
vec_madd
(
a
.
reg
,
b
.
reg
,
acc
.
reg
);
}
// FP32Vec8 FMA: acc = acc + (a * b)
FORCE_INLINE
void
fma
(
FP32Vec8
&
acc
,
const
FP32Vec8
&
a
,
const
FP32Vec8
&
b
)
{
acc
.
reg
.
val
[
0
]
=
vec_madd
(
a
.
reg
.
val
[
0
],
b
.
reg
.
val
[
0
],
acc
.
reg
.
val
[
0
]);
acc
.
reg
.
val
[
1
]
=
vec_madd
(
a
.
reg
.
val
[
1
],
b
.
reg
.
val
[
1
],
acc
.
reg
.
val
[
1
]);
}
// FP32Vec16 FMA: acc = acc + (a * b)
FORCE_INLINE
void
fma
(
FP32Vec16
&
acc
,
const
FP32Vec16
&
a
,
const
FP32Vec16
&
b
)
{
acc
.
reg
.
val
[
0
]
=
vec_madd
(
a
.
reg
.
val
[
0
],
b
.
reg
.
val
[
0
],
acc
.
reg
.
val
[
0
]);
acc
.
reg
.
val
[
1
]
=
vec_madd
(
a
.
reg
.
val
[
1
],
b
.
reg
.
val
[
1
],
acc
.
reg
.
val
[
1
]);
acc
.
reg
.
val
[
2
]
=
vec_madd
(
a
.
reg
.
val
[
2
],
b
.
reg
.
val
[
2
],
acc
.
reg
.
val
[
2
]);
acc
.
reg
.
val
[
3
]
=
vec_madd
(
a
.
reg
.
val
[
3
],
b
.
reg
.
val
[
3
],
acc
.
reg
.
val
[
3
]);
}
// Multiply-Subtract: acc = acc - (a * b)
FORCE_INLINE
void
fms
(
FP32Vec4
&
acc
,
const
FP32Vec4
&
a
,
const
FP32Vec4
&
b
)
{
acc
.
reg
=
vec_msub
(
a
.
reg
,
b
.
reg
,
acc
.
reg
);
}
FORCE_INLINE
void
fms
(
FP32Vec8
&
acc
,
const
FP32Vec8
&
a
,
const
FP32Vec8
&
b
)
{
acc
.
reg
.
val
[
0
]
=
vec_msub
(
a
.
reg
.
val
[
0
],
b
.
reg
.
val
[
0
],
acc
.
reg
.
val
[
0
]);
acc
.
reg
.
val
[
1
]
=
vec_msub
(
a
.
reg
.
val
[
1
],
b
.
reg
.
val
[
1
],
acc
.
reg
.
val
[
1
]);
}
FORCE_INLINE
void
fms
(
FP32Vec16
&
acc
,
const
FP32Vec16
&
a
,
const
FP32Vec16
&
b
)
{
acc
.
reg
.
val
[
0
]
=
vec_msub
(
a
.
reg
.
val
[
0
],
b
.
reg
.
val
[
0
],
acc
.
reg
.
val
[
0
]);
acc
.
reg
.
val
[
1
]
=
vec_msub
(
a
.
reg
.
val
[
1
],
b
.
reg
.
val
[
1
],
acc
.
reg
.
val
[
1
]);
acc
.
reg
.
val
[
2
]
=
vec_msub
(
a
.
reg
.
val
[
2
],
b
.
reg
.
val
[
2
],
acc
.
reg
.
val
[
2
]);
acc
.
reg
.
val
[
3
]
=
vec_msub
(
a
.
reg
.
val
[
3
],
b
.
reg
.
val
[
3
],
acc
.
reg
.
val
[
3
]);
}
// Negative Multiply-Add: acc = -(a * b) + acc
FORCE_INLINE
void
nfma
(
FP32Vec4
&
acc
,
const
FP32Vec4
&
a
,
const
FP32Vec4
&
b
)
{
acc
.
reg
=
vec_nmadd
(
a
.
reg
,
b
.
reg
,
acc
.
reg
);
}
FORCE_INLINE
void
nfma
(
FP32Vec8
&
acc
,
const
FP32Vec8
&
a
,
const
FP32Vec8
&
b
)
{
acc
.
reg
.
val
[
0
]
=
vec_nmadd
(
a
.
reg
.
val
[
0
],
b
.
reg
.
val
[
0
],
acc
.
reg
.
val
[
0
]);
acc
.
reg
.
val
[
1
]
=
vec_nmadd
(
a
.
reg
.
val
[
1
],
b
.
reg
.
val
[
1
],
acc
.
reg
.
val
[
1
]);
}
FORCE_INLINE
void
nfma
(
FP32Vec16
&
acc
,
const
FP32Vec16
&
a
,
const
FP32Vec16
&
b
)
{
acc
.
reg
.
val
[
0
]
=
vec_nmadd
(
a
.
reg
.
val
[
0
],
b
.
reg
.
val
[
0
],
acc
.
reg
.
val
[
0
]);
acc
.
reg
.
val
[
1
]
=
vec_nmadd
(
a
.
reg
.
val
[
1
],
b
.
reg
.
val
[
1
],
acc
.
reg
.
val
[
1
]);
acc
.
reg
.
val
[
2
]
=
vec_nmadd
(
a
.
reg
.
val
[
2
],
b
.
reg
.
val
[
2
],
acc
.
reg
.
val
[
2
]);
acc
.
reg
.
val
[
3
]
=
vec_nmadd
(
a
.
reg
.
val
[
3
],
b
.
reg
.
val
[
3
],
acc
.
reg
.
val
[
3
]);
}
// Negative Multiply-Subtract: acc = -(a * b) - acc
FORCE_INLINE
void
nfms
(
FP32Vec4
&
acc
,
const
FP32Vec4
&
a
,
const
FP32Vec4
&
b
)
{
acc
.
reg
=
vec_nmsub
(
a
.
reg
,
b
.
reg
,
acc
.
reg
);
}
FORCE_INLINE
void
nfms
(
FP32Vec8
&
acc
,
const
FP32Vec8
&
a
,
const
FP32Vec8
&
b
)
{
acc
.
reg
.
val
[
0
]
=
vec_nmsub
(
a
.
reg
.
val
[
0
],
b
.
reg
.
val
[
0
],
acc
.
reg
.
val
[
0
]);
acc
.
reg
.
val
[
1
]
=
vec_nmsub
(
a
.
reg
.
val
[
1
],
b
.
reg
.
val
[
1
],
acc
.
reg
.
val
[
1
]);
}
FORCE_INLINE
void
nfms
(
FP32Vec16
&
acc
,
const
FP32Vec16
&
a
,
const
FP32Vec16
&
b
)
{
acc
.
reg
.
val
[
0
]
=
vec_nmsub
(
a
.
reg
.
val
[
0
],
b
.
reg
.
val
[
0
],
acc
.
reg
.
val
[
0
]);
acc
.
reg
.
val
[
1
]
=
vec_nmsub
(
a
.
reg
.
val
[
1
],
b
.
reg
.
val
[
1
],
acc
.
reg
.
val
[
1
]);
acc
.
reg
.
val
[
2
]
=
vec_nmsub
(
a
.
reg
.
val
[
2
],
b
.
reg
.
val
[
2
],
acc
.
reg
.
val
[
2
]);
acc
.
reg
.
val
[
3
]
=
vec_nmsub
(
a
.
reg
.
val
[
3
],
b
.
reg
.
val
[
3
],
acc
.
reg
.
val
[
3
]);
}
const
static
__vector
unsigned
char
omask
=
{
2
,
3
,
6
,
7
,
10
,
11
,
14
,
15
,
18
,
19
,
22
,
23
,
26
,
27
,
30
,
31
};
const
static
__vector
unsigned
int
bias
=
{
0x00007fff
,
0x00007fff
,
0x00007fff
,
...
...
@@ -441,13 +739,24 @@ const static __vector unsigned int one = {1, 1, 1, 1};
inline
BF16Vec8
::
BF16Vec8
(
const
FP32Vec8
&
v
)
{
__vector
unsigned
int
inp0
=
(
__vector
unsigned
int
)(
v
.
reg
.
val
[
0
]);
__vector
unsigned
int
inp1
=
(
__vector
unsigned
int
)(
v
.
reg
.
val
[
1
]);
__vector
unsigned
int
lsb0
=
inp0
>>
sh16
;
__vector
unsigned
int
lsb1
=
inp1
>>
sh16
;
lsb0
=
lsb0
&
one
;
lsb1
=
lsb1
&
one
;
__vector
unsigned
int
rnd0
=
lsb0
+
bias
;
__vector
unsigned
int
rnd1
=
lsb1
+
bias
;
inp0
=
inp0
+
rnd0
;
inp1
=
inp1
+
rnd1
;
int
cc
;
__vector
__bool
int
sel0
=
vec_fp_test_data_class
(
v
.
reg
.
val
[
0
],
__VEC_CLASS_FP_NAN
,
&
cc
);
__vector
__bool
int
sel1
=
vec_fp_test_data_class
(
v
.
reg
.
val
[
1
],
__VEC_CLASS_FP_NAN
,
&
cc
);
inp0
=
vec_sel
(
inp0
,
nan
,
sel0
)
>>
sh16
;
inp1
=
vec_sel
(
inp1
,
nan
,
sel1
)
>>
sh16
;
inp0
=
vec_sel
(
inp0
,
nan
,
sel0
);
inp1
=
vec_sel
(
inp1
,
nan
,
sel1
);
inp0
=
inp0
>>
sh16
;
inp1
=
inp1
>>
sh16
;
reg
=
(
__vector
signed
short
)
vec_perm
(
inp0
,
inp1
,
omask
);
}
...
...
@@ -456,6 +765,22 @@ inline BF16Vec16::BF16Vec16(const FP32Vec16& v) {
__vector
unsigned
int
inp1
=
(
__vector
unsigned
int
)(
v
.
reg
.
val
[
1
]);
__vector
unsigned
int
inp2
=
(
__vector
unsigned
int
)(
v
.
reg
.
val
[
2
]);
__vector
unsigned
int
inp3
=
(
__vector
unsigned
int
)(
v
.
reg
.
val
[
3
]);
__vector
unsigned
int
lsb0
=
inp0
>>
sh16
;
__vector
unsigned
int
lsb1
=
inp1
>>
sh16
;
__vector
unsigned
int
lsb2
=
inp2
>>
sh16
;
__vector
unsigned
int
lsb3
=
inp3
>>
sh16
;
lsb0
=
lsb0
&
one
;
lsb1
=
lsb1
&
one
;
lsb2
=
lsb2
&
one
;
lsb3
=
lsb3
&
one
;
__vector
unsigned
int
rnd0
=
lsb0
+
bias
;
__vector
unsigned
int
rnd1
=
lsb1
+
bias
;
__vector
unsigned
int
rnd2
=
lsb2
+
bias
;
__vector
unsigned
int
rnd3
=
lsb3
+
bias
;
inp0
=
inp0
+
rnd0
;
inp1
=
inp1
+
rnd1
;
inp2
=
inp2
+
rnd2
;
inp3
=
inp3
+
rnd3
;
int
cc
;
__vector
__bool
int
sel0
=
vec_fp_test_data_class
(
v
.
reg
.
val
[
0
],
__VEC_CLASS_FP_NAN
,
&
cc
);
...
...
@@ -465,15 +790,164 @@ inline BF16Vec16::BF16Vec16(const FP32Vec16& v) {
vec_fp_test_data_class
(
v
.
reg
.
val
[
2
],
__VEC_CLASS_FP_NAN
,
&
cc
);
__vector
__bool
int
sel3
=
vec_fp_test_data_class
(
v
.
reg
.
val
[
3
],
__VEC_CLASS_FP_NAN
,
&
cc
);
inp0
=
vec_sel
(
inp0
,
nan
,
sel0
)
>>
sh16
;
inp1
=
vec_sel
(
inp1
,
nan
,
sel1
)
>>
sh16
;
inp2
=
vec_sel
(
inp2
,
nan
,
sel2
)
>>
sh16
;
inp3
=
vec_sel
(
inp3
,
nan
,
sel3
)
>>
sh16
;
inp0
=
vec_sel
(
inp0
,
nan
,
sel0
);
inp1
=
vec_sel
(
inp1
,
nan
,
sel1
);
inp2
=
vec_sel
(
inp2
,
nan
,
sel2
);
inp3
=
vec_sel
(
inp3
,
nan
,
sel3
);
inp0
=
inp0
>>
sh16
;
inp1
=
inp1
>>
sh16
;
inp2
=
inp2
>>
sh16
;
inp3
=
inp3
>>
sh16
;
reg
.
val
[
0
]
=
(
__vector
signed
short
)
vec_perm
(
inp0
,
inp1
,
omask
);
reg
.
val
[
1
]
=
(
__vector
signed
short
)
vec_perm
(
inp2
,
inp3
,
omask
);
}
inline
void
prefetch
(
const
void
*
addr
)
{
void
__dcbt
(
const
void
*
addr
);
}
// 1D softmax over `n` elements in `input`, writes result to `output`.
// Uses FP32Vec8 for main body, scalar tail handling.
// Requirement: n > 0
FORCE_INLINE
void
softmax_fp32vec8
(
float
*
output
,
const
float
*
input
,
int
n
)
{
if
(
n
<=
0
)
return
;
// ---------- Pass 1: find max ----------
float
max_val
=
-
std
::
numeric_limits
<
float
>::
infinity
();
int
i
=
0
;
for
(;
i
+
FP32Vec8
::
VEC_ELEM_NUM
<=
n
;
i
+=
FP32Vec8
::
VEC_ELEM_NUM
)
{
FP32Vec8
v
(
input
+
i
);
FP32Vec8
::
AliasReg
ar
;
ar
.
reg
=
v
.
reg
;
for
(
int
j
=
0
;
j
<
FP32Vec8
::
VEC_ELEM_NUM
;
++
j
)
{
if
(
ar
.
values
[
j
]
>
max_val
)
max_val
=
ar
.
values
[
j
];
}
}
for
(;
i
<
n
;
++
i
)
{
if
(
input
[
i
]
>
max_val
)
max_val
=
input
[
i
];
}
// ---------- Pass 2: compute exp(x - max) and sum ----------
float
sum
=
0.0
f
;
i
=
0
;
for
(;
i
+
FP32Vec8
::
VEC_ELEM_NUM
<=
n
;
i
+=
FP32Vec8
::
VEC_ELEM_NUM
)
{
float
tmp
[
FP32Vec8
::
VEC_ELEM_NUM
];
for
(
int
j
=
0
;
j
<
FP32Vec8
::
VEC_ELEM_NUM
;
++
j
)
{
tmp
[
j
]
=
input
[
i
+
j
]
-
max_val
;
}
FP32Vec8
v
(
tmp
);
FP32Vec8
e
=
v
.
exp
();
FP32Vec8
::
AliasReg
ar
;
ar
.
reg
=
e
.
reg
;
for
(
int
j
=
0
;
j
<
FP32Vec8
::
VEC_ELEM_NUM
;
++
j
)
{
output
[
i
+
j
]
=
ar
.
values
[
j
];
sum
+=
ar
.
values
[
j
];
}
}
// Tail
for
(;
i
<
n
;
++
i
)
{
float
x
=
input
[
i
]
-
max_val
;
float
ex
=
std
::
exp
(
x
);
// scalar tail
output
[
i
]
=
ex
;
sum
+=
ex
;
}
// ---------- Pass 3: normalize ----------
float
inv_sum
=
1.0
f
/
sum
;
i
=
0
;
for
(;
i
+
FP32Vec8
::
VEC_ELEM_NUM
<=
n
;
i
+=
FP32Vec8
::
VEC_ELEM_NUM
)
{
float
tmp
[
FP32Vec8
::
VEC_ELEM_NUM
];
for
(
int
j
=
0
;
j
<
FP32Vec8
::
VEC_ELEM_NUM
;
++
j
)
{
tmp
[
j
]
=
output
[
i
+
j
]
*
inv_sum
;
}
FP32Vec8
v
(
tmp
);
v
.
save
(
output
+
i
);
}
for
(;
i
<
n
;
++
i
)
{
output
[
i
]
*=
inv_sum
;
}
}
// 1D RMSNorm kernel:
// input: x[0..n-1]
// weight: w[0..n-1] (gamma), may be nullptr
// output: y[i] = x[i] * inv_rms * (weight[i] if weight != nullptr else 1)
// eps: small epsilon for numerical stability
FORCE_INLINE
void
rmsnorm_fp32vec8
(
float
*
output
,
const
float
*
input
,
const
float
*
weight
,
int
n
,
float
eps
)
{
if
(
n
<=
0
)
return
;
// ---------- Pass 1: compute sum of squares ----------
float
sum_sq
=
0.0
f
;
int
i
=
0
;
for
(;
i
+
FP32Vec8
::
VEC_ELEM_NUM
<=
n
;
i
+=
FP32Vec8
::
VEC_ELEM_NUM
)
{
FP32Vec8
x_vec
(
input
+
i
);
FP32Vec8
sq
=
x_vec
*
x_vec
;
FP32Vec8
::
AliasReg
ar
;
ar
.
reg
=
sq
.
reg
;
for
(
int
j
=
0
;
j
<
FP32Vec8
::
VEC_ELEM_NUM
;
++
j
)
{
sum_sq
+=
ar
.
values
[
j
];
}
}
// Tail
for
(;
i
<
n
;
++
i
)
{
float
v
=
input
[
i
];
sum_sq
+=
v
*
v
;
}
float
mean_sq
=
sum_sq
/
static_cast
<
float
>
(
n
);
float
inv_rms
=
1.0
f
/
std
::
sqrt
(
mean_sq
+
eps
);
// ---------- Pass 2: scale (and apply weight if given) ----------
const
float
inv_rms_f
=
inv_rms
;
i
=
0
;
if
(
weight
)
{
// with gamma
for
(;
i
+
FP32Vec8
::
VEC_ELEM_NUM
<=
n
;
i
+=
FP32Vec8
::
VEC_ELEM_NUM
)
{
FP32Vec8
x_vec
(
input
+
i
);
float
wtmp
[
FP32Vec8
::
VEC_ELEM_NUM
];
for
(
int
j
=
0
;
j
<
FP32Vec8
::
VEC_ELEM_NUM
;
++
j
)
{
wtmp
[
j
]
=
weight
[
i
+
j
];
}
FP32Vec8
w_vec
(
wtmp
);
FP32Vec8
scale_vec
(
inv_rms_f
);
FP32Vec8
y
=
x_vec
*
scale_vec
*
w_vec
;
y
.
save
(
output
+
i
);
}
for
(;
i
<
n
;
++
i
)
{
output
[
i
]
=
input
[
i
]
*
inv_rms_f
*
weight
[
i
];
}
}
else
{
// without gamma
for
(;
i
+
FP32Vec8
::
VEC_ELEM_NUM
<=
n
;
i
+=
FP32Vec8
::
VEC_ELEM_NUM
)
{
FP32Vec8
x_vec
(
input
+
i
);
FP32Vec8
scale_vec
(
inv_rms_f
);
FP32Vec8
y
=
x_vec
*
scale_vec
;
y
.
save
(
output
+
i
);
}
for
(;
i
<
n
;
++
i
)
{
output
[
i
]
=
input
[
i
]
*
inv_rms_f
;
}
}
}
// Prefetch data to cache for better memory access performance
FORCE_INLINE
void
prefetch
(
const
void
*
addr
)
{
__builtin_prefetch
(
addr
,
0
,
3
);
// 0=read, 3=high temporal locality
}
};
// namespace vec_op
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment