Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
d7211684
"src/vscode:/vscode.git/clone" did not exist on "a54b16a2d1cdd022c8b52ae403e22d6c7dd7518f"
Unverified
Commit
d7211684
authored
May 27, 2023
by
Woosuk Kwon
Committed by
GitHub
May 27, 2023
Browse files
Improve setup script & Add a guard for bfloat16 kernels (#130)
parent
4a151dd4
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
90 additions
and
16 deletions
+90
-16
csrc/attention/attention_dtypes.h
csrc/attention/attention_dtypes.h
+0
-3
csrc/attention/attention_kernels.cu
csrc/attention/attention_kernels.cu
+0
-2
csrc/attention/dtype_bfloat16.cuh
csrc/attention/dtype_bfloat16.cuh
+44
-0
setup.py
setup.py
+46
-11
No files found.
csrc/attention/attention_dtypes.h
View file @
d7211684
...
@@ -3,7 +3,4 @@
...
@@ -3,7 +3,4 @@
#include "attention_generic.cuh"
#include "attention_generic.cuh"
#include "dtype_float16.cuh"
#include "dtype_float16.cuh"
#include "dtype_float32.cuh"
#include "dtype_float32.cuh"
#ifdef ENABLE_BF16
#include "dtype_bfloat16.cuh"
#include "dtype_bfloat16.cuh"
#endif // ENABLE_BF16
csrc/attention/attention_kernels.cu
View file @
d7211684
...
@@ -458,10 +458,8 @@ void single_query_cached_kv_attention(
...
@@ -458,10 +458,8 @@ void single_query_cached_kv_attention(
// TODO(woosuk): Support FP32.
// TODO(woosuk): Support FP32.
if
(
query
.
dtype
()
==
at
::
ScalarType
::
Half
)
{
if
(
query
.
dtype
()
==
at
::
ScalarType
::
Half
)
{
CALL_KERNEL_LAUNCHER_BLOCK_SIZE
(
uint16_t
);
CALL_KERNEL_LAUNCHER_BLOCK_SIZE
(
uint16_t
);
#ifdef ENABLE_BF16
}
else
if
(
query
.
dtype
()
==
at
::
ScalarType
::
BFloat16
)
{
}
else
if
(
query
.
dtype
()
==
at
::
ScalarType
::
BFloat16
)
{
CALL_KERNEL_LAUNCHER_BLOCK_SIZE
(
__nv_bfloat16
);
CALL_KERNEL_LAUNCHER_BLOCK_SIZE
(
__nv_bfloat16
);
#endif
}
else
{
}
else
{
TORCH_CHECK
(
false
,
"Unsupported data type: "
,
query
.
dtype
());
TORCH_CHECK
(
false
,
"Unsupported data type: "
,
query
.
dtype
());
}
}
...
...
csrc/attention/dtype_bfloat16.cuh
View file @
d7211684
...
@@ -78,20 +78,36 @@ struct FloatVec<bf16_8_t> {
...
@@ -78,20 +78,36 @@ struct FloatVec<bf16_8_t> {
// Utility functions for type conversions.
// Utility functions for type conversions.
inline
__device__
float2
bf1622float2
(
const
__nv_bfloat162
val
)
{
inline
__device__
float2
bf1622float2
(
const
__nv_bfloat162
val
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert
(
false
);
#else
return
__bfloat1622float2
(
val
);
return
__bfloat1622float2
(
val
);
#endif
}
}
inline
__device__
__nv_bfloat162
bf162bf162
(
const
__nv_bfloat16
val
)
{
inline
__device__
__nv_bfloat162
bf162bf162
(
const
__nv_bfloat16
val
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert
(
false
);
#else
return
__bfloat162bfloat162
(
val
);
return
__bfloat162bfloat162
(
val
);
#endif
}
}
// Vector addition.
// Vector addition.
inline
__device__
__nv_bfloat16
add
(
__nv_bfloat16
a
,
__nv_bfloat16
b
)
{
inline
__device__
__nv_bfloat16
add
(
__nv_bfloat16
a
,
__nv_bfloat16
b
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert
(
false
);
#else
return
a
+
b
;
return
a
+
b
;
#endif
}
}
inline
__device__
__nv_bfloat162
add
(
__nv_bfloat162
a
,
__nv_bfloat162
b
)
{
inline
__device__
__nv_bfloat162
add
(
__nv_bfloat162
a
,
__nv_bfloat162
b
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert
(
false
);
#else
return
__hadd2
(
a
,
b
);
return
__hadd2
(
a
,
b
);
#endif
}
}
inline
__device__
bf16_4_t
add
(
bf16_4_t
a
,
bf16_4_t
b
)
{
inline
__device__
bf16_4_t
add
(
bf16_4_t
a
,
bf16_4_t
b
)
{
...
@@ -134,12 +150,20 @@ inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) {
...
@@ -134,12 +150,20 @@ inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) {
// Vector multiplication.
// Vector multiplication.
template
<
>
template
<
>
inline
__device__
__nv_bfloat16
mul
(
__nv_bfloat16
a
,
__nv_bfloat16
b
)
{
inline
__device__
__nv_bfloat16
mul
(
__nv_bfloat16
a
,
__nv_bfloat16
b
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert
(
false
);
#else
return
__hmul
(
a
,
b
);
return
__hmul
(
a
,
b
);
#endif
}
}
template
<
>
template
<
>
inline
__device__
__nv_bfloat162
mul
(
__nv_bfloat162
a
,
__nv_bfloat162
b
)
{
inline
__device__
__nv_bfloat162
mul
(
__nv_bfloat162
a
,
__nv_bfloat162
b
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert
(
false
);
#else
return
__hmul2
(
a
,
b
);
return
__hmul2
(
a
,
b
);
#endif
}
}
template
<
>
template
<
>
...
@@ -244,11 +268,19 @@ inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) {
...
@@ -244,11 +268,19 @@ inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) {
// Vector fused multiply-add.
// Vector fused multiply-add.
inline
__device__
__nv_bfloat162
fma
(
__nv_bfloat162
a
,
__nv_bfloat162
b
,
__nv_bfloat162
c
)
{
inline
__device__
__nv_bfloat162
fma
(
__nv_bfloat162
a
,
__nv_bfloat162
b
,
__nv_bfloat162
c
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert
(
false
);
#else
return
__hfma2
(
a
,
b
,
c
);
return
__hfma2
(
a
,
b
,
c
);
#endif
}
}
inline
__device__
__nv_bfloat162
fma
(
__nv_bfloat16
a
,
__nv_bfloat162
b
,
__nv_bfloat162
c
)
{
inline
__device__
__nv_bfloat162
fma
(
__nv_bfloat16
a
,
__nv_bfloat162
b
,
__nv_bfloat162
c
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert
(
false
);
#else
return
__hfma2
(
bf162bf162
(
a
),
b
,
c
);
return
__hfma2
(
bf162bf162
(
a
),
b
,
c
);
#endif
}
}
inline
__device__
bf16_4_t
fma
(
bf16_4_t
a
,
bf16_4_t
b
,
bf16_4_t
c
)
{
inline
__device__
bf16_4_t
fma
(
bf16_4_t
a
,
bf16_4_t
b
,
bf16_4_t
c
)
{
...
@@ -361,19 +393,31 @@ inline __device__ void from_float(__nv_bfloat16& dst, float src) {
...
@@ -361,19 +393,31 @@ inline __device__ void from_float(__nv_bfloat16& dst, float src) {
}
}
inline
__device__
void
from_float
(
__nv_bfloat162
&
dst
,
float2
src
)
{
inline
__device__
void
from_float
(
__nv_bfloat162
&
dst
,
float2
src
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert
(
false
);
#else
dst
=
__float22bfloat162_rn
(
src
);
dst
=
__float22bfloat162_rn
(
src
);
#endif
}
}
inline
__device__
void
from_float
(
bf16_4_t
&
dst
,
Float4_
src
)
{
inline
__device__
void
from_float
(
bf16_4_t
&
dst
,
Float4_
src
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert
(
false
);
#else
dst
.
x
=
__float22bfloat162_rn
(
src
.
x
);
dst
.
x
=
__float22bfloat162_rn
(
src
.
x
);
dst
.
y
=
__float22bfloat162_rn
(
src
.
y
);
dst
.
y
=
__float22bfloat162_rn
(
src
.
y
);
#endif
}
}
inline
__device__
void
from_float
(
bf16_8_t
&
dst
,
Float8_
src
)
{
inline
__device__
void
from_float
(
bf16_8_t
&
dst
,
Float8_
src
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert
(
false
);
#else
dst
.
x
=
__float22bfloat162_rn
(
src
.
x
);
dst
.
x
=
__float22bfloat162_rn
(
src
.
x
);
dst
.
y
=
__float22bfloat162_rn
(
src
.
y
);
dst
.
y
=
__float22bfloat162_rn
(
src
.
y
);
dst
.
z
=
__float22bfloat162_rn
(
src
.
z
);
dst
.
z
=
__float22bfloat162_rn
(
src
.
z
);
dst
.
w
=
__float22bfloat162_rn
(
src
.
w
);
dst
.
w
=
__float22bfloat162_rn
(
src
.
w
);
#endif
}
}
}
// namespace cacheflow
}
// namespace cacheflow
setup.py
View file @
d7211684
from
typing
import
List
import
subprocess
from
typing
import
List
,
Set
from
packaging.version
import
parse
,
Version
import
setuptools
import
setuptools
import
torch
import
torch
from
torch.utils.cpp_extension
import
BuildExtension
,
CUDAExtension
from
torch.utils.cpp_extension
import
BuildExtension
,
CUDAExtension
from
torch.utils.cpp_extension
import
CUDA_HOME
from
torch.utils.cpp_extension
import
CUDA_HOME
# Compiler flags.
# Build custom operators.
CXX_FLAGS
=
[
"-g"
,
"-O2"
]
CXX_FLAGS
=
[
"-g"
]
# TODO(woosuk): Should we use -O3?
# TODO(woosuk): Should we use -O3?
NVCC_FLAGS
=
[
"-O2"
]
NVCC_FLAGS
=
[
"-O2"
]
if
not
torch
.
cuda
.
is_available
():
if
not
torch
.
cuda
.
is_available
():
raise
RuntimeError
(
raise
RuntimeError
(
f
"Cannot find CUDA at CUDA_HOME:
{
CUDA_HOME
}
. "
f
"Cannot find CUDA at CUDA_HOME:
{
CUDA_HOME
}
. "
"CUDA must be available in order to build the package."
)
"CUDA must be available in order to build the package."
)
# FIXME(woosuk): Consider the case where the machine has multiple GPUs with
# different compute capabilities.
def
get_nvcc_cuda_version
(
cuda_dir
:
str
)
->
Version
:
compute_capability
=
torch
.
cuda
.
get_device_capability
()
"""Get the CUDA version from nvcc.
major
,
minor
=
compute_capability
# Enable bfloat16 support if the compute capability is >= 8.0.
Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py
if
major
>=
8
:
"""
NVCC_FLAGS
.
append
(
"-DENABLE_BF16"
)
nvcc_output
=
subprocess
.
check_output
([
cuda_dir
+
"/bin/nvcc"
,
"-V"
],
universal_newlines
=
True
)
output
=
nvcc_output
.
split
()
release_idx
=
output
.
index
(
"release"
)
+
1
nvcc_cuda_version
=
parse
(
output
[
release_idx
].
split
(
","
)[
0
])
return
nvcc_cuda_version
# Collect the compute capabilities of all available GPUs.
device_count
=
torch
.
cuda
.
device_count
()
compute_capabilities
:
Set
[
int
]
=
set
()
for
i
in
range
(
device_count
):
major
,
minor
=
torch
.
cuda
.
get_device_capability
(
i
)
if
major
<
7
:
raise
RuntimeError
(
"GPUs with compute capability less than 7.0 are not supported."
)
compute_capabilities
.
add
(
major
*
10
+
minor
)
# If no GPU is available, add all supported compute capabilities.
if
not
compute_capabilities
:
compute_capabilities
=
{
70
,
75
,
80
,
86
,
90
}
# Add target compute capabilities to NVCC flags.
for
capability
in
compute_capabilities
:
NVCC_FLAGS
+=
[
"-gencode"
,
f
"arch=compute_
{
capability
}
,code=sm_
{
capability
}
"
]
# Validate the NVCC CUDA version.
nvcc_cuda_version
=
get_nvcc_cuda_version
(
CUDA_HOME
)
if
nvcc_cuda_version
<
Version
(
"11.0"
):
raise
RuntimeError
(
"CUDA 11.0 or higher is required to build the package."
)
if
86
in
compute_capabilities
and
nvcc_cuda_version
<
Version
(
"11.1"
):
raise
RuntimeError
(
"CUDA 11.1 or higher is required for GPUs with compute capability 8.6."
)
if
90
in
compute_capabilities
and
nvcc_cuda_version
<
Version
(
"11.8"
):
raise
RuntimeError
(
"CUDA 11.8 or higher is required for GPUs with compute capability 9.0."
)
ext_modules
=
[]
ext_modules
=
[]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment