Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e69c990c
Unverified
Commit
e69c990c
authored
Feb 03, 2026
by
Radu Salavat
Committed by
GitHub
Feb 02, 2026
Browse files
[Feature][CPU Backend]: Optimize ARM vectorization backend (#30329)
Signed-off-by:
Radu Salavat
<
radu.salavat@arm.com
>
parent
5eac9a1b
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
552 additions
and
597 deletions
+552
-597
csrc/cpu/cpu_attn_impl.hpp
csrc/cpu/cpu_attn_impl.hpp
+0
-11
csrc/cpu/cpu_types_arm.hpp
csrc/cpu/cpu_types_arm.hpp
+551
-579
csrc/cpu/dnnl_kernels.cpp
csrc/cpu/dnnl_kernels.cpp
+0
-2
csrc/cpu/mla_decode.cpp
csrc/cpu/mla_decode.cpp
+1
-3
csrc/cpu/utils.hpp
csrc/cpu/utils.hpp
+0
-2
No files found.
csrc/cpu/cpu_attn_impl.hpp
View file @
e69c990c
...
@@ -816,14 +816,10 @@ struct VecTypeTrait<float> {
...
@@ -816,14 +816,10 @@ struct VecTypeTrait<float> {
using
vec_t
=
vec_op
::
FP32Vec16
;
using
vec_t
=
vec_op
::
FP32Vec16
;
};
};
// ARM only supports BF16 with ARMv8.6-A extension
#if (defined(__aarch64__) && !defined(ARM_BF16_SUPPORT))
#else
template
<
>
template
<
>
struct
VecTypeTrait
<
c10
::
BFloat16
>
{
struct
VecTypeTrait
<
c10
::
BFloat16
>
{
using
vec_t
=
vec_op
::
BF16Vec16
;
using
vec_t
=
vec_op
::
BF16Vec16
;
};
};
#endif
#if !defined(__powerpc__) && !defined(__s390x__)
#if !defined(__powerpc__) && !defined(__s390x__)
template
<
>
template
<
>
...
@@ -1585,17 +1581,10 @@ class AttentionMainLoop {
...
@@ -1585,17 +1581,10 @@ class AttentionMainLoop {
if
(
use_sink
)
{
if
(
use_sink
)
{
alignas
(
64
)
float
s_aux_fp32
[
16
];
alignas
(
64
)
float
s_aux_fp32
[
16
];
#if defined(__aarch64__) && !defined(ARM_BF16_SUPPORT)
// ARM without native BF16 support: manual conversion
for
(
int
i
=
0
;
i
<
16
;
++
i
)
{
s_aux_fp32
[
i
]
=
static_cast
<
float
>
(
curr_s_aux
[
i
]);
}
#else
// All other platforms have BF16Vec16 available
// All other platforms have BF16Vec16 available
vec_op
::
BF16Vec16
vec_bf16
(
curr_s_aux
);
vec_op
::
BF16Vec16
vec_bf16
(
curr_s_aux
);
vec_op
::
FP32Vec16
vec_fp32
(
vec_bf16
);
vec_op
::
FP32Vec16
vec_fp32
(
vec_bf16
);
vec_fp32
.
save
(
s_aux_fp32
);
vec_fp32
.
save
(
s_aux_fp32
);
#endif
float
*
__restrict__
curr_sum_buffer
=
sum_buffer
;
float
*
__restrict__
curr_sum_buffer
=
sum_buffer
;
float
*
__restrict__
curr_max_buffer
=
max_buffer
;
float
*
__restrict__
curr_max_buffer
=
max_buffer
;
...
...
csrc/cpu/cpu_types_arm.hpp
View file @
e69c990c
#include <cmath>
#include <type_traits>
#include <arm_neon.h>
#include <arm_neon.h>
#include <torch/all.h>
#include <torch/all.h>
#include <cmath>
#include <ATen/cpu/vec/functional.h>
#include <ATen/cpu/vec/vec.h>
#if defined(__APPLE__)
#if defined(__APPLE__)
#include "omp.h"
#include "omp.h"
#endif
#endif
using
namespace
at
::
vec
;
namespace
vec_op
{
namespace
vec_op
{
#ifdef ARM_BF16_SUPPORT
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#else
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
#endif
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
...
@@ -45,667 +46,632 @@ constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F&& f) {
...
@@ -45,667 +46,632 @@ constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F&& f) {
template
<
typename
T
,
T
count
,
typename
F
,
template
<
typename
T
,
T
count
,
typename
F
,
typename
=
std
::
enable_if_t
<
std
::
is_invocable_v
<
F
,
T
>
>>
typename
=
std
::
enable_if_t
<
std
::
is_invocable_v
<
F
,
T
>
>>
constexpr
void
unroll_loop
(
F
&&
f
)
{
inline
constexpr
void
unroll_loop
(
F
&&
f
)
{
unroll_loop_item
(
std
::
make_integer_sequence
<
T
,
count
>
{},
std
::
forward
<
F
>
(
f
));
unroll_loop_item
(
std
::
make_integer_sequence
<
T
,
count
>
{},
std
::
forward
<
F
>
(
f
));
}
}
template
<
typename
T
>
template
<
typename
T
,
typename
...
Ts
>
struct
Vec
{
struct
is_one_of
:
std
::
bool_constant
<
(
std
::
is_same_v
<
T
,
Ts
>
||
...)
>
{};
constexpr
static
int
get_elem_num
()
{
return
T
::
VEC_ELEM_NUM
;
};
};
struct
FP32Vec8
;
template
<
typename
T
,
typename
...
Ts
>
struct
FP32Vec16
;
inline
constexpr
bool
is_one_of_v
=
is_one_of
<
T
,
Ts
...
>::
value
;
struct
FP16Vec8
:
public
Vec
<
FP16Vec8
>
{
struct
uninit_t
{
constexpr
static
int
VEC_ELEM_NUM
=
8
;
explicit
constexpr
uninit_t
()
=
default
;
};
inline
constexpr
uninit_t
uninit
{};
float16x8_t
reg
;
template
<
typename
NxVectorizedTVecReg
,
typename
T
,
int
VEC_ELEM_NUM
>
union
AliasReg
{
NxVectorizedTVecReg
reg
;
T
values
[
VEC_ELEM_NUM
];
};
explicit
FP16Vec8
(
const
void
*
ptr
)
// Template over at::vec::Vectorized<T> to support
:
reg
(
vld1q_f16
(
static_cast
<
const
__fp16
*>
(
ptr
)))
{};
// multiple vectorised registers into 1 of length VEC_REG_NUM val
template
<
int
N
,
typename
T
>
struct
NxVectorizedTVecReg
{
using
value_t
=
T
;
using
VectorizedT
=
Vectorized
<
T
>
;
explicit
FP16Vec8
(
const
FP32Vec8
&
)
;
VectorizedT
val
[
N
]
;
void
save
(
void
*
ptr
)
const
{
vst1q_f16
(
static_cast
<
__fp16
*>
(
ptr
),
reg
);
}
NxVectorizedTVecReg
()
=
default
;
};
NxVectorizedTVecReg
(
const
NxVectorizedTVecReg
&
)
=
default
;
NxVectorizedTVecReg
(
NxVectorizedTVecReg
&&
)
=
default
;
NxVectorizedTVecReg
&
operator
=
(
const
NxVectorizedTVecReg
&
)
=
default
;
NxVectorizedTVecReg
&
operator
=
(
NxVectorizedTVecReg
&&
)
=
default
;
struct
FP16Vec16
:
public
Vec
<
FP16Vec16
>
{
explicit
NxVectorizedTVecReg
(
uninit_t
)
noexcept
{};
constexpr
static
int
VEC_ELEM_NUM
=
16
;
float16x8x2_t
reg
;
FORCE_INLINE
explicit
NxVectorizedTVecReg
(
const
VectorizedT
&
vec_t
)
{
unroll_loop
<
int
,
N
>
([
&
](
int
i
)
{
val
[
i
]
=
vec_t
;
});
};
explicit
FP16Vec16
(
const
void
*
ptr
)
{
FORCE_INLINE
explicit
NxVectorizedTVecReg
(
T
v
)
noexcept
{
reg
.
val
[
0
]
=
vld1q_f16
(
reinterpret_cast
<
const
__fp16
*>
(
ptr
)
);
VectorizedT
vv
(
v
);
reg
.
val
[
1
]
=
vld1q_f16
(
reinterpret_cast
<
const
__fp16
*>
(
ptr
)
+
8
);
unroll_loop
<
int
,
N
>
([
&
](
int
i
)
{
val
[
i
]
=
vv
;
}
);
}
}
// ASIMD does not support non-temporal loads
FORCE_INLINE
explicit
NxVectorizedTVecReg
(
const
void
*
ptr
)
{
load
(
ptr
);
}
explicit
FP16Vec16
(
bool
,
const
void
*
ptr
)
:
FP16Vec16
(
ptr
)
{}
explicit
NxVectorizedTVecReg
(
const
void
*
ptr
,
const
int
elem_num
)
{
load
(
ptr
,
elem_num
);
}
explicit
FP16Vec16
(
const
FP32Vec16
&
vec
);
static
constexpr
int
size
()
noexcept
{
return
N
*
VectorizedT
::
size
();
}
void
save
(
void
*
ptr
)
const
{
vst1q_f16
(
reinterpret_cast
<
__fp16
*>
(
ptr
),
reg
.
val
[
0
]);
FORCE_INLINE
void
save
(
void
*
ptr
)
const
{
vst1q_f16
(
reinterpret_cast
<
__fp16
*>
(
ptr
)
+
8
,
reg
.
val
[
1
]);
value_t
*
base
=
reinterpret_cast
<
value_t
*>
(
ptr
);
unroll_loop
<
int
,
N
>
(
[
&
](
int
i
)
{
val
[
i
].
store
(
base
+
i
*
VectorizedT
::
size
());
});
}
FORCE_INLINE
void
load
(
const
void
*
ptr
)
{
const
value_t
*
base
=
reinterpret_cast
<
const
value_t
*>
(
ptr
);
unroll_loop
<
int
,
N
>
([
&
](
int
i
)
{
val
[
i
]
=
VectorizedT
::
loadu
(
base
+
i
*
VectorizedT
::
size
());
});
}
}
void
save
(
void
*
ptr
,
const
int
elem_num
)
const
{
FORCE_INLINE
void
save
(
void
*
ptr
,
const
int
elem_num
)
const
{
int
full_blocks
=
elem_num
/
NUM_ELEMENTS_REG
(
reg
.
val
[
0
]);
value_t
*
base
=
reinterpret_cast
<
value_t
*>
(
ptr
);
int
remainder
=
elem_num
%
NUM_ELEMENTS_REG
(
reg
.
val
[
0
]);
save_partial
(
base
,
elem_num
);
}
if
(
full_blocks
>
0
)
{
FORCE_INLINE
void
load
(
const
void
*
ptr
,
const
int
elem_num
)
{
vst1q_f16
(
reinterpret_cast
<
__fp16
*>
(
ptr
),
reg
.
val
[
0
]);
const
value_t
*
base
=
reinterpret_cast
<
const
value_t
*>
(
ptr
);
if
(
full_blocks
>
1
)
{
load_partial
(
base
,
elem_num
);
vst1q_f16
(
reinterpret_cast
<
__fp16
*>
(
ptr
)
+
8
,
reg
.
val
[
1
]);
}
}
}
// Note: below is the unrolled version of the following code:
FORCE_INLINE
void
save_partial
(
value_t
*
base
,
int
elem_num
)
const
{
//
const
int
w
=
VectorizedT
::
size
();
// for (int i = 0; i < remainder; ++i) {
int
full
=
elem_num
/
w
;
// reinterpret_cast<__fp16*>(ptr)[full_blocks * 8 + i] =
int
rem
=
elem_num
%
w
;
// vgetq_lane_f16(temp, i);
for
(
int
i
=
0
;
i
<
full
;
i
++
)
val
[
i
].
store
(
base
+
i
*
w
);
// }
if
(
rem
)
val
[
full
].
store
(
base
+
full
*
w
,
rem
);
//
}
// For macOS build (Clang), the arm/neon intrinsics function
// `vgetq_lane_f16` needs the parameter `i` to be constant at compile
// time.
if
(
remainder
>
0
)
{
FORCE_INLINE
void
load_partial
(
const
value_t
*
base
,
int
elem_num
)
{
float16x8_t
temp
=
reg
.
val
[
full_blocks
];
const
int
w
=
VectorizedT
::
size
();
__fp16
*
fp16_ptr
=
reinterpret_cast
<
__fp16
*>
(
ptr
);
int
full
=
elem_num
/
w
;
switch
(
remainder
)
{
int
rem
=
elem_num
%
w
;
case
1
:
for
(
int
i
=
0
;
i
<
full
;
i
++
)
val
[
i
]
=
VectorizedT
::
loadu
(
base
+
i
*
w
);
fp16_ptr
[
full_blocks
*
8
+
0
]
=
vgetq_lane_f16
(
temp
,
0
);
if
(
rem
)
val
[
full
]
=
VectorizedT
::
loadu
(
base
+
full
*
w
,
rem
);
break
;
}
case
2
:
fp16_ptr
[
full_blocks
*
8
+
0
]
=
vgetq_lane_f16
(
temp
,
0
);
fp16_ptr
[
full_blocks
*
8
+
1
]
=
vgetq_lane_f16
(
temp
,
1
);
break
;
case
3
:
fp16_ptr
[
full_blocks
*
8
+
0
]
=
vgetq_lane_f16
(
temp
,
0
);
fp16_ptr
[
full_blocks
*
8
+
1
]
=
vgetq_lane_f16
(
temp
,
1
);
fp16_ptr
[
full_blocks
*
8
+
2
]
=
vgetq_lane_f16
(
temp
,
2
);
break
;
case
4
:
fp16_ptr
[
full_blocks
*
8
+
0
]
=
vgetq_lane_f16
(
temp
,
0
);
fp16_ptr
[
full_blocks
*
8
+
1
]
=
vgetq_lane_f16
(
temp
,
1
);
fp16_ptr
[
full_blocks
*
8
+
2
]
=
vgetq_lane_f16
(
temp
,
2
);
fp16_ptr
[
full_blocks
*
8
+
3
]
=
vgetq_lane_f16
(
temp
,
3
);
break
;
case
5
:
fp16_ptr
[
full_blocks
*
8
+
0
]
=
vgetq_lane_f16
(
temp
,
0
);
fp16_ptr
[
full_blocks
*
8
+
1
]
=
vgetq_lane_f16
(
temp
,
1
);
fp16_ptr
[
full_blocks
*
8
+
2
]
=
vgetq_lane_f16
(
temp
,
2
);
fp16_ptr
[
full_blocks
*
8
+
3
]
=
vgetq_lane_f16
(
temp
,
3
);
fp16_ptr
[
full_blocks
*
8
+
4
]
=
vgetq_lane_f16
(
temp
,
4
);
break
;
case
6
:
fp16_ptr
[
full_blocks
*
8
+
0
]
=
vgetq_lane_f16
(
temp
,
0
);
fp16_ptr
[
full_blocks
*
8
+
1
]
=
vgetq_lane_f16
(
temp
,
1
);
fp16_ptr
[
full_blocks
*
8
+
2
]
=
vgetq_lane_f16
(
temp
,
2
);
fp16_ptr
[
full_blocks
*
8
+
3
]
=
vgetq_lane_f16
(
temp
,
3
);
fp16_ptr
[
full_blocks
*
8
+
4
]
=
vgetq_lane_f16
(
temp
,
4
);
fp16_ptr
[
full_blocks
*
8
+
5
]
=
vgetq_lane_f16
(
temp
,
5
);
break
;
case
7
:
fp16_ptr
[
full_blocks
*
8
+
0
]
=
vgetq_lane_f16
(
temp
,
0
);
fp16_ptr
[
full_blocks
*
8
+
1
]
=
vgetq_lane_f16
(
temp
,
1
);
fp16_ptr
[
full_blocks
*
8
+
2
]
=
vgetq_lane_f16
(
temp
,
2
);
fp16_ptr
[
full_blocks
*
8
+
3
]
=
vgetq_lane_f16
(
temp
,
3
);
fp16_ptr
[
full_blocks
*
8
+
4
]
=
vgetq_lane_f16
(
temp
,
4
);
fp16_ptr
[
full_blocks
*
8
+
5
]
=
vgetq_lane_f16
(
temp
,
5
);
fp16_ptr
[
full_blocks
*
8
+
6
]
=
vgetq_lane_f16
(
temp
,
6
);
break
;
default:
template
<
VectorizedT
(
VectorizedT
::*
torch_vec_func
)()
const
,
break
;
value_t
(
*
std_func
)(
value_t
)>
FORCE_INLINE
NxVectorizedTVecReg
opt_vec_func_impl
()
const
{
NxVectorizedTVecReg
result
;
if
constexpr
(
torch_vec_func
!=
nullptr
)
{
unroll_loop
<
int
,
N
>
(
[
&
](
int
i
)
{
result
.
val
[
i
]
=
(
val
[
i
].
*
torch_vec_func
)();
});
}
else
{
for
(
int
i
=
0
;
i
<
N
;
i
++
)
{
alignas
(
64
)
value_t
buf
[
VectorizedT
::
size
()];
val
[
i
].
store
(
buf
);
for
(
int
j
=
0
;
j
<
VectorizedT
::
size
();
++
j
)
{
buf
[
j
]
=
std_func
(
buf
[
j
]);
}
result
.
val
[
i
]
=
VectorizedT
::
loadu
(
buf
);
}
}
}
}
return
result
;
}
}
};
};
#ifdef ARM_BF16_SUPPORT
template
<
typename
DerivedClassT
,
int
N
,
typename
T
>
struct
BF16Vec8
:
public
Vec
<
BF16Vec8
>
{
struct
VectorizedRegWrapper
{
constexpr
static
int
VEC_ELEM_NUM
=
8
;
using
ScalarT
=
T
;
using
VectorizedT
=
Vectorized
<
T
>
;
using
NxVectorizedTArray
=
NxVectorizedTVecReg
<
N
,
T
>
;
constexpr
static
int
VEC_REG_NUM
=
N
;
constexpr
static
int
VEC_ELEM_NUM
=
VEC_REG_NUM
*
VectorizedT
::
size
();
constexpr
static
int
get_elem_num
()
{
return
VEC_ELEM_NUM
;
};
NxVectorizedTArray
reg
;
VectorizedRegWrapper
()
noexcept
=
default
;
explicit
VectorizedRegWrapper
(
uninit_t
)
noexcept
:
reg
{
uninit
}
{};
explicit
VectorizedRegWrapper
(
T
v
)
:
reg
(
v
)
{};
explicit
VectorizedRegWrapper
(
const
void
*
ptr
)
:
reg
(
ptr
)
{};
explicit
VectorizedRegWrapper
(
const
void
*
ptr
,
const
int
elem_num
)
:
reg
(
ptr
,
elem_num
)
{};
explicit
VectorizedRegWrapper
(
const
VectorizedT
&
r
)
:
reg
(
r
)
{};
explicit
VectorizedRegWrapper
(
const
NxVectorizedTArray
&
r
)
:
reg
(
r
)
{};
VectorizedRegWrapper
(
const
VectorizedRegWrapper
&
)
=
default
;
VectorizedRegWrapper
(
VectorizedRegWrapper
&&
)
=
default
;
VectorizedRegWrapper
&
operator
=
(
VectorizedRegWrapper
&&
)
=
default
;
VectorizedRegWrapper
&
operator
=
(
const
VectorizedRegWrapper
&
)
=
default
;
FORCE_INLINE
void
save
(
void
*
ptr
)
const
{
reg
.
save
(
ptr
);
}
void
save
(
void
*
ptr
,
const
int
elem_num
)
const
{
reg
.
save
(
ptr
,
elem_num
);
}
// Define optimized functions using at::vec::Vectorized<T> where possible
// Fallback to std:: functions when not available
#define OPT_TORCH_IMPL(FUNC_NAME, STD_FUNC_NAME, TORCH_FUNC_NAME, ...) \
FORCE_INLINE DerivedClassT FUNC_NAME() const { \
if constexpr (is_one_of_v<T, __VA_ARGS__>) { \
return DerivedClassT{ \
reg.template opt_vec_func_impl<&VectorizedT::TORCH_FUNC_NAME, \
std::STD_FUNC_NAME>()}; \
} else { \
return DerivedClassT{reg.template opt_vec_func_impl< \
nullptr, static_cast<ScalarT (*)(ScalarT)>(&std::STD_FUNC_NAME)>()}; \
} \
}
bfloat16x8_t
reg
;
// Define optimized functions for datatypes passed in __VA_ARGS__
OPT_TORCH_IMPL
(
abs
,
abs
,
abs
,
c10
::
Half
,
float
)
OPT_TORCH_IMPL
(
er
,
erf
,
erf
,
float
)
OPT_TORCH_IMPL
(
exp
,
exp
,
fexp_u20
,
float
)
OPT_TORCH_IMPL
(
exp_u20
,
exp
,
exp_u20
,
float
)
OPT_TORCH_IMPL
(
sin
,
sin
,
sin
,
float
)
OPT_TORCH_IMPL
(
sinh
,
sinh
,
sinh
,
float
)
OPT_TORCH_IMPL
(
cos
,
cos
,
cos
,
float
)
OPT_TORCH_IMPL
(
cosh
,
cosh
,
cosh
,
float
)
OPT_TORCH_IMPL
(
log
,
log
,
log
,
float
)
OPT_TORCH_IMPL
(
log10
,
log10
,
log10
,
float
)
OPT_TORCH_IMPL
(
sqrt
,
sqrt
,
sqrt
,
c10
::
Half
,
float
)
OPT_TORCH_IMPL
(
tan
,
tan
,
tan
,
float
)
OPT_TORCH_IMPL
(
tanh
,
tanh
,
tanh
,
float
)
#undef OPT_TORCH_IMPL
};
explicit
BF16Vec8
(
const
void
*
ptr
)
// forward declare vectorised dtypes
:
reg
(
*
reinterpret_cast
<
const
bfloat16x8_t
*>
(
ptr
))
{};
struct
FP32Vec8
;
struct
FP32Vec16
;
struct
FP16Vec8
;
struct
FP16Vec16
;
struct
BF16Vec8
;
struct
BF16Vec16
;
explicit
BF16Vec8
(
bfloat16x8_t
data
)
:
reg
(
data
)
{};
struct
INT8Vec16
;
struct
INT32Vec16
;
explicit
BF16Vec8
(
const
FP32Vec8
&
);
template
<
typename
T
>
struct
VecType
{
using
vec_type
=
void
;
};
explicit
BF16Vec8
(
float32x4x2_t
v
)
template
<
typename
T
>
:
reg
(
vcvtq_high_bf16_f32
(
vcvtq_low_bf16_f32
(
v
.
val
[
0
]),
v
.
val
[
1
]))
{}
;
using
vec_t
=
typename
VecType
<
T
>::
vec_type
;
void
save
(
void
*
ptr
)
const
{
*
reinterpret_cast
<
bfloat16x8_t
*>
(
ptr
)
=
reg
;
}
template
<
>
struct
VecType
<
float
>
{
using
vec_type
=
FP32Vec8
;
};
};
struct
BF16Vec16
:
public
Vec
<
BF16Vec16
>
{
template
<
>
constexpr
static
int
VEC_ELEM_NUM
=
16
;
struct
VecType
<
c10
::
Half
>
{
using
vec_type
=
FP16Vec8
;
bfloat16x8x2_t
reg
;
}
;
explicit
BF16Vec16
(
const
void
*
ptr
)
template
<
>
:
reg
(
*
reinterpret_cast
<
const
bfloat16x8x2_t
*>
(
ptr
))
{};
struct
VecType
<
c10
::
BFloat16
>
{
using
vec_type
=
BF16Vec8
;
};
// ASIMD does not support non-temporal loads
struct
FP16Vec8
:
public
VectorizedRegWrapper
<
FP16Vec8
,
1
,
c10
::
Half
>
{
explicit
BF16Vec16
(
bool
,
const
void
*
ptr
)
:
BF16Vec16
(
ptr
)
{}
using
Base
=
VectorizedRegWrapper
<
FP16Vec8
,
1
,
c10
::
Half
>
;
using
Base
::
Base
;
using
Base
::
get_elem_num
;
using
Base
::
VEC_ELEM_NUM
;
explicit
BF16Vec16
(
bfloat16x8x2_t
data
)
:
reg
(
data
)
{};
explicit
FP16Vec8
(
const
FP32Vec8
&
);
};
explicit
BF16Vec16
(
const
FP32Vec16
&
);
struct
FP16Vec16
:
public
VectorizedRegWrapper
<
FP16Vec16
,
2
,
c10
::
Half
>
{
using
Base
=
VectorizedRegWrapper
<
FP16Vec16
,
2
,
c10
::
Half
>
;
using
Base
::
Base
;
using
Base
::
get_elem_num
;
using
Base
::
VEC_ELEM_NUM
;
explicit
BF16Vec16
(
float32x4x4_t
v
)
// ASIMD does not support non-temporal loads
:
reg
({
vcvtq_high_bf16_f32
(
vcvtq_low_bf16_f32
(
v
.
val
[
0
]),
v
.
val
[
1
]),
explicit
FP16Vec16
(
bool
,
const
void
*
ptr
)
:
Base
(
ptr
)
{}
vcvtq_high_bf16_f32
(
vcvtq_low_bf16_f32
(
v
.
val
[
2
]),
v
.
val
[
3
])})
{};
void
save
(
void
*
ptr
)
const
{
*
reinterpret_cast
<
bfloat16x8x2_t
*>
(
ptr
)
=
reg
;
};
explicit
FP16Vec16
(
const
FP32Vec16
&
vec
);
void
save
(
void
*
ptr
,
const
int
elem_num
)
const
{
int
full_blocks
=
elem_num
/
NUM_ELEMENTS_REG
(
reg
.
val
[
0
]);
int
remainder
=
elem_num
%
NUM_ELEMENTS_REG
(
reg
.
val
[
0
]);
for
(
int
i
=
0
;
i
<
full_blocks
;
i
++
)
vst1q_bf16
(
reinterpret_cast
<
__bf16
*>
(
ptr
)
+
NUM_ELEMENTS_REG
(
reg
.
val
[
0
])
*
i
,
reg
.
val
[
i
]);
if
(
remainder
>
0
)
{
bfloat16x8_t
temp
=
reg
.
val
[
full_blocks
];
bfloat16_t
*
base
=
reinterpret_cast
<
bfloat16_t
*>
(
ptr
)
+
full_blocks
*
8
;
if
(
remainder
>
0
)
base
[
0
]
=
vgetq_lane_bf16
(
temp
,
0
);
if
(
remainder
>
1
)
base
[
1
]
=
vgetq_lane_bf16
(
temp
,
1
);
if
(
remainder
>
2
)
base
[
2
]
=
vgetq_lane_bf16
(
temp
,
2
);
if
(
remainder
>
3
)
base
[
3
]
=
vgetq_lane_bf16
(
temp
,
3
);
if
(
remainder
>
4
)
base
[
4
]
=
vgetq_lane_bf16
(
temp
,
4
);
if
(
remainder
>
5
)
base
[
5
]
=
vgetq_lane_bf16
(
temp
,
5
);
if
(
remainder
>
6
)
base
[
6
]
=
vgetq_lane_bf16
(
temp
,
6
);
}
};
};
};
struct
BF16Vec32
:
public
Vec
<
BF16Vec32
>
{
struct
BF16Vec8
:
public
VectorizedRegWrapper
<
BF16Vec8
,
1
,
c10
::
BFloat16
>
{
constexpr
static
int
VEC_ELEM_NUM
=
32
;
using
Base
=
VectorizedRegWrapper
<
BF16Vec8
,
1
,
c10
::
BFloat16
>
;
using
VectorizedT
=
typename
Base
::
VectorizedT
;
using
Base
::
Base
;
using
Base
::
get_elem_num
;
using
Base
::
VEC_ELEM_NUM
;
bfloat16x8
x4
_t
reg
;
explicit
BF16Vec8
(
at_
bfloat16x8_t
data
)
:
Base
(
VectorizedT
(
data
))
{}
;
explicit
BF16Vec32
(
const
void
*
ptr
)
explicit
BF16Vec8
(
float32x4x2_t
v
)
{
:
reg
(
*
reinterpret_cast
<
const
bfloat16x8x4_t
*>
(
ptr
))
{};
reg
.
val
[
0
]
=
convert_float_bfloat16
(
v
.
val
[
0
],
v
.
val
[
1
]);
};
explicit
BF16Vec32
(
bfloat16x8x4_t
data
)
:
reg
(
data
)
{};
explicit
BF16Vec8
(
const
FP32Vec8
&
);
};
explicit
BF16Vec32
(
const
BF16Vec8
&
vec8_data
)
struct
BF16Vec16
:
public
VectorizedRegWrapper
<
BF16Vec16
,
2
,
c10
::
BFloat16
>
{
:
reg
({
vec8_data
.
reg
,
vec8_data
.
reg
,
vec8_data
.
reg
,
vec8_data
.
reg
})
{};
using
Base
=
VectorizedRegWrapper
<
BF16Vec16
,
2
,
c10
::
BFloat16
>
;
using
VectorizedT
=
typename
Base
::
VectorizedT
;
using
Base
::
Base
;
using
Base
::
get_elem_num
;
using
Base
::
VEC_ELEM_NUM
;
void
save
(
void
*
ptr
)
const
{
*
reinterpret_cast
<
bfloat16x8x4_t
*>
(
ptr
)
=
reg
;
};
// ASIMD does not support non-temporal loads
void
save
(
void
*
ptr
,
const
int
elem_num
)
const
{
explicit
BF16Vec16
(
bool
,
const
void
*
ptr
)
:
Base
(
ptr
)
{}
int
full_blocks
=
elem_num
/
NUM_ELEMENTS_REG
(
reg
.
val
[
0
]);
int
remainder
=
elem_num
%
NUM_ELEMENTS_REG
(
reg
.
val
[
0
]);
explicit
BF16Vec16
(
float32x4x4_t
v
)
{
for
(
int
i
=
0
;
i
<
full_blocks
;
i
++
)
reg
.
val
[
0
]
=
convert_float_bfloat16
(
v
.
val
[
0
],
v
.
val
[
1
]);
vst1q_bf16
(
reg
.
val
[
1
]
=
convert_float_bfloat16
(
v
.
val
[
2
],
v
.
val
[
3
]);
reinterpret_cast
<
__bf16
*>
(
ptr
)
+
NUM_ELEMENTS_REG
(
reg
.
val
[
0
])
*
i
,
reg
.
val
[
i
]);
if
(
remainder
>
0
)
{
bfloat16x8_t
temp
=
reg
.
val
[
full_blocks
];
bfloat16_t
*
base
=
reinterpret_cast
<
bfloat16_t
*>
(
ptr
)
+
full_blocks
*
8
;
base
[
0
]
=
vgetq_lane_bf16
(
temp
,
0
);
if
(
remainder
>
1
)
base
[
1
]
=
vgetq_lane_bf16
(
temp
,
1
);
if
(
remainder
>
2
)
base
[
2
]
=
vgetq_lane_bf16
(
temp
,
2
);
if
(
remainder
>
3
)
base
[
3
]
=
vgetq_lane_bf16
(
temp
,
3
);
if
(
remainder
>
4
)
base
[
4
]
=
vgetq_lane_bf16
(
temp
,
4
);
if
(
remainder
>
5
)
base
[
5
]
=
vgetq_lane_bf16
(
temp
,
5
);
if
(
remainder
>
6
)
base
[
6
]
=
vgetq_lane_bf16
(
temp
,
6
);
}
};
};
};
#endif
struct
FP32Vec4
:
public
Vec
<
FP32Vec
4
>
{
explicit
BF16Vec16
(
const
FP32Vec
16
&
);
constexpr
static
int
VEC_ELEM_NUM
=
4
;
}
;
union
AliasReg
{
struct
BF16Vec32
:
public
VectorizedRegWrapper
<
BF16Vec32
,
4
,
c10
::
BFloat16
>
{
float32x4_t
reg
;
using
Base
=
VectorizedRegWrapper
<
BF16Vec32
,
4
,
c10
::
BFloat16
>
;
float
values
[
VEC_ELEM_NUM
];
using
Base
::
Base
;
using
Base
::
get_elem_num
;
using
Base
::
VEC_ELEM_NUM
;
explicit
BF16Vec32
(
const
BF16Vec8
&
vec8_data
)
{
reg
.
val
[
0
]
=
vec8_data
.
reg
.
val
[
0
];
reg
.
val
[
1
]
=
vec8_data
.
reg
.
val
[
0
];
reg
.
val
[
2
]
=
vec8_data
.
reg
.
val
[
0
];
reg
.
val
[
3
]
=
vec8_data
.
reg
.
val
[
0
];
};
};
};
float32x4_t
reg
;
struct
FP32Vec4
:
public
VectorizedRegWrapper
<
FP32Vec4
,
1
,
float
>
{
using
Base
=
VectorizedRegWrapper
<
FP32Vec4
,
1
,
float
>
;
explicit
FP32Vec4
(
float
v
)
:
reg
(
vdupq_n_f32
(
v
))
{};
using
Base
::
Base
;
using
Base
::
get_elem_num
;
using
Base
::
VEC_ELEM_NUM
;
explicit
FP32Vec4
()
:
reg
(
vdupq_n_f32
(
0.0
f
))
{};
using
VectorizedT
=
typename
Base
::
VectorizedT
;
using
Vectorized1x4f
=
typename
Base
::
NxVectorizedTArray
;
explicit
FP32Vec4
(
const
float
*
ptr
)
:
reg
(
vld1q_f32
(
ptr
))
{};
FP32Vec4
()
:
Base
()
{};
explicit
FP32Vec4
(
float
v
)
:
Base
(
v
)
{};
explicit
FP32Vec4
(
float32x4_t
data
)
:
reg
(
data
)
{};
explicit
FP32Vec4
(
float32x4_t
data
)
:
Base
(
VectorizedT
(
data
)
)
{};
explicit
FP32Vec4
(
const
FP32Vec4
&
data
)
:
reg
(
data
.
reg
)
{};
explicit
FP32Vec4
(
const
FP32Vec4
&
data
)
:
Base
(
data
)
{};
};
};
struct
FP32Vec8
:
public
Vec
<
FP32Vec8
>
{
struct
FP32Vec8
:
public
VectorizedRegWrapper
<
FP32Vec8
,
2
,
float
>
{
constexpr
static
int
VEC_ELEM_NUM
=
8
;
using
Base
=
VectorizedRegWrapper
<
FP32Vec8
,
2
,
float
>
;
union
AliasReg
{
using
Base
::
Base
;
float32x4x2_t
reg
;
using
Base
::
get_elem_num
;
float
values
[
VEC_ELEM_NUM
];
using
Base
::
VEC_ELEM_NUM
;
};
using
Base
::
VEC_REG_NUM
;
float32x4x2_t
reg
;
explicit
FP32Vec8
(
float
v
)
:
reg
({
vmovq_n_f32
(
v
),
vmovq_n_f32
(
v
)})
{};
using
VectorizedT
=
typename
Base
::
VectorizedT
;
using
Vectorized2x4f
=
typename
Base
::
NxVectorizedTArray
;
explicit
FP32Vec8
()
:
reg
({
vmovq_n_f32
(
0.0
),
vmovq_n_f32
(
0.0
)})
{};
FP32Vec8
()
:
Base
()
{};
FP32Vec8
(
const
FP32Vec8
&
data
)
:
Base
(
data
)
{};
explicit
FP32Vec8
(
float
v
)
:
Base
(
v
)
{};
explicit
FP32Vec8
(
const
float
*
ptr
)
explicit
FP32Vec8
(
const
float
*
ptr
)
:
reg
({
vld1q_f32
(
ptr
),
vld1q_f32
(
ptr
+
4
)})
{};
:
Base
(
reinterpret_cast
<
const
void
*>
(
ptr
))
{};
explicit
FP32Vec8
(
const
float
*
ptr
,
const
int
elem_num
)
:
Base
(
reinterpret_cast
<
const
void
*>
(
ptr
),
elem_num
)
{};
explicit
FP32Vec8
(
float32x4x2_t
data
)
:
reg
(
data
)
{};
explicit
FP32Vec8
(
const
Vectorized2x4f
&
data
)
{
reg
.
val
[
0
]
=
data
.
val
[
0
];
explicit
FP32Vec8
(
const
FP32Vec8
&
data
)
:
reg
(
data
.
reg
)
{};
reg
.
val
[
1
]
=
data
.
val
[
1
];
};
explicit
FP32Vec8
(
const
BF16Vec8
&
v
)
{
std
::
tie
(
reg
.
val
[
0
],
reg
.
val
[
1
])
=
convert_bfloat16_float
(
v
.
reg
.
val
[
0
]);
};
explicit
FP32Vec8
(
const
FP16Vec8
&
v
)
{
explicit
FP32Vec8
(
const
FP16Vec8
&
v
)
{
reg
.
val
[
0
]
=
vcvt_f32_f16
(
vget_low_f16
(
v
.
reg
));
reg
.
val
[
0
]
=
Vectorized
<
float
>
(
vcvt_f32_f16
(
vget_low_f16
(
v
.
reg
.
val
[
0
])
));
reg
.
val
[
1
]
=
vcvt_f32_f16
(
vget_high_f16
(
v
.
reg
));
reg
.
val
[
1
]
=
Vectorized
<
float
>
(
vcvt_f32_f16
(
vget_high_f16
(
v
.
reg
.
val
[
0
])
));
};
};
explicit
FP32Vec8
(
float16x8_t
v
)
{
reg
.
val
[
0
]
=
Vectorized
<
float
>
(
vcvt_f32_f16
(
vget_low_f16
(
v
)));
reg
.
val
[
1
]
=
Vectorized
<
float
>
(
vcvt_f32_f16
(
vget_high_f16
(
v
)));
};
explicit
FP32Vec8
(
at_bfloat16x8_t
v
)
{
std
::
tie
(
reg
.
val
[
0
],
reg
.
val
[
1
])
=
convert_bfloat16_float
(
Vectorized
<
c10
::
BFloat16
>
(
v
));
};
explicit
FP32Vec8
(
float32x4x2_t
data
)
{
reg
.
val
[
0
]
=
Vectorized
<
float
>
(
data
.
val
[
0
]);
reg
.
val
[
1
]
=
Vectorized
<
float
>
(
data
.
val
[
1
]);
}
explicit
FP32Vec8
(
float16x8_t
v
)
FORCE_INLINE
float
reduce_sum
()
const
noexcept
{
:
reg
({
vcvt_f32_f16
(
vget_low_f16
(
v
)),
vcvt_f32_f16
(
vget_high_f16
(
v
))})
{};
#ifdef ARM_BF16_SUPPORT
explicit
FP32Vec8
(
bfloat16x8_t
v
)
:
reg
({
vcvtq_low_f32_bf16
(
v
),
vcvtq_high_f32_bf16
(
v
)})
{};
explicit
FP32Vec8
(
const
BF16Vec8
&
v
)
:
reg
({
vcvtq_low_f32_bf16
(
v
.
reg
),
vcvtq_high_f32_bf16
(
v
.
reg
)})
{};
#endif
float
reduce_sum
()
const
{
AliasReg
ar
;
ar
.
reg
=
reg
;
float
answer
=
0
;
float
answer
=
0
;
unroll_loop
<
int
,
VEC_ELEM_NUM
>
(
std
::
plus
<
VectorizedT
>
add
;
[
&
answer
,
&
ar
](
int
i
)
{
answer
+=
ar
.
values
[
i
];
});
unroll_loop
<
int
,
VEC_REG_NUM
>
([
&
](
int
i
)
{
answer
+=
at
::
vec
::
vec_reduce_all
<
float
,
std
::
plus
<
VectorizedT
>>
(
add
,
reg
.
val
[
i
]);
});
return
answer
;
return
answer
;
}
}
FP32Vec8
exp
()
const
{
FORCE_INLINE
FP32Vec8
operator
+
(
const
FP32Vec8
&
b
)
const
noexcept
{
AliasReg
ar
;
FP32Vec8
r
(
uninit
);
ar
.
reg
=
reg
;
r
.
reg
.
val
[
0
]
=
reg
.
val
[
0
]
+
b
.
reg
.
val
[
0
];
r
.
reg
.
val
[
1
]
=
reg
.
val
[
1
]
+
b
.
reg
.
val
[
1
];
float32x2_t
exp_vec0
=
{
expf
(
ar
.
values
[
0
]),
expf
(
ar
.
values
[
1
])};
return
r
;
float32x2_t
exp_vec1
=
{
expf
(
ar
.
values
[
2
]),
expf
(
ar
.
values
[
3
])};
float32x2_t
exp_vec2
=
{
expf
(
ar
.
values
[
4
]),
expf
(
ar
.
values
[
5
])};
float32x2_t
exp_vec3
=
{
expf
(
ar
.
values
[
6
]),
expf
(
ar
.
values
[
7
])};
float32x4_t
result0
=
vcombine_f32
(
exp_vec0
,
exp_vec1
);
float32x4_t
result1
=
vcombine_f32
(
exp_vec2
,
exp_vec3
);
float32x4x2_t
result
;
result
.
val
[
0
]
=
result0
;
result
.
val
[
1
]
=
result1
;
return
FP32Vec8
(
result
);
}
}
FP32Vec8
tanh
()
const
{
FORCE_INLINE
FP32Vec8
operator
-
(
const
FP32Vec8
&
b
)
const
noexcept
{
AliasReg
ar
;
FP32Vec8
r
(
uninit
);
ar
.
reg
=
reg
;
r
.
reg
.
val
[
0
]
=
reg
.
val
[
0
]
-
b
.
reg
.
val
[
0
];
r
.
reg
.
val
[
1
]
=
reg
.
val
[
1
]
-
b
.
reg
.
val
[
1
];
float32x2_t
tanh_vec0
=
{
tanhf
(
ar
.
values
[
0
]),
tanhf
(
ar
.
values
[
1
])};
return
r
;
float32x2_t
tanh_vec1
=
{
tanhf
(
ar
.
values
[
2
]),
tanhf
(
ar
.
values
[
3
])};
float32x2_t
tanh_vec2
=
{
tanhf
(
ar
.
values
[
4
]),
tanhf
(
ar
.
values
[
5
])};
float32x2_t
tanh_vec3
=
{
tanhf
(
ar
.
values
[
6
]),
tanhf
(
ar
.
values
[
7
])};
float32x4_t
result0
=
vcombine_f32
(
tanh_vec0
,
tanh_vec1
);
float32x4_t
result1
=
vcombine_f32
(
tanh_vec2
,
tanh_vec3
);
float32x4x2_t
result
;
result
.
val
[
0
]
=
result0
;
result
.
val
[
1
]
=
result1
;
return
FP32Vec8
(
result
);
}
}
FP32Vec8
er
()
const
{
FORCE_INLINE
FP32Vec8
operator
*
(
const
FP32Vec8
&
b
)
const
noexcept
{
AliasReg
ar
;
FP32Vec8
r
(
uninit
);
ar
.
reg
=
reg
;
r
.
reg
.
val
[
0
]
=
reg
.
val
[
0
]
*
b
.
reg
.
val
[
0
];
r
.
reg
.
val
[
1
]
=
reg
.
val
[
1
]
*
b
.
reg
.
val
[
1
];
float32x2_t
er_vec0
=
{
static_cast
<
float32_t
>
(
erf
(
ar
.
values
[
0
])),
return
r
;
static_cast
<
float32_t
>
(
erf
(
ar
.
values
[
1
]))};
float32x2_t
er_vec1
=
{
static_cast
<
float32_t
>
(
erf
(
ar
.
values
[
2
])),
static_cast
<
float32_t
>
(
erf
(
ar
.
values
[
3
]))};
float32x2_t
er_vec2
=
{
static_cast
<
float32_t
>
(
erf
(
ar
.
values
[
4
])),
static_cast
<
float32_t
>
(
erf
(
ar
.
values
[
5
]))};
float32x2_t
er_vec3
=
{
static_cast
<
float32_t
>
(
erf
(
ar
.
values
[
6
])),
static_cast
<
float32_t
>
(
erf
(
ar
.
values
[
7
]))};
float32x4_t
result0
=
vcombine_f32
(
er_vec0
,
er_vec1
);
float32x4_t
result1
=
vcombine_f32
(
er_vec2
,
er_vec3
);
float32x4x2_t
result
;
result
.
val
[
0
]
=
result0
;
result
.
val
[
1
]
=
result1
;
return
FP32Vec8
(
result
);
}
FP32Vec8
operator
*
(
const
FP32Vec8
&
b
)
const
{
return
FP32Vec8
(
float32x4x2_t
({
vmulq_f32
(
reg
.
val
[
0
],
b
.
reg
.
val
[
0
]),
vmulq_f32
(
reg
.
val
[
1
],
b
.
reg
.
val
[
1
])}));
}
FP32Vec8
operator
+
(
const
FP32Vec8
&
b
)
const
{
return
FP32Vec8
(
float32x4x2_t
({
vaddq_f32
(
reg
.
val
[
0
],
b
.
reg
.
val
[
0
]),
vaddq_f32
(
reg
.
val
[
1
],
b
.
reg
.
val
[
1
])}));
}
FP32Vec8
operator
-
(
const
FP32Vec8
&
b
)
const
{
return
FP32Vec8
(
float32x4x2_t
({
vsubq_f32
(
reg
.
val
[
0
],
b
.
reg
.
val
[
0
]),
vsubq_f32
(
reg
.
val
[
1
],
b
.
reg
.
val
[
1
])}));
}
FP32Vec8
operator
/
(
const
FP32Vec8
&
b
)
const
{
return
FP32Vec8
(
float32x4x2_t
({
vdivq_f32
(
reg
.
val
[
0
],
b
.
reg
.
val
[
0
]),
vdivq_f32
(
reg
.
val
[
1
],
b
.
reg
.
val
[
1
])}));
}
}
void
save
(
float
*
ptr
)
const
{
FORCE_INLINE
FP32Vec8
operator
/
(
const
FP32Vec8
&
b
)
const
noexcept
{
vst1q_f32
(
ptr
,
reg
.
val
[
0
]);
FP32Vec8
r
(
uninit
);
vst1q_f32
(
ptr
+
4
,
reg
.
val
[
1
]);
r
.
reg
.
val
[
0
]
=
reg
.
val
[
0
]
/
b
.
reg
.
val
[
0
];
r
.
reg
.
val
[
1
]
=
reg
.
val
[
1
]
/
b
.
reg
.
val
[
1
];
return
r
;
}
}
};
};
struct
INT32Vec16
:
public
Vec
<
INT32Vec16
>
{
struct
FP32Vec16
:
public
VectorizedRegWrapper
<
FP32Vec16
,
4
,
float
>
{
constexpr
static
int
VEC_ELEM_NUM
=
16
;
using
Base
=
VectorizedRegWrapper
<
FP32Vec16
,
4
,
float
>
;
union
AliasReg
{
using
Base
::
Base
;
int32x4x4_t
reg
;
using
Base
::
get_elem_num
;
int32_t
values
[
VEC_ELEM_NUM
];
using
Base
::
VEC_ELEM_NUM
;
};
int32x4x4_t
reg
;
explicit
INT32Vec16
(
const
void
*
ptr
)
{
using
ScalarT
=
typename
Base
::
ScalarT
;
reg
.
val
[
0
]
=
vld1q_s32
(
reinterpret_cast
<
const
int32_t
*>
(
ptr
));
using
VectorizedT
=
typename
Base
::
VectorizedT
;
reg
.
val
[
1
]
=
vld1q_s32
(
reinterpret_cast
<
const
int32_t
*>
(
ptr
)
+
4
);
using
Vectorized4x4f
=
typename
Base
::
NxVectorizedTArray
;
reg
.
val
[
2
]
=
vld1q_s32
(
reinterpret_cast
<
const
int32_t
*>
(
ptr
)
+
8
);
reg
.
val
[
3
]
=
vld1q_s32
(
reinterpret_cast
<
const
int32_t
*>
(
ptr
)
+
12
);
}
void
save
(
int32_t
*
ptr
)
const
{
FP32Vec16
()
:
Base
()
{};
vst1q_s32
(
ptr
,
reg
.
val
[
0
]);
FP32Vec16
(
const
FP32Vec16
&
data
)
:
Base
(
data
)
{};
vst1q_s32
(
ptr
+
4
,
reg
.
val
[
1
]);
explicit
FP32Vec16
(
float
v
)
:
Base
(
v
)
{};
vst1q_s32
(
ptr
+
8
,
reg
.
val
[
2
]);
explicit
FP32Vec16
(
const
float
*
ptr
)
vst1q_s32
(
ptr
+
12
,
reg
.
val
[
3
]);
:
Base
(
reinterpret_cast
<
const
void
*>
(
ptr
))
{};
explicit
FP32Vec16
(
const
float
*
ptr
,
const
int
elem_num
)
:
Base
(
reinterpret_cast
<
const
void
*>
(
ptr
),
elem_num
)
{};
explicit
FP32Vec16
(
const
Vectorized4x4f
&
data
)
{
reg
.
val
[
0
]
=
data
.
val
[
0
];
reg
.
val
[
1
]
=
data
.
val
[
1
];
reg
.
val
[
2
]
=
data
.
val
[
2
];
reg
.
val
[
3
]
=
data
.
val
[
3
];
};
};
void
save
(
int32_t
*
ptr
,
const
int
elem_num
)
const
{
// ASIMD does not support non-temporal loads
int
full_blocks
=
elem_num
/
NUM_ELEMENTS_REG
(
reg
.
val
[
0
]);
explicit
FP32Vec16
(
bool
,
const
float
*
ptr
)
:
Base
(
ptr
)
{}
int
remainder
=
elem_num
%
NUM_ELEMENTS_REG
(
reg
.
val
[
0
]);
for
(
int
i
=
0
;
i
<
full_blocks
;
i
++
)
vst1q_s32
(
reinterpret_cast
<
__int32_t
*>
(
ptr
)
+
NUM_ELEMENTS_REG
(
reg
.
val
[
0
])
*
i
,
reg
.
val
[
i
]);
if
(
remainder
>
0
)
{
int32x4_t
temp
=
reg
.
val
[
full_blocks
];
int32_t
*
base
=
reinterpret_cast
<
int32_t
*>
(
ptr
)
+
full_blocks
*
4
;
if
(
remainder
>
0
)
base
[
0
]
=
vgetq_lane_s32
(
temp
,
0
);
if
(
remainder
>
1
)
base
[
1
]
=
vgetq_lane_s32
(
temp
,
1
);
if
(
remainder
>
2
)
base
[
2
]
=
vgetq_lane_s32
(
temp
,
2
);
if
(
remainder
>
3
)
base
[
3
]
=
vgetq_lane_s32
(
temp
,
3
);
}
}
};
struct
FP32Vec16
:
public
Vec
<
FP32Vec16
>
{
explicit
FP32Vec16
(
float32x4x4_t
data
)
{
constexpr
static
int
VEC_ELEM_NUM
=
16
;
reg
.
val
[
0
]
=
data
.
val
[
0
]
;
union
AliasReg
{
reg
.
val
[
1
]
=
data
.
val
[
1
];
float32x4x4_t
reg
;
reg
.
val
[
2
]
=
data
.
val
[
2
]
;
float
values
[
VEC_ELEM_NUM
];
reg
.
val
[
3
]
=
data
.
val
[
3
];
};
};
float32x4x4_t
reg
;
explicit
FP32Vec16
(
const
FP32Vec4
&
data
)
{
reg
.
val
[
0
]
=
data
.
reg
.
val
[
0
];
explicit
FP32Vec16
(
float
v
)
reg
.
val
[
1
]
=
data
.
reg
.
val
[
0
];
:
reg
({
vmovq_n_f32
(
v
),
vmovq_n_f32
(
v
),
vmovq_n_f32
(
v
),
vmovq_n_f32
(
v
)})
{}
reg
.
val
[
2
]
=
data
.
reg
.
val
[
0
];
reg
.
val
[
3
]
=
data
.
reg
.
val
[
0
];
explicit
FP32Vec16
()
};
:
reg
({
vmovq_n_f32
(
0.0
),
vmovq_n_f32
(
0.0
),
vmovq_n_f32
(
0.0
),
vmovq_n_f32
(
0.0
)})
{}
explicit
FP32Vec16
(
const
float
*
ptr
)
:
reg
({
vld1q_f32
(
ptr
),
vld1q_f32
(
ptr
+
4
),
vld1q_f32
(
ptr
+
8
),
vld1q_f32
(
ptr
+
12
)})
{}
// ASIMD does not support non-temporal loads
explicit
FP32Vec16
(
bool
,
const
float
*
ptr
)
:
FP32Vec16
(
ptr
)
{}
explicit
FP32Vec16
(
float32x4x4_t
data
)
:
reg
(
data
)
{}
explicit
FP32Vec16
(
const
FP32Vec8
&
data
)
{
explicit
FP32Vec16
(
const
FP32Vec8
&
data
)
{
reg
.
val
[
0
]
=
data
.
reg
.
val
[
0
];
reg
.
val
[
0
]
=
data
.
reg
.
val
[
0
];
reg
.
val
[
1
]
=
data
.
reg
.
val
[
1
];
reg
.
val
[
1
]
=
data
.
reg
.
val
[
1
];
reg
.
val
[
2
]
=
data
.
reg
.
val
[
0
];
reg
.
val
[
2
]
=
data
.
reg
.
val
[
0
];
reg
.
val
[
3
]
=
data
.
reg
.
val
[
1
];
reg
.
val
[
3
]
=
data
.
reg
.
val
[
1
];
}
explicit
FP32Vec16
(
const
FP32Vec16
&
data
)
:
reg
(
data
.
reg
)
{}
explicit
FP32Vec16
(
const
FP16Vec8
&
v
)
:
FP32Vec16
(
FP32Vec8
(
v
.
reg
))
{}
#ifdef ARM_BF16_SUPPORT
explicit
FP32Vec16
(
bfloat16x8x2_t
v
)
:
reg
({
vcvtq_low_f32_bf16
(
v
.
val
[
0
]),
vcvtq_high_f32_bf16
(
v
.
val
[
0
]),
vcvtq_low_f32_bf16
(
v
.
val
[
1
]),
vcvtq_high_f32_bf16
(
v
.
val
[
1
])})
{};
#endif
explicit
FP32Vec16
(
const
FP32Vec4
&
data
)
{
reg
.
val
[
0
]
=
data
.
reg
;
reg
.
val
[
1
]
=
data
.
reg
;
reg
.
val
[
2
]
=
data
.
reg
;
reg
.
val
[
3
]
=
data
.
reg
;
};
};
#ifdef ARM_BF16_SUPPORT
explicit
FP32Vec16
(
const
BF16Vec16
&
v
)
{
explicit
FP32Vec16
(
const
BF16Vec16
&
v
)
std
::
tie
(
reg
.
val
[
0
],
reg
.
val
[
1
])
=
convert_bfloat16_float
(
v
.
reg
.
val
[
0
]);
:
reg
({
vcvtq_low_f32_bf16
(
v
.
reg
.
val
[
0
]),
std
::
tie
(
reg
.
val
[
2
],
reg
.
val
[
3
])
=
convert_bfloat16_float
(
v
.
reg
.
val
[
1
]);
vcvtq_high_f32_bf16
(
v
.
reg
.
val
[
0
]),
};
vcvtq_low_f32_bf16
(
v
.
reg
.
val
[
1
]),
vcvtq_high_f32_bf16
(
v
.
reg
.
val
[
1
])})
{};
explicit
FP32Vec16
(
const
BF16Vec8
&
v
)
:
FP32Vec16
(
FP32Vec8
(
v
))
{};
explicit
FP32Vec16
(
const
BF16Vec8
&
v
)
:
FP32Vec16
(
FP32Vec8
(
v
))
{};
#endif
explicit
FP32Vec16
(
const
FP16Vec16
&
v
)
{
explicit
FP32Vec16
(
const
FP16Vec16
&
v
)
{
reg
.
val
[
0
]
=
vcvt_f32_f16
(
vget_low_f16
(
v
.
reg
.
val
[
0
]));
reg
.
val
[
0
]
=
Vectorized
<
float
>
(
vcvt_f32_f16
(
vget_low_f16
(
v
.
reg
.
val
[
0
])));
reg
.
val
[
1
]
=
vcvt_f32_f16
(
vget_high_f16
(
v
.
reg
.
val
[
0
]));
reg
.
val
[
1
]
=
Vectorized
<
float
>
(
vcvt_f32_f16
(
vget_high_f16
(
v
.
reg
.
val
[
0
])));
reg
.
val
[
2
]
=
vcvt_f32_f16
(
vget_low_f16
(
v
.
reg
.
val
[
1
]));
reg
.
val
[
2
]
=
Vectorized
<
float
>
(
vcvt_f32_f16
(
vget_low_f16
(
v
.
reg
.
val
[
1
])));
reg
.
val
[
3
]
=
vcvt_f32_f16
(
vget_high_f16
(
v
.
reg
.
val
[
1
]));
reg
.
val
[
3
]
=
Vectorized
<
float
>
(
vcvt_f32_f16
(
vget_high_f16
(
v
.
reg
.
val
[
1
])));
};
explicit
FP32Vec16
(
const
INT32Vec16
&
v
)
{
reg
.
val
[
0
]
=
vcvtq_f32_s32
(
v
.
reg
.
val
[
0
]);
reg
.
val
[
1
]
=
vcvtq_f32_s32
(
v
.
reg
.
val
[
1
]);
reg
.
val
[
2
]
=
vcvtq_f32_s32
(
v
.
reg
.
val
[
2
]);
reg
.
val
[
3
]
=
vcvtq_f32_s32
(
v
.
reg
.
val
[
3
]);
};
FP32Vec16
operator
+
(
const
FP32Vec16
&
b
)
const
{
return
FP32Vec16
(
float32x4x4_t
({
vaddq_f32
(
reg
.
val
[
0
],
b
.
reg
.
val
[
0
]),
vaddq_f32
(
reg
.
val
[
1
],
b
.
reg
.
val
[
1
]),
vaddq_f32
(
reg
.
val
[
2
],
b
.
reg
.
val
[
2
]),
vaddq_f32
(
reg
.
val
[
3
],
b
.
reg
.
val
[
3
])}));
};
};
FP32Vec16
operator
*
(
const
FP32Vec16
&
b
)
const
{
FORCE_INLINE
FP32Vec16
operator
+
(
const
FP32Vec16
&
b
)
const
noexcept
{
return
FP32Vec16
(
float32x4x4_t
({
vmulq_f32
(
reg
.
val
[
0
],
b
.
reg
.
val
[
0
]),
FP32Vec16
r
(
uninit
);
vmulq_f32
(
reg
.
val
[
1
],
b
.
reg
.
val
[
1
]),
r
.
reg
.
val
[
0
]
=
reg
.
val
[
0
]
+
b
.
reg
.
val
[
0
];
vmulq_f32
(
reg
.
val
[
2
],
b
.
reg
.
val
[
2
]),
r
.
reg
.
val
[
1
]
=
reg
.
val
[
1
]
+
b
.
reg
.
val
[
1
];
vmulq_f32
(
reg
.
val
[
3
],
b
.
reg
.
val
[
3
])}));
r
.
reg
.
val
[
2
]
=
reg
.
val
[
2
]
+
b
.
reg
.
val
[
2
];
};
r
.
reg
.
val
[
3
]
=
reg
.
val
[
3
]
+
b
.
reg
.
val
[
3
];
return
r
;
}
FP32Vec16
operator
-
(
const
FP32Vec16
&
b
)
const
{
FORCE_INLINE
FP32Vec16
operator
-
(
const
FP32Vec16
&
b
)
const
noexcept
{
return
FP32Vec16
(
float32x4x4_t
({
vsubq_f32
(
reg
.
val
[
0
],
b
.
reg
.
val
[
0
]),
FP32Vec16
r
(
uninit
);
vsubq_f32
(
reg
.
val
[
1
],
b
.
reg
.
val
[
1
]),
r
.
reg
.
val
[
0
]
=
reg
.
val
[
0
]
-
b
.
reg
.
val
[
0
];
vsubq_f32
(
reg
.
val
[
2
],
b
.
reg
.
val
[
2
]),
r
.
reg
.
val
[
1
]
=
reg
.
val
[
1
]
-
b
.
reg
.
val
[
1
];
vsubq_f32
(
reg
.
val
[
3
],
b
.
reg
.
val
[
3
])}));
r
.
reg
.
val
[
2
]
=
reg
.
val
[
2
]
-
b
.
reg
.
val
[
2
];
};
r
.
reg
.
val
[
3
]
=
reg
.
val
[
3
]
-
b
.
reg
.
val
[
3
];
return
r
;
}
FP32Vec16
operator
/
(
const
FP32Vec16
&
b
)
const
{
FORCE_INLINE
FP32Vec16
operator
*
(
const
FP32Vec16
&
b
)
const
noexcept
{
return
FP32Vec16
(
float32x4x4_t
({
vdivq_f32
(
reg
.
val
[
0
],
b
.
reg
.
val
[
0
]),
FP32Vec16
r
(
uninit
);
vdivq_f32
(
reg
.
val
[
1
],
b
.
reg
.
val
[
1
]),
r
.
reg
.
val
[
0
]
=
reg
.
val
[
0
]
*
b
.
reg
.
val
[
0
];
vdivq_f32
(
reg
.
val
[
2
],
b
.
reg
.
val
[
2
]),
r
.
reg
.
val
[
1
]
=
reg
.
val
[
1
]
*
b
.
reg
.
val
[
1
];
vdivq_f32
(
reg
.
val
[
3
],
b
.
reg
.
val
[
3
])}));
r
.
reg
.
val
[
2
]
=
reg
.
val
[
2
]
*
b
.
reg
.
val
[
2
];
r
.
reg
.
val
[
3
]
=
reg
.
val
[
3
]
*
b
.
reg
.
val
[
3
];
return
r
;
}
FORCE_INLINE
FP32Vec16
operator
/
(
const
FP32Vec16
&
b
)
const
noexcept
{
FP32Vec16
r
(
uninit
);
r
.
reg
.
val
[
0
]
=
reg
.
val
[
0
]
/
b
.
reg
.
val
[
0
];
r
.
reg
.
val
[
1
]
=
reg
.
val
[
1
]
/
b
.
reg
.
val
[
1
];
r
.
reg
.
val
[
2
]
=
reg
.
val
[
2
]
/
b
.
reg
.
val
[
2
];
r
.
reg
.
val
[
3
]
=
reg
.
val
[
3
]
/
b
.
reg
.
val
[
3
];
return
r
;
}
FORCE_INLINE
FP32Vec16
clamp
(
const
FP32Vec16
&
min
,
const
FP32Vec16
&
max
)
const
{
FP32Vec16
r
(
uninit
);
r
.
reg
.
val
[
0
]
=
at
::
vec
::
clamp
(
reg
.
val
[
0
],
min
.
reg
.
val
[
0
],
max
.
reg
.
val
[
0
]);
r
.
reg
.
val
[
1
]
=
at
::
vec
::
clamp
(
reg
.
val
[
1
],
min
.
reg
.
val
[
1
],
max
.
reg
.
val
[
1
]);
r
.
reg
.
val
[
2
]
=
at
::
vec
::
clamp
(
reg
.
val
[
2
],
min
.
reg
.
val
[
2
],
max
.
reg
.
val
[
2
]);
r
.
reg
.
val
[
3
]
=
at
::
vec
::
clamp
(
reg
.
val
[
3
],
min
.
reg
.
val
[
3
],
max
.
reg
.
val
[
3
]);
return
r
;
};
};
FP32Vec16
clamp
(
const
FP32Vec16
&
min
,
const
FP32Vec16
&
max
)
const
{
FORCE_INLINE
FP32Vec16
min
(
const
FP32Vec16
&
b
)
const
{
return
FP32Vec16
(
float32x4x4_t
(
FP32Vec16
r
(
uninit
);
{
vminq_f32
(
max
.
reg
.
val
[
0
],
vmaxq_f32
(
min
.
reg
.
val
[
0
],
reg
.
val
[
0
])),
r
.
reg
.
val
[
0
]
=
minimum
(
b
.
reg
.
val
[
0
],
reg
.
val
[
0
]),
vminq_f32
(
max
.
reg
.
val
[
1
],
vmaxq_f32
(
min
.
reg
.
val
[
1
],
reg
.
val
[
1
])),
r
.
reg
.
val
[
1
]
=
minimum
(
b
.
reg
.
val
[
1
],
reg
.
val
[
1
]);
vminq_f32
(
max
.
reg
.
val
[
2
],
vmaxq_f32
(
min
.
reg
.
val
[
2
],
reg
.
val
[
2
])),
r
.
reg
.
val
[
2
]
=
minimum
(
b
.
reg
.
val
[
2
],
reg
.
val
[
2
]);
vminq_f32
(
max
.
reg
.
val
[
3
],
vmaxq_f32
(
min
.
reg
.
val
[
3
],
reg
.
val
[
3
]))}));
r
.
reg
.
val
[
3
]
=
minimum
(
b
.
reg
.
val
[
3
],
reg
.
val
[
3
]);
return
r
;
};
};
FP32Vec16
max
(
const
FP32Vec16
&
b
)
const
{
FORCE_INLINE
FP32Vec16
max
(
const
FP32Vec16
&
b
)
const
{
return
FP32Vec16
(
float32x4x4_t
({
vmaxq_f32
(
b
.
reg
.
val
[
0
],
reg
.
val
[
0
]),
FP32Vec16
r
(
uninit
);
vmaxq_f32
(
b
.
reg
.
val
[
1
],
reg
.
val
[
1
]),
r
.
reg
.
val
[
0
]
=
maximum
(
b
.
reg
.
val
[
0
],
reg
.
val
[
0
]);
vmaxq_f32
(
b
.
reg
.
val
[
2
],
reg
.
val
[
2
]),
r
.
reg
.
val
[
1
]
=
maximum
(
b
.
reg
.
val
[
1
],
reg
.
val
[
1
]);
vmaxq_f32
(
b
.
reg
.
val
[
3
],
reg
.
val
[
3
])}));
r
.
reg
.
val
[
2
]
=
maximum
(
b
.
reg
.
val
[
2
],
reg
.
val
[
2
]);
r
.
reg
.
val
[
3
]
=
maximum
(
b
.
reg
.
val
[
3
],
reg
.
val
[
3
]);
return
r
;
};
};
FP32Vec16
max
(
const
FP32Vec16
&
b
,
const
int
elem_num
)
const
{
FP32Vec16
min
(
const
FP32Vec16
&
b
,
const
int
elem_num
)
const
{
int
full_blocks
=
elem_num
/
NUM_ELEMENTS_REG
(
reg
.
val
[
0
]);
size_t
num_elements
=
reg
.
val
[
0
].
size
();
int
remainder
=
elem_num
%
NUM_ELEMENTS_REG
(
reg
.
val
[
0
]);
float32x4x4_t
temp
;
if
(
elem_num
==
VEC_ELEM_NUM
)
{
return
FP32Vec16
::
min
(
b
);
}
int
full_blocks
=
elem_num
/
num_elements
;
const
int
remainder
=
elem_num
%
num_elements
;
FP32Vec16
res
(
uninit
);
for
(
int
i
=
0
;
i
<
full_blocks
;
i
++
)
for
(
int
i
=
0
;
i
<
full_blocks
;
i
++
)
temp
.
val
[
i
]
=
vmaxq_f32
(
b
.
reg
.
val
[
i
],
reg
.
val
[
i
]);
res
.
reg
.
val
[
i
]
=
minimum
(
b
.
reg
.
val
[
i
],
reg
.
val
[
i
]);
if
(
remainder
>
0
)
{
if
(
remainder
>
0
)
{
float
m
ax
_v
=
std
::
m
ax
(
vgetq_lane_f32
(
reg
.
val
[
full_blocks
],
0
),
float
m
in
_v
=
std
::
m
in
(
vgetq_lane_f32
(
reg
.
val
[
full_blocks
],
0
),
vgetq_lane_f32
(
b
.
reg
.
val
[
full_blocks
],
0
));
vgetq_lane_f32
(
b
.
reg
.
val
[
full_blocks
],
0
));
temp
.
val
[
full_blocks
]
=
vsetq_lane_f32
(
max_v
,
temp
.
val
[
full_blocks
],
0
);
res
.
reg
.
val
[
full_blocks
]
=
vsetq_lane_f32
(
min_v
,
res
.
reg
.
val
[
full_blocks
],
0
);
}
}
if
(
remainder
>
1
)
{
if
(
remainder
>
1
)
{
float
m
ax
_v
=
std
::
m
ax
(
vgetq_lane_f32
(
reg
.
val
[
full_blocks
],
1
),
float
m
in
_v
=
std
::
m
in
(
vgetq_lane_f32
(
reg
.
val
[
full_blocks
],
1
),
vgetq_lane_f32
(
b
.
reg
.
val
[
full_blocks
],
1
));
vgetq_lane_f32
(
b
.
reg
.
val
[
full_blocks
],
1
));
temp
.
val
[
full_blocks
]
=
vsetq_lane_f32
(
max_v
,
temp
.
val
[
full_blocks
],
1
);
res
.
reg
.
val
[
full_blocks
]
=
vsetq_lane_f32
(
min_v
,
res
.
reg
.
val
[
full_blocks
],
1
);
}
}
if
(
remainder
>
2
)
{
if
(
remainder
>
2
)
{
float
m
ax
_v
=
std
::
m
ax
(
vgetq_lane_f32
(
reg
.
val
[
full_blocks
],
2
),
float
m
in
_v
=
std
::
m
in
(
vgetq_lane_f32
(
reg
.
val
[
full_blocks
],
2
),
vgetq_lane_f32
(
b
.
reg
.
val
[
full_blocks
],
2
));
vgetq_lane_f32
(
b
.
reg
.
val
[
full_blocks
],
2
));
temp
.
val
[
full_blocks
]
=
vsetq_lane_f32
(
max_v
,
temp
.
val
[
full_blocks
],
2
);
res
.
reg
.
val
[
full_blocks
]
=
vsetq_lane_f32
(
min_v
,
res
.
reg
.
val
[
full_blocks
],
2
);
}
}
return
FP32Vec16
(
temp
);
};
FP32Vec16
min
(
const
FP32Vec16
&
b
)
const
{
return
res
;
return
FP32Vec16
(
float32x4x4_t
({
vminq_f32
(
b
.
reg
.
val
[
0
],
reg
.
val
[
0
]),
vminq_f32
(
b
.
reg
.
val
[
1
],
reg
.
val
[
1
]),
vminq_f32
(
b
.
reg
.
val
[
2
],
reg
.
val
[
2
]),
vminq_f32
(
b
.
reg
.
val
[
3
],
reg
.
val
[
3
]),
}));
};
};
FP32Vec16
min
(
const
FP32Vec16
&
b
,
const
int
elem_num
)
const
{
int
full_blocks
=
elem_num
/
NUM_ELEMENTS_REG
(
reg
.
val
[
0
]);
FP32Vec16
max
(
const
FP32Vec16
&
b
,
const
int
elem_num
)
const
{
const
int
remainder
=
elem_num
%
NUM_ELEMENTS_REG
(
reg
.
val
[
0
]);
size_t
num_elements
=
reg
.
val
[
0
].
size
();
float32x4x4_t
temp
;
if
(
elem_num
==
VEC_ELEM_NUM
)
{
return
FP32Vec16
::
max
(
b
);
}
int
full_blocks
=
elem_num
/
num_elements
;
int
remainder
=
elem_num
%
num_elements
;
FP32Vec16
res
(
uninit
);
for
(
int
i
=
0
;
i
<
full_blocks
;
i
++
)
for
(
int
i
=
0
;
i
<
full_blocks
;
i
++
)
temp
.
val
[
i
]
=
vminq_f32
(
b
.
reg
.
val
[
i
],
reg
.
val
[
i
]);
res
.
reg
.
val
[
i
]
=
maximum
(
b
.
reg
.
val
[
i
],
reg
.
val
[
i
]);
if
(
remainder
>
0
)
{
if
(
remainder
>
0
)
{
float
m
in
_v
=
std
::
m
in
(
vgetq_lane_f32
(
reg
.
val
[
full_blocks
],
0
),
float
m
ax
_v
=
std
::
m
ax
(
vgetq_lane_f32
(
reg
.
val
[
full_blocks
],
0
),
vgetq_lane_f32
(
b
.
reg
.
val
[
full_blocks
],
0
));
vgetq_lane_f32
(
b
.
reg
.
val
[
full_blocks
],
0
));
temp
.
val
[
full_blocks
]
=
vsetq_lane_f32
(
min_v
,
temp
.
val
[
full_blocks
],
0
);
res
.
reg
.
val
[
full_blocks
]
=
vsetq_lane_f32
(
max_v
,
res
.
reg
.
val
[
full_blocks
],
0
);
}
}
if
(
remainder
>
1
)
{
if
(
remainder
>
1
)
{
float
m
in
_v
=
std
::
m
in
(
vgetq_lane_f32
(
reg
.
val
[
full_blocks
],
1
),
float
m
ax
_v
=
std
::
m
ax
(
vgetq_lane_f32
(
reg
.
val
[
full_blocks
],
1
),
vgetq_lane_f32
(
b
.
reg
.
val
[
full_blocks
],
1
));
vgetq_lane_f32
(
b
.
reg
.
val
[
full_blocks
],
1
));
temp
.
val
[
full_blocks
]
=
vsetq_lane_f32
(
min_v
,
temp
.
val
[
full_blocks
],
1
);
res
.
reg
.
val
[
full_blocks
]
=
vsetq_lane_f32
(
max_v
,
res
.
reg
.
val
[
full_blocks
],
1
);
}
}
if
(
remainder
>
2
)
{
if
(
remainder
>
2
)
{
float
m
in
_v
=
std
::
m
in
(
vgetq_lane_f32
(
reg
.
val
[
full_blocks
],
2
),
float
m
ax
_v
=
std
::
m
ax
(
vgetq_lane_f32
(
reg
.
val
[
full_blocks
],
2
),
vgetq_lane_f32
(
b
.
reg
.
val
[
full_blocks
],
2
));
vgetq_lane_f32
(
b
.
reg
.
val
[
full_blocks
],
2
));
temp
.
val
[
full_blocks
]
=
vsetq_lane_f32
(
min_v
,
temp
.
val
[
full_blocks
],
2
);
res
.
reg
.
val
[
full_blocks
]
=
vsetq_lane_f32
(
max_v
,
res
.
reg
.
val
[
full_blocks
],
2
);
}
}
return
res
;
return
FP32Vec16
(
temp
);
};
FP32Vec16
abs
()
const
{
return
FP32Vec16
(
float32x4x4_t
({
vabsq_f32
(
reg
.
val
[
0
]),
vabsq_f32
(
reg
.
val
[
1
]),
vabsq_f32
(
reg
.
val
[
2
]),
vabsq_f32
(
reg
.
val
[
3
])}));
}
float
reduce_sum
()
const
{
AliasReg
ar
;
ar
.
reg
=
reg
;
float
answer
=
0
;
unroll_loop
<
int
,
VEC_ELEM_NUM
>
(
[
&
answer
,
&
ar
](
int
i
)
{
answer
+=
ar
.
values
[
i
];
});
return
answer
;
};
};
float
reduce_max
()
const
{
float
reduce_max
()
const
{
AliasReg
ar
;
VectorizedT
max_vec
=
reg
.
val
[
0
]
;
ar
.
reg
=
reg
;
unroll_loop
<
int
,
VEC_REG_NUM
>
([
&
](
int
i
)
{
float
max_v
=
std
::
numeric_limits
<
float
>::
lowest
(
);
if
(
i
>
0
)
max_v
ec
=
maximum
(
max_vec
,
reg
.
val
[
i
]
);
unroll_loop
<
int
,
VEC_ELEM_NUM
>
(
});
[
&
max_v
,
&
ar
](
int
i
)
{
max_v
=
std
::
max
(
max_v
,
ar
.
values
[
i
]);
});
return
max_v
;
return
vmaxvq_f32
(
max_v
ec
)
;
}
}
float
reduce_min
()
const
{
float
reduce_min
()
const
{
AliasReg
ar
;
VectorizedT
min_vec
=
reg
.
val
[
0
]
;
ar
.
reg
=
reg
;
unroll_loop
<
int
,
VEC_REG_NUM
>
([
&
](
int
i
)
{
float
min_v
=
std
::
numeric_limits
<
float
>::
max
(
);
if
(
i
>
0
)
min_v
ec
=
minimum
(
min_vec
,
reg
.
val
[
i
]
);
unroll_loop
<
int
,
VEC_ELEM_NUM
>
(
});
[
&
min_v
,
&
ar
](
int
i
)
{
min_v
=
std
::
min
(
min_v
,
ar
.
values
[
i
]);
});
return
min_v
;
return
vminvq_f32
(
min_v
ec
)
;
}
}
template
<
int
group_size
>
template
<
int
group_size
>
float
reduce_sub_sum
(
int
idx
)
{
float
reduce_sub_sum
(
int
idx
)
{
static_assert
(
VEC_ELEM_NUM
%
group_size
==
0
);
static_assert
(
VEC_ELEM_NUM
%
group_size
==
0
);
AliasReg
ar
;
AliasReg
<
NxVectorizedTArray
,
ScalarT
,
VEC_ELEM_NUM
>
ar
{
reg
};
ar
.
reg
=
reg
;
float
answer
=
0
;
float
answer
=
0
;
const
int
start
=
idx
*
group_size
;
const
int
start
=
idx
*
group_size
;
unroll_loop
<
int
,
group_size
>
(
unroll_loop
<
int
,
group_size
>
(
[
&
answer
,
&
start
,
ar
](
int
i
)
{
answer
+=
ar
.
values
[
start
+
i
];
});
[
&
](
int
i
)
{
answer
+=
ar
.
values
[
start
+
i
];
});
return
answer
;
return
answer
;
};
};
void
save
(
float
*
ptr
)
const
{
float
reduce_sum
()
const
{
vst1q_f32
(
ptr
,
reg
.
val
[
0
]);
float
answer
=
0
;
vst1q_f32
(
ptr
+
4
,
reg
.
val
[
1
]);
std
::
plus
<
VectorizedT
>
add
;
vst1q_f32
(
ptr
+
8
,
reg
.
val
[
2
]);
unroll_loop
<
int
,
VEC_REG_NUM
>
([
&
](
int
i
)
{
vst1q_f32
(
ptr
+
12
,
reg
.
val
[
3
]);
answer
+=
at
::
vec
::
vec_reduce_all
<
float
>
(
add
,
reg
.
val
[
i
]);
};
});
void
save
(
float
*
ptr
,
const
int
elem_num
)
const
{
int
full_blocks
=
elem_num
/
NUM_ELEMENTS_REG
(
reg
.
val
[
0
]);
int
remainder
=
elem_num
%
NUM_ELEMENTS_REG
(
reg
.
val
[
0
]);
for
(
int
i
=
0
;
i
<
full_blocks
;
i
++
)
vst1q_f32
(
reinterpret_cast
<
float32_t
*>
(
ptr
)
+
NUM_ELEMENTS_REG
(
reg
.
val
[
0
])
*
i
,
reg
.
val
[
i
]);
if
(
remainder
>
0
)
{
return
answer
;
float32x4_t
temp
=
reg
.
val
[
full_blocks
];
float
*
base
=
reinterpret_cast
<
float32_t
*>
(
ptr
)
+
full_blocks
*
NUM_ELEMENTS_REG
(
reg
.
val
[
0
]);
if
(
remainder
>
0
)
base
[
0
]
=
vgetq_lane_f32
(
temp
,
0
);
if
(
remainder
>
1
)
base
[
1
]
=
vgetq_lane_f32
(
temp
,
1
);
if
(
remainder
>
2
)
base
[
2
]
=
vgetq_lane_f32
(
temp
,
2
);
}
}
}
};
};
// Only used for int types for now could be replaced when
// int8/32 vectorised ops are added in ATen
template
<
typename
T
>
struct
Vec
{
constexpr
static
int
get_elem_num
()
{
return
T
::
VEC_ELEM_NUM
;
};
};
struct
INT8Vec16
:
public
Vec
<
INT8Vec16
>
{
struct
INT8Vec16
:
public
Vec
<
INT8Vec16
>
{
constexpr
static
int
VEC_ELEM_NUM
=
16
;
constexpr
static
int
VEC_ELEM_NUM
=
16
;
union
AliasReg
{
union
AliasReg
{
...
@@ -854,30 +820,47 @@ struct INT8Vec64 : public Vec<INT8Vec64> {
...
@@ -854,30 +820,47 @@ struct INT8Vec64 : public Vec<INT8Vec64> {
void
nt_save
(
int8_t
*
ptr
)
const
{
save
(
ptr
);
}
void
nt_save
(
int8_t
*
ptr
)
const
{
save
(
ptr
);
}
};
// INT8Vec64
};
// INT8Vec64
template
<
typename
T
>
struct
INT32Vec16
:
public
Vec
<
INT32Vec16
>
{
struct
VecType
{
constexpr
static
int
VEC_ELEM_NUM
=
16
;
using
vec_type
=
void
;
union
AliasReg
{
};
int32x4x4_t
reg
;
int32_t
values
[
VEC_ELEM_NUM
];
};
int32x4x4_t
reg
;
template
<
typename
T
>
explicit
INT32Vec16
(
const
void
*
ptr
)
{
using
vec_t
=
typename
VecType
<
T
>::
vec_type
;
reg
.
val
[
0
]
=
vld1q_s32
(
reinterpret_cast
<
const
int32_t
*>
(
ptr
));
reg
.
val
[
1
]
=
vld1q_s32
(
reinterpret_cast
<
const
int32_t
*>
(
ptr
)
+
4
);
reg
.
val
[
2
]
=
vld1q_s32
(
reinterpret_cast
<
const
int32_t
*>
(
ptr
)
+
8
);
reg
.
val
[
3
]
=
vld1q_s32
(
reinterpret_cast
<
const
int32_t
*>
(
ptr
)
+
12
);
}
template
<
>
void
save
(
int32_t
*
ptr
)
const
{
struct
VecType
<
float
>
{
vst1q_s32
(
ptr
,
reg
.
val
[
0
]);
using
vec_type
=
FP32Vec8
;
vst1q_s32
(
ptr
+
4
,
reg
.
val
[
1
]);
};
vst1q_s32
(
ptr
+
8
,
reg
.
val
[
2
]);
vst1q_s32
(
ptr
+
12
,
reg
.
val
[
3
]);
};
template
<
>
void
save
(
int32_t
*
ptr
,
const
int
elem_num
)
const
{
struct
VecType
<
c10
::
Half
>
{
int
full_blocks
=
elem_num
/
NUM_ELEMENTS_REG
(
reg
.
val
[
0
]);
using
vec_type
=
FP16Vec8
;
int
remainder
=
elem_num
%
NUM_ELEMENTS_REG
(
reg
.
val
[
0
]);
};
#ifdef ARM_BF16_SUPPORT
for
(
int
i
=
0
;
i
<
full_blocks
;
i
++
)
template
<
>
vst1q_s32
(
struct
VecType
<
c10
::
BFloat16
>
{
reinterpret_cast
<
__int32_t
*>
(
ptr
)
+
NUM_ELEMENTS_REG
(
reg
.
val
[
0
])
*
i
,
using
vec_type
=
BF16Vec8
;
reg
.
val
[
i
]);
if
(
remainder
>
0
)
{
int32x4_t
temp
=
reg
.
val
[
full_blocks
];
int32_t
*
base
=
reinterpret_cast
<
int32_t
*>
(
ptr
)
+
full_blocks
*
4
;
if
(
remainder
>
0
)
base
[
0
]
=
vgetq_lane_s32
(
temp
,
0
);
if
(
remainder
>
1
)
base
[
1
]
=
vgetq_lane_s32
(
temp
,
1
);
if
(
remainder
>
2
)
base
[
2
]
=
vgetq_lane_s32
(
temp
,
2
);
if
(
remainder
>
3
)
base
[
3
]
=
vgetq_lane_s32
(
temp
,
3
);
}
}
};
};
#endif
template
<
typename
T
>
template
<
typename
T
>
void
storeFP32
(
float
v
,
T
*
ptr
)
{
void
storeFP32
(
float
v
,
T
*
ptr
)
{
...
@@ -889,66 +872,55 @@ inline void storeFP32<c10::Half>(float v, c10::Half* ptr) {
...
@@ -889,66 +872,55 @@ inline void storeFP32<c10::Half>(float v, c10::Half* ptr) {
*
reinterpret_cast
<
__fp16
*>
(
ptr
)
=
v
;
*
reinterpret_cast
<
__fp16
*>
(
ptr
)
=
v
;
}
}
inline
FP16Vec16
::
FP16Vec16
(
const
FP32Vec16
&
v
)
{
inline
FP16Vec8
::
FP16Vec8
(
const
FP32Vec8
&
v
)
{
float16x4_t
low_0
=
vcvt_f16_f32
(
v
.
reg
.
val
[
0
]);
reg
.
val
[
0
]
=
convert_float_half
(
v
.
reg
.
val
[
0
],
v
.
reg
.
val
[
1
]);
float16x4_t
high_0
=
vcvt_f16_f32
(
v
.
reg
.
val
[
1
]);
};
float16x4_t
low_1
=
vcvt_f16_f32
(
v
.
reg
.
val
[
2
]);
float16x4_t
high_1
=
vcvt_f16_f32
(
v
.
reg
.
val
[
3
]);
reg
.
val
[
0
]
=
vcombine_f16
(
low_0
,
high_0
);
inline
FP16Vec16
::
FP16Vec16
(
const
FP32Vec16
&
v
)
{
reg
.
val
[
1
]
=
vcombine_f16
(
low_1
,
high_1
);
reg
.
val
[
0
]
=
convert_float_half
(
v
.
reg
.
val
[
0
],
v
.
reg
.
val
[
1
]);
reg
.
val
[
1
]
=
convert_float_half
(
v
.
reg
.
val
[
2
],
v
.
reg
.
val
[
3
]);
};
};
inline
FP16Vec8
::
FP16Vec8
(
const
FP32Vec8
&
v
)
{
inline
void
fma
(
FP32Vec16
&
acc
,
FP32Vec16
&
a
,
FP32Vec16
&
b
)
{
float16x4_t
lower_half
=
vcvt_f16_f32
(
v
.
reg
.
val
[
0
]);
fmadd
(
acc
.
reg
.
val
[
0
],
a
.
reg
.
val
[
0
],
b
.
reg
.
val
[
0
]);
float16x4_t
upper_half
=
vcvt_f16_f32
(
v
.
reg
.
val
[
1
]);
fmadd
(
acc
.
reg
.
val
[
1
],
a
.
reg
.
val
[
1
],
b
.
reg
.
val
[
1
]);
fmadd
(
acc
.
reg
.
val
[
2
],
a
.
reg
.
val
[
2
],
b
.
reg
.
val
[
2
]);
fmadd
(
acc
.
reg
.
val
[
3
],
a
.
reg
.
val
[
3
],
b
.
reg
.
val
[
3
]);
};
reg
=
vcombine_f16
(
lower_half
,
upper_half
);
inline
BF16Vec8
::
BF16Vec8
(
const
FP32Vec8
&
v
)
{
reg
.
val
[
0
]
=
convert_float_bfloat16
(
v
.
reg
.
val
[
0
],
v
.
reg
.
val
[
1
]);
};
};
inline
void
fma
(
FP32Vec16
&
acc
,
FP32Vec16
&
a
,
FP32Vec16
&
b
)
{
inline
BF16Vec16
::
BF16Vec16
(
const
FP32Vec16
&
v
)
{
acc
.
reg
.
val
[
0
]
=
vfmaq_f32
(
acc
.
reg
.
val
[
0
],
a
.
reg
.
val
[
0
],
b
.
reg
.
val
[
0
]);
reg
.
val
[
0
]
=
convert_float_bfloat16
(
v
.
reg
.
val
[
0
],
v
.
reg
.
val
[
1
]);
acc
.
reg
.
val
[
1
]
=
vfmaq_f32
(
acc
.
reg
.
val
[
1
],
a
.
reg
.
val
[
1
],
b
.
reg
.
val
[
1
]);
reg
.
val
[
1
]
=
convert_float_bfloat16
(
v
.
reg
.
val
[
2
],
v
.
reg
.
val
[
3
]);
acc
.
reg
.
val
[
2
]
=
vfmaq_f32
(
acc
.
reg
.
val
[
2
],
a
.
reg
.
val
[
2
],
b
.
reg
.
val
[
2
]);
acc
.
reg
.
val
[
3
]
=
vfmaq_f32
(
acc
.
reg
.
val
[
3
],
a
.
reg
.
val
[
3
],
b
.
reg
.
val
[
3
]);
};
};
#ifdef ARM_BF16_SUPPORT
inline
void
fma
(
FP32Vec16
&
acc
,
BF16Vec32
&
a
,
BF16Vec32
&
b
)
{
inline
void
fma
(
FP32Vec16
&
acc
,
BF16Vec32
&
a
,
BF16Vec32
&
b
)
{
float32x4_t
a0_low
=
vcvt_f32_bf16
(
vget_low_bf16
(
a
.
reg
.
val
[
0
]));
Vectorized
<
float
>
a0_low
,
a0_high
,
a1_low
,
a1_high
,
b0_low
,
b0_high
,
b1_low
,
float32x4_t
a0_high
=
vcvt_f32_bf16
(
vget_high_bf16
(
a
.
reg
.
val
[
0
]));
b1_high
;
float32x4_t
a1_low
=
vcvt_f32_bf16
(
vget_low_bf16
(
a
.
reg
.
val
[
1
]));
float32x4_t
a1_high
=
vcvt_f32_bf16
(
vget_high_bf16
(
a
.
reg
.
val
[
1
]));
std
::
tie
(
a0_low
,
a0_high
)
=
convert_bfloat16_float
(
a
.
reg
.
val
[
0
]);
std
::
tie
(
a1_low
,
a1_high
)
=
convert_bfloat16_float
(
a
.
reg
.
val
[
1
]);
float32x4_t
b0_low
=
vcvt_f32_bf16
(
vget_low_bf16
(
b
.
reg
.
val
[
0
]));
std
::
tie
(
b0_low
,
b0_high
)
=
convert_bfloat16_float
(
b
.
reg
.
val
[
0
]);
float32x4_t
b0_high
=
vcvt_f32_bf16
(
vget_high_bf16
(
b
.
reg
.
val
[
0
]));
std
::
tie
(
b1_low
,
b1_high
)
=
convert_bfloat16_float
(
b
.
reg
.
val
[
1
]);
float32x4_t
b1_low
=
vcvt_f32_bf16
(
vget_low_bf16
(
b
.
reg
.
val
[
1
]));
float32x4_t
b1_high
=
vcvt_f32_bf16
(
vget_high_bf16
(
b
.
reg
.
val
[
1
]));
fmadd
(
acc
.
reg
.
val
[
0
],
a0_low
,
b0_low
);
fmadd
(
acc
.
reg
.
val
[
1
],
a0_high
,
b0_high
);
acc
.
reg
.
val
[
0
]
=
vfmaq_f32
(
acc
.
reg
.
val
[
0
],
a0_low
,
b0_low
);
fmadd
(
acc
.
reg
.
val
[
2
],
a1_low
,
b1_low
);
acc
.
reg
.
val
[
1
]
=
vfmaq_f32
(
acc
.
reg
.
val
[
1
],
a0_high
,
b0_high
);
fmadd
(
acc
.
reg
.
val
[
3
],
a1_high
,
b1_high
);
acc
.
reg
.
val
[
2
]
=
vfmaq_f32
(
acc
.
reg
.
val
[
2
],
a1_low
,
b1_low
);
acc
.
reg
.
val
[
3
]
=
vfmaq_f32
(
acc
.
reg
.
val
[
3
],
a1_high
,
b1_high
);
};
};
#endif
template
<
>
inline
void
storeFP32
<
c10
::
BFloat16
>
(
float
v
,
c10
::
BFloat16
*
ptr
)
{
#ifdef ARM_BF16_SUPPORT
#ifdef ARM_BF16_SUPPORT
inline
BF16Vec8
::
BF16Vec8
(
const
FP32Vec8
&
v
)
*
reinterpret_cast
<
__bf16
*>
(
ptr
)
=
vcvth_bf16_f32
(
v
);
:
reg
(
vcvtq_high_bf16_f32
(
vcvtq_low_bf16_f32
(
v
.
reg
.
val
[
0
]),
v
.
reg
.
val
[
1
]))
{
#else
};
*
ptr
=
static_cast
<
c10
::
BFloat16
>
(
v
);
inline
BF16Vec16
::
BF16Vec16
(
const
FP32Vec16
&
v
)
:
reg
({
vcvtq_high_bf16_f32
(
vcvtq_low_bf16_f32
(
v
.
reg
.
val
[
0
]),
v
.
reg
.
val
[
1
]),
vcvtq_high_bf16_f32
(
vcvtq_low_bf16_f32
(
v
.
reg
.
val
[
2
]),
v
.
reg
.
val
[
3
])})
{};
#endif
#endif
};
inline
void
prefetch
(
const
void
*
addr
)
{
__builtin_prefetch
(
addr
,
0
,
1
);
};
inline
void
prefetch
(
const
void
*
addr
)
{
__builtin_prefetch
(
addr
,
0
,
1
);
};
#ifdef ARM_BF16_SUPPORT
template
<
>
inline
void
storeFP32
<
c10
::
BFloat16
>
(
float
v
,
c10
::
BFloat16
*
ptr
)
{
*
reinterpret_cast
<
__bf16
*>
(
ptr
)
=
vcvth_bf16_f32
(
v
);
};
#endif
};
// namespace vec_op
};
// namespace vec_op
\ No newline at end of file
csrc/cpu/dnnl_kernels.cpp
View file @
e69c990c
...
@@ -14,13 +14,11 @@ struct KernelVecType<float> {
...
@@ -14,13 +14,11 @@ struct KernelVecType<float> {
using
cvt_vec_type
=
vec_op
::
FP32Vec16
;
using
cvt_vec_type
=
vec_op
::
FP32Vec16
;
};
};
#if !defined(__aarch64__) || defined(ARM_BF16_SUPPORT)
template
<
>
template
<
>
struct
KernelVecType
<
c10
::
BFloat16
>
{
struct
KernelVecType
<
c10
::
BFloat16
>
{
using
load_vec_type
=
vec_op
::
BF16Vec16
;
using
load_vec_type
=
vec_op
::
BF16Vec16
;
using
cvt_vec_type
=
vec_op
::
FP32Vec16
;
using
cvt_vec_type
=
vec_op
::
FP32Vec16
;
};
};
#endif
template
<
>
template
<
>
struct
KernelVecType
<
c10
::
Half
>
{
struct
KernelVecType
<
c10
::
Half
>
{
...
...
csrc/cpu/mla_decode.cpp
View file @
e69c990c
...
@@ -38,9 +38,7 @@ struct KernelVecType<c10::BFloat16> {
...
@@ -38,9 +38,7 @@ struct KernelVecType<c10::BFloat16> {
using
qk_vec_type
=
vec_op
::
BF16Vec32
;
using
qk_vec_type
=
vec_op
::
BF16Vec32
;
using
v_load_vec_type
=
vec_op
::
BF16Vec16
;
using
v_load_vec_type
=
vec_op
::
BF16Vec16
;
};
};
#elif defined(__aarch64__) && !defined(ARM_BF16_SUPPORT)
#elif defined(__aarch64__)
// pass
#else
template
<
>
template
<
>
struct
KernelVecType
<
c10
::
BFloat16
>
{
struct
KernelVecType
<
c10
::
BFloat16
>
{
using
qk_load_vec_type
=
vec_op
::
BF16Vec16
;
using
qk_load_vec_type
=
vec_op
::
BF16Vec16
;
...
...
csrc/cpu/utils.hpp
View file @
e69c990c
...
@@ -30,12 +30,10 @@ struct VecTypeTrait<float> {
...
@@ -30,12 +30,10 @@ struct VecTypeTrait<float> {
using
vec_t
=
vec_op
::
FP32Vec16
;
using
vec_t
=
vec_op
::
FP32Vec16
;
};
};
#if !defined(__aarch64__) || defined(ARM_BF16_SUPPORT)
template
<
>
template
<
>
struct
VecTypeTrait
<
c10
::
BFloat16
>
{
struct
VecTypeTrait
<
c10
::
BFloat16
>
{
using
vec_t
=
vec_op
::
BF16Vec16
;
using
vec_t
=
vec_op
::
BF16Vec16
;
};
};
#endif
#if !defined(__powerpc__)
#if !defined(__powerpc__)
template
<
>
template
<
>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment