Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
a10e9cee
Commit
a10e9cee
authored
May 29, 2024
by
zhuwenwen
Browse files
support bf16b infer
parent
675c0abe
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
48 additions
and
48 deletions
+48
-48
csrc/attention/dtype_bfloat16.cuh
csrc/attention/dtype_bfloat16.cuh
+48
-48
No files found.
csrc/attention/dtype_bfloat16.cuh
View file @
a10e9cee
...
...
@@ -87,40 +87,40 @@ struct FloatVec<bf16_8_t> {
// Utility functions for type conversions.
inline
__device__
float2
bf1622float2
(
const
__nv_bfloat162
val
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert
(
false
);
#else
//
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
//
assert(false);
//
#else
return
__bfloat1622float2
(
val
);
#endif
//
#endif
}
inline
__device__
__nv_bfloat162
bf162bf162
(
const
__nv_bfloat16
val
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert
(
false
);
#else
//
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
//
assert(false);
//
#else
return
__bfloat162bfloat162
(
val
);
#endif
//
#endif
}
// Vector addition.
inline
__device__
__nv_bfloat16
add
(
__nv_bfloat16
a
,
__nv_bfloat16
b
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert
(
false
);
#else
//
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
//
assert(false);
//
#else
#ifndef USE_ROCM
return
a
+
b
;
#else
return
__hadd
(
a
,
b
);
#endif
#endif
//
#endif
}
inline
__device__
__nv_bfloat162
add
(
__nv_bfloat162
a
,
__nv_bfloat162
b
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert
(
false
);
#else
//
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
//
assert(false);
//
#else
return
__hadd2
(
a
,
b
);
#endif
//
#endif
}
inline
__device__
bf16_4_t
add
(
bf16_4_t
a
,
bf16_4_t
b
)
{
...
...
@@ -163,20 +163,20 @@ inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) {
// Vector multiplication.
template
<
>
inline
__device__
__nv_bfloat16
mul
(
__nv_bfloat16
a
,
__nv_bfloat16
b
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert
(
false
);
#else
//
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
//
assert(false);
//
#else
return
__hmul
(
a
,
b
);
#endif
//
#endif
}
template
<
>
inline
__device__
__nv_bfloat162
mul
(
__nv_bfloat162
a
,
__nv_bfloat162
b
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert
(
false
);
#else
//
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
//
assert(false);
//
#else
return
__hmul2
(
a
,
b
);
#endif
//
#endif
}
template
<
>
...
...
@@ -281,19 +281,19 @@ inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) {
// Vector fused multiply-add.
inline
__device__
__nv_bfloat162
fma
(
__nv_bfloat162
a
,
__nv_bfloat162
b
,
__nv_bfloat162
c
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert
(
false
);
#else
//
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
//
assert(false);
//
#else
return
__hfma2
(
a
,
b
,
c
);
#endif
//
#endif
}
inline
__device__
__nv_bfloat162
fma
(
__nv_bfloat16
a
,
__nv_bfloat162
b
,
__nv_bfloat162
c
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert
(
false
);
#else
//
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
//
assert(false);
//
#else
return
__hfma2
(
bf162bf162
(
a
),
b
,
c
);
#endif
//
#endif
}
inline
__device__
bf16_4_t
fma
(
bf16_4_t
a
,
bf16_4_t
b
,
bf16_4_t
c
)
{
...
...
@@ -406,31 +406,31 @@ inline __device__ void from_float(__nv_bfloat16& dst, float src) {
}
inline
__device__
void
from_float
(
__nv_bfloat162
&
dst
,
float2
src
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert
(
false
);
#else
//
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
//
assert(false);
//
#else
dst
=
__float22bfloat162_rn
(
src
);
#endif
//
#endif
}
inline
__device__
void
from_float
(
bf16_4_t
&
dst
,
Float4_
src
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert
(
false
);
#else
//
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
//
assert(false);
//
#else
dst
.
x
=
__float22bfloat162_rn
(
src
.
x
);
dst
.
y
=
__float22bfloat162_rn
(
src
.
y
);
#endif
//
#endif
}
inline
__device__
void
from_float
(
bf16_8_t
&
dst
,
Float8_
src
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert
(
false
);
#else
//
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
//
assert(false);
//
#else
dst
.
x
=
__float22bfloat162_rn
(
src
.
x
);
dst
.
y
=
__float22bfloat162_rn
(
src
.
y
);
dst
.
z
=
__float22bfloat162_rn
(
src
.
z
);
dst
.
w
=
__float22bfloat162_rn
(
src
.
w
);
#endif
//
#endif
}
// From bfloat16 to float32.
...
...
@@ -440,12 +440,12 @@ inline __device__ float to_float(__nv_bfloat16 u) {
// Zero-out a variable.
inline
__device__
void
zero
(
__nv_bfloat16
&
dst
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert
(
false
);
#else
//
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
//
assert(false);
//
#else
// Same as CUDART_ZERO_BF16 introduced in CUDA 12.2.
dst
=
__ushort_as_bfloat16
((
unsigned
short
)
0x0000U
);
#endif
//
#endif
}
}
// namespace vllm
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment