Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
d7211684
Unverified
Commit
d7211684
authored
May 27, 2023
by
Woosuk Kwon
Committed by
GitHub
May 27, 2023
Browse files
Improve setup script & Add a guard for bfloat16 kernels (#130)
parent
4a151dd4
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
90 additions
and
16 deletions
+90
-16
csrc/attention/attention_dtypes.h
csrc/attention/attention_dtypes.h
+0
-3
csrc/attention/attention_kernels.cu
csrc/attention/attention_kernels.cu
+0
-2
csrc/attention/dtype_bfloat16.cuh
csrc/attention/dtype_bfloat16.cuh
+44
-0
setup.py
setup.py
+46
-11
No files found.
csrc/attention/attention_dtypes.h
View file @
d7211684
...
@@ -3,7 +3,4 @@
...
@@ -3,7 +3,4 @@
#include "attention_generic.cuh"
#include "attention_generic.cuh"
#include "dtype_float16.cuh"
#include "dtype_float16.cuh"
#include "dtype_float32.cuh"
#include "dtype_float32.cuh"
#ifdef ENABLE_BF16
#include "dtype_bfloat16.cuh"
#include "dtype_bfloat16.cuh"
#endif // ENABLE_BF16
csrc/attention/attention_kernels.cu
View file @
d7211684
...
@@ -458,10 +458,8 @@ void single_query_cached_kv_attention(
...
@@ -458,10 +458,8 @@ void single_query_cached_kv_attention(
// TODO(woosuk): Support FP32.
// TODO(woosuk): Support FP32.
if
(
query
.
dtype
()
==
at
::
ScalarType
::
Half
)
{
if
(
query
.
dtype
()
==
at
::
ScalarType
::
Half
)
{
CALL_KERNEL_LAUNCHER_BLOCK_SIZE
(
uint16_t
);
CALL_KERNEL_LAUNCHER_BLOCK_SIZE
(
uint16_t
);
#ifdef ENABLE_BF16
}
else
if
(
query
.
dtype
()
==
at
::
ScalarType
::
BFloat16
)
{
}
else
if
(
query
.
dtype
()
==
at
::
ScalarType
::
BFloat16
)
{
CALL_KERNEL_LAUNCHER_BLOCK_SIZE
(
__nv_bfloat16
);
CALL_KERNEL_LAUNCHER_BLOCK_SIZE
(
__nv_bfloat16
);
#endif
}
else
{
}
else
{
TORCH_CHECK
(
false
,
"Unsupported data type: "
,
query
.
dtype
());
TORCH_CHECK
(
false
,
"Unsupported data type: "
,
query
.
dtype
());
}
}
...
...
csrc/attention/dtype_bfloat16.cuh
View file @
d7211684
...
@@ -78,20 +78,36 @@ struct FloatVec<bf16_8_t> {
...
@@ -78,20 +78,36 @@ struct FloatVec<bf16_8_t> {
// Utility functions for type conversions.
// Utility functions for type conversions.
inline
__device__
float2
bf1622float2
(
const
__nv_bfloat162
val
)
{
inline
__device__
float2
bf1622float2
(
const
__nv_bfloat162
val
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert
(
false
);
#else
return
__bfloat1622float2
(
val
);
return
__bfloat1622float2
(
val
);
#endif
}
}
inline
__device__
__nv_bfloat162
bf162bf162
(
const
__nv_bfloat16
val
)
{
inline
__device__
__nv_bfloat162
bf162bf162
(
const
__nv_bfloat16
val
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert
(
false
);
#else
return
__bfloat162bfloat162
(
val
);
return
__bfloat162bfloat162
(
val
);
#endif
}
}
// Vector addition.
// Vector addition.
inline
__device__
__nv_bfloat16
add
(
__nv_bfloat16
a
,
__nv_bfloat16
b
)
{
inline
__device__
__nv_bfloat16
add
(
__nv_bfloat16
a
,
__nv_bfloat16
b
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert
(
false
);
#else
return
a
+
b
;
return
a
+
b
;
#endif
}
}
inline
__device__
__nv_bfloat162
add
(
__nv_bfloat162
a
,
__nv_bfloat162
b
)
{
inline
__device__
__nv_bfloat162
add
(
__nv_bfloat162
a
,
__nv_bfloat162
b
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert
(
false
);
#else
return
__hadd2
(
a
,
b
);
return
__hadd2
(
a
,
b
);
#endif
}
}
inline
__device__
bf16_4_t
add
(
bf16_4_t
a
,
bf16_4_t
b
)
{
inline
__device__
bf16_4_t
add
(
bf16_4_t
a
,
bf16_4_t
b
)
{
...
@@ -134,12 +150,20 @@ inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) {
...
@@ -134,12 +150,20 @@ inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) {
// Vector multiplication.
// Vector multiplication.
template
<
>
template
<
>
inline
__device__
__nv_bfloat16
mul
(
__nv_bfloat16
a
,
__nv_bfloat16
b
)
{
inline
__device__
__nv_bfloat16
mul
(
__nv_bfloat16
a
,
__nv_bfloat16
b
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert
(
false
);
#else
return
__hmul
(
a
,
b
);
return
__hmul
(
a
,
b
);
#endif
}
}
template
<
>
template
<
>
inline
__device__
__nv_bfloat162
mul
(
__nv_bfloat162
a
,
__nv_bfloat162
b
)
{
inline
__device__
__nv_bfloat162
mul
(
__nv_bfloat162
a
,
__nv_bfloat162
b
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert
(
false
);
#else
return
__hmul2
(
a
,
b
);
return
__hmul2
(
a
,
b
);
#endif
}
}
template
<
>
template
<
>
...
@@ -244,11 +268,19 @@ inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) {
...
@@ -244,11 +268,19 @@ inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) {
// Vector fused multiply-add.
// Vector fused multiply-add.
inline
__device__
__nv_bfloat162
fma
(
__nv_bfloat162
a
,
__nv_bfloat162
b
,
__nv_bfloat162
c
)
{
inline
__device__
__nv_bfloat162
fma
(
__nv_bfloat162
a
,
__nv_bfloat162
b
,
__nv_bfloat162
c
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert
(
false
);
#else
return
__hfma2
(
a
,
b
,
c
);
return
__hfma2
(
a
,
b
,
c
);
#endif
}
}
inline
__device__
__nv_bfloat162
fma
(
__nv_bfloat16
a
,
__nv_bfloat162
b
,
__nv_bfloat162
c
)
{
inline
__device__
__nv_bfloat162
fma
(
__nv_bfloat16
a
,
__nv_bfloat162
b
,
__nv_bfloat162
c
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert
(
false
);
#else
return
__hfma2
(
bf162bf162
(
a
),
b
,
c
);
return
__hfma2
(
bf162bf162
(
a
),
b
,
c
);
#endif
}
}
inline
__device__
bf16_4_t
fma
(
bf16_4_t
a
,
bf16_4_t
b
,
bf16_4_t
c
)
{
inline
__device__
bf16_4_t
fma
(
bf16_4_t
a
,
bf16_4_t
b
,
bf16_4_t
c
)
{
...
@@ -361,19 +393,31 @@ inline __device__ void from_float(__nv_bfloat16& dst, float src) {
...
@@ -361,19 +393,31 @@ inline __device__ void from_float(__nv_bfloat16& dst, float src) {
}
}
inline
__device__
void
from_float
(
__nv_bfloat162
&
dst
,
float2
src
)
{
inline
__device__
void
from_float
(
__nv_bfloat162
&
dst
,
float2
src
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert
(
false
);
#else
dst
=
__float22bfloat162_rn
(
src
);
dst
=
__float22bfloat162_rn
(
src
);
#endif
}
}
inline
__device__
void
from_float
(
bf16_4_t
&
dst
,
Float4_
src
)
{
inline
__device__
void
from_float
(
bf16_4_t
&
dst
,
Float4_
src
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert
(
false
);
#else
dst
.
x
=
__float22bfloat162_rn
(
src
.
x
);
dst
.
x
=
__float22bfloat162_rn
(
src
.
x
);
dst
.
y
=
__float22bfloat162_rn
(
src
.
y
);
dst
.
y
=
__float22bfloat162_rn
(
src
.
y
);
#endif
}
}
inline
__device__
void
from_float
(
bf16_8_t
&
dst
,
Float8_
src
)
{
inline
__device__
void
from_float
(
bf16_8_t
&
dst
,
Float8_
src
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert
(
false
);
#else
dst
.
x
=
__float22bfloat162_rn
(
src
.
x
);
dst
.
x
=
__float22bfloat162_rn
(
src
.
x
);
dst
.
y
=
__float22bfloat162_rn
(
src
.
y
);
dst
.
y
=
__float22bfloat162_rn
(
src
.
y
);
dst
.
z
=
__float22bfloat162_rn
(
src
.
z
);
dst
.
z
=
__float22bfloat162_rn
(
src
.
z
);
dst
.
w
=
__float22bfloat162_rn
(
src
.
w
);
dst
.
w
=
__float22bfloat162_rn
(
src
.
w
);
#endif
}
}
}
// namespace cacheflow
}
// namespace cacheflow
setup.py
View file @
d7211684
from
typing
import
List
import
subprocess
from
typing
import
List
,
Set
from
packaging.version
import
parse
,
Version
import
setuptools
import
setuptools
import
torch
import
torch
from
torch.utils.cpp_extension
import
BuildExtension
,
CUDAExtension
from
torch.utils.cpp_extension
import
BuildExtension
,
CUDAExtension
from
torch.utils.cpp_extension
import
CUDA_HOME
from
torch.utils.cpp_extension
import
CUDA_HOME
# Compiler flags.
# Build custom operators.
CXX_FLAGS
=
[
"-g"
,
"-O2"
]
CXX_FLAGS
=
[
"-g"
]
# TODO(woosuk): Should we use -O3?
# TODO(woosuk): Should we use -O3?
NVCC_FLAGS
=
[
"-O2"
]
NVCC_FLAGS
=
[
"-O2"
]
if
not
torch
.
cuda
.
is_available
():
if
not
torch
.
cuda
.
is_available
():
raise
RuntimeError
(
raise
RuntimeError
(
f
"Cannot find CUDA at CUDA_HOME:
{
CUDA_HOME
}
. "
f
"Cannot find CUDA at CUDA_HOME:
{
CUDA_HOME
}
. "
"CUDA must be available in order to build the package."
)
"CUDA must be available in order to build the package."
)
# FIXME(woosuk): Consider the case where the machine has multiple GPUs with
# different compute capabilities.
def
get_nvcc_cuda_version
(
cuda_dir
:
str
)
->
Version
:
compute_capability
=
torch
.
cuda
.
get_device_capability
()
"""Get the CUDA version from nvcc.
major
,
minor
=
compute_capability
# Enable bfloat16 support if the compute capability is >= 8.0.
Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py
if
major
>=
8
:
"""
NVCC_FLAGS
.
append
(
"-DENABLE_BF16"
)
nvcc_output
=
subprocess
.
check_output
([
cuda_dir
+
"/bin/nvcc"
,
"-V"
],
universal_newlines
=
True
)
output
=
nvcc_output
.
split
()
release_idx
=
output
.
index
(
"release"
)
+
1
nvcc_cuda_version
=
parse
(
output
[
release_idx
].
split
(
","
)[
0
])
return
nvcc_cuda_version
# Collect the compute capabilities of all available GPUs.
device_count
=
torch
.
cuda
.
device_count
()
compute_capabilities
:
Set
[
int
]
=
set
()
for
i
in
range
(
device_count
):
major
,
minor
=
torch
.
cuda
.
get_device_capability
(
i
)
if
major
<
7
:
raise
RuntimeError
(
"GPUs with compute capability less than 7.0 are not supported."
)
compute_capabilities
.
add
(
major
*
10
+
minor
)
# If no GPU is available, add all supported compute capabilities.
if
not
compute_capabilities
:
compute_capabilities
=
{
70
,
75
,
80
,
86
,
90
}
# Add target compute capabilities to NVCC flags.
for
capability
in
compute_capabilities
:
NVCC_FLAGS
+=
[
"-gencode"
,
f
"arch=compute_
{
capability
}
,code=sm_
{
capability
}
"
]
# Validate the NVCC CUDA version.
nvcc_cuda_version
=
get_nvcc_cuda_version
(
CUDA_HOME
)
if
nvcc_cuda_version
<
Version
(
"11.0"
):
raise
RuntimeError
(
"CUDA 11.0 or higher is required to build the package."
)
if
86
in
compute_capabilities
and
nvcc_cuda_version
<
Version
(
"11.1"
):
raise
RuntimeError
(
"CUDA 11.1 or higher is required for GPUs with compute capability 8.6."
)
if
90
in
compute_capabilities
and
nvcc_cuda_version
<
Version
(
"11.8"
):
raise
RuntimeError
(
"CUDA 11.8 or higher is required for GPUs with compute capability 9.0."
)
ext_modules
=
[]
ext_modules
=
[]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment