Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
e38074b1
Unverified
Commit
e38074b1
authored
Jun 07, 2023
by
Woosuk Kwon
Committed by
GitHub
Jun 07, 2023
Browse files
Support FP32 (#141)
parent
376725ce
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
65 additions
and
54 deletions
+65
-54
cacheflow/config.py
cacheflow/config.py
+3
-4
cacheflow/entrypoints/llm.py
cacheflow/entrypoints/llm.py
+5
-4
cacheflow/model_executor/layers/attention.py
cacheflow/model_executor/layers/attention.py
+3
-5
cacheflow/server/arg_utils.py
cacheflow/server/arg_utils.py
+4
-4
csrc/attention/attention_kernels.cu
csrc/attention/attention_kernels.cu
+37
-32
docs/source/getting_started/installation.rst
docs/source/getting_started/installation.rst
+3
-0
setup.py
setup.py
+5
-0
tests/kernels/test_attention.py
tests/kernels/test_attention.py
+5
-5
No files found.
cacheflow/config.py
View file @
e38074b1
...
@@ -164,7 +164,7 @@ def _get_and_verify_dtype(
...
@@ -164,7 +164,7 @@ def _get_and_verify_dtype(
config_dtype
=
torch
.
float32
config_dtype
=
torch
.
float32
dtype
=
dtype
.
lower
()
dtype
=
dtype
.
lower
()
if
dtype
==
"
default
"
:
if
dtype
==
"
auto
"
:
if
config_dtype
==
torch
.
float32
:
if
config_dtype
==
torch
.
float32
:
# Following the common practice, we use float16 for float32 models.
# Following the common practice, we use float16 for float32 models.
torch_dtype
=
torch
.
float16
torch_dtype
=
torch
.
float16
...
@@ -184,9 +184,8 @@ def _get_and_verify_dtype(
...
@@ -184,9 +184,8 @@ def _get_and_verify_dtype(
# Downcasting from float32 to float16 or bfloat16 is allowed.
# Downcasting from float32 to float16 or bfloat16 is allowed.
pass
pass
else
:
else
:
# Casting between float16 and bfloat16 is not allowed.
# Casting between float16 and bfloat16 is allowed with a warning.
raise
ValueError
(
logger
.
warn
(
f
"Casting
{
config_dtype
}
to
{
torch_dtype
}
."
)
f
"Cannot use
{
torch_dtype
}
for
{
config_dtype
}
model."
)
# Check if the GPU supports the dtype.
# Check if the GPU supports the dtype.
if
torch_dtype
==
torch
.
bfloat16
:
if
torch_dtype
==
torch
.
bfloat16
:
...
...
cacheflow/entrypoints/llm.py
View file @
e38074b1
...
@@ -28,9 +28,10 @@ class LLM:
...
@@ -28,9 +28,10 @@ class LLM:
tensor_parallel_size: The number of GPUs to use for distributed
tensor_parallel_size: The number of GPUs to use for distributed
execution with tensor parallelism.
execution with tensor parallelism.
dtype: The data type for the model weights and activations. Currently,
dtype: The data type for the model weights and activations. Currently,
we support `float16` and `bfloat16`. If `default`, we use the
we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
`torch_dtype` attribute of the model config. If the `torch_dtype`
the `torch_dtype` attribute specified in the model config file.
is `float32`, we use `float16` instead.
However, if the `torch_dtype` in the config is `float32`, we will
use `float16` instead.
seed: The seed to initialize the random number generator for sampling.
seed: The seed to initialize the random number generator for sampling.
"""
"""
...
@@ -38,7 +39,7 @@ class LLM:
...
@@ -38,7 +39,7 @@ class LLM:
self
,
self
,
model
:
str
,
model
:
str
,
tensor_parallel_size
:
int
=
1
,
tensor_parallel_size
:
int
=
1
,
dtype
:
str
=
"
default
"
,
dtype
:
str
=
"
auto
"
,
seed
:
int
=
0
,
seed
:
int
=
0
,
**
kwargs
,
**
kwargs
,
)
->
None
:
)
->
None
:
...
...
cacheflow/model_executor/layers/attention.py
View file @
e38074b1
...
@@ -10,7 +10,7 @@ from cacheflow import cache_ops
...
@@ -10,7 +10,7 @@ from cacheflow import cache_ops
from
cacheflow
import
pos_encoding_ops
from
cacheflow
import
pos_encoding_ops
from
cacheflow.model_executor.input_metadata
import
InputMetadata
from
cacheflow.model_executor.input_metadata
import
InputMetadata
_SUPPORTED_HEAD_SIZES
=
[
32
,
64
,
80
,
96
,
128
,
160
,
192
,
256
]
_SUPPORTED_HEAD_SIZES
=
[
64
,
80
,
96
,
128
]
class
GPTCacheFlowAttention
(
nn
.
Module
):
class
GPTCacheFlowAttention
(
nn
.
Module
):
...
@@ -49,10 +49,8 @@ class GPTCacheFlowAttention(nn.Module):
...
@@ -49,10 +49,8 @@ class GPTCacheFlowAttention(nn.Module):
self
.
attn_op
=
xops
.
fmha
.
cutlass
.
FwOp
()
self
.
attn_op
=
xops
.
fmha
.
cutlass
.
FwOp
()
if
self
.
head_size
not
in
_SUPPORTED_HEAD_SIZES
:
if
self
.
head_size
not
in
_SUPPORTED_HEAD_SIZES
:
raise
ValueError
(
f
'head_size (
{
self
.
head_size
}
) is not supported by '
raise
ValueError
(
f
"head_size (
{
self
.
head_size
}
) is not supported. "
'the single_query_cached_kv_attention kernel. '
f
"Supported head sizes:
{
_SUPPORTED_HEAD_SIZES
}
."
)
'Use one of the following head sizes: '
f
'
{
_SUPPORTED_HEAD_SIZES
}
.'
)
def
multi_query_kv_attention
(
def
multi_query_kv_attention
(
self
,
self
,
...
...
cacheflow/server/arg_utils.py
View file @
e38074b1
...
@@ -13,7 +13,7 @@ class ServerArgs:
...
@@ -13,7 +13,7 @@ class ServerArgs:
download_dir
:
Optional
[
str
]
=
None
download_dir
:
Optional
[
str
]
=
None
use_np_weights
:
bool
=
False
use_np_weights
:
bool
=
False
use_dummy_weights
:
bool
=
False
use_dummy_weights
:
bool
=
False
dtype
:
str
=
"
default
"
dtype
:
str
=
"
auto
"
seed
:
int
=
0
seed
:
int
=
0
worker_use_ray
:
bool
=
False
worker_use_ray
:
bool
=
False
pipeline_parallel_size
:
int
=
1
pipeline_parallel_size
:
int
=
1
...
@@ -49,9 +49,9 @@ class ServerArgs:
...
@@ -49,9 +49,9 @@ class ServerArgs:
help
=
'use dummy values for model weights'
)
help
=
'use dummy values for model weights'
)
# TODO(woosuk): Support FP32.
# TODO(woosuk): Support FP32.
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
ServerArgs
.
dtype
,
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
ServerArgs
.
dtype
,
choices
=
[
'
default
'
,
'half'
,
'bfloat16'
],
choices
=
[
'
auto
'
,
'half'
,
'bfloat16'
,
'float'
],
help
=
'data type for model weights and activations. '
help
=
'data type for model weights and activations. '
'The "
default
" option will use FP16 precision '
'The "
auto
" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.'
)
'for BF16 models.'
)
# Parallel arguments
# Parallel arguments
...
@@ -67,7 +67,7 @@ class ServerArgs:
...
@@ -67,7 +67,7 @@ class ServerArgs:
# KV cache arguments
# KV cache arguments
parser
.
add_argument
(
'--block-size'
,
type
=
int
,
parser
.
add_argument
(
'--block-size'
,
type
=
int
,
default
=
ServerArgs
.
block_size
,
default
=
ServerArgs
.
block_size
,
choices
=
[
1
,
2
,
4
,
8
,
16
,
32
,
64
,
128
,
256
],
choices
=
[
8
,
16
,
32
],
help
=
'token block size'
)
help
=
'token block size'
)
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
ServerArgs
.
seed
,
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
ServerArgs
.
seed
,
...
...
csrc/attention/attention_kernels.cu
View file @
e38074b1
...
@@ -370,9 +370,11 @@ void single_query_cached_kv_attention_launcher(
...
@@ -370,9 +370,11 @@ void single_query_cached_kv_attention_launcher(
dim3
block
(
NUM_THREADS
);
dim3
block
(
NUM_THREADS
);
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
switch
(
head_size
)
{
switch
(
head_size
)
{
case
32
:
// NOTE(woosuk): To reduce the compilation time, we omitted head sizes
LAUNCH_ATTENTION_KERNEL
(
T
,
32
,
BLOCK_SIZE
,
NUM_THREADS
);
// 32, 160, 192, 256.
break
;
// case 32:
// LAUNCH_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS);
// break;
case
64
:
case
64
:
LAUNCH_ATTENTION_KERNEL
(
T
,
64
,
BLOCK_SIZE
,
NUM_THREADS
);
LAUNCH_ATTENTION_KERNEL
(
T
,
64
,
BLOCK_SIZE
,
NUM_THREADS
);
break
;
break
;
...
@@ -385,15 +387,15 @@ void single_query_cached_kv_attention_launcher(
...
@@ -385,15 +387,15 @@ void single_query_cached_kv_attention_launcher(
case
128
:
case
128
:
LAUNCH_ATTENTION_KERNEL
(
T
,
128
,
BLOCK_SIZE
,
NUM_THREADS
);
LAUNCH_ATTENTION_KERNEL
(
T
,
128
,
BLOCK_SIZE
,
NUM_THREADS
);
break
;
break
;
case
160
:
//
case 160:
LAUNCH_ATTENTION_KERNEL
(
T
,
160
,
BLOCK_SIZE
,
NUM_THREADS
);
//
LAUNCH_ATTENTION_KERNEL(T, 160, BLOCK_SIZE, NUM_THREADS);
break
;
//
break;
case
192
:
//
case 192:
LAUNCH_ATTENTION_KERNEL
(
T
,
192
,
BLOCK_SIZE
,
NUM_THREADS
);
//
LAUNCH_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS);
break
;
//
break;
case
256
:
//
case 256:
LAUNCH_ATTENTION_KERNEL
(
T
,
256
,
BLOCK_SIZE
,
NUM_THREADS
);
//
LAUNCH_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS);
break
;
//
break;
default:
default:
TORCH_CHECK
(
false
,
"Unsupported head size: "
,
head_size
);
TORCH_CHECK
(
false
,
"Unsupported head size: "
,
head_size
);
break
;
break
;
...
@@ -411,17 +413,19 @@ void single_query_cached_kv_attention_launcher(
...
@@ -411,17 +413,19 @@ void single_query_cached_kv_attention_launcher(
context_lens, \
context_lens, \
max_context_len);
max_context_len);
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256.
#define CALL_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
#define CALL_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
switch (block_size) { \
switch (block_size) { \
case 1:
\
/*
case 1:
*/
\
CALL_KERNEL_LAUNCHER(T, 1);
\
/*
CALL_KERNEL_LAUNCHER(T, 1);
*/
\
break;
\
/*
break;
*/
\
case 2:
\
/*
case 2:
*/
\
CALL_KERNEL_LAUNCHER(T, 2);
\
/*
CALL_KERNEL_LAUNCHER(T, 2);
*/
\
break;
\
/*
break;
*/
\
case 4:
\
/*
case 4:
*/
\
CALL_KERNEL_LAUNCHER(T, 4);
\
/*
CALL_KERNEL_LAUNCHER(T, 4);
*/
\
break;
\
/*
break;
*/
\
case 8: \
case 8: \
CALL_KERNEL_LAUNCHER(T, 8); \
CALL_KERNEL_LAUNCHER(T, 8); \
break; \
break; \
...
@@ -431,15 +435,15 @@ void single_query_cached_kv_attention_launcher(
...
@@ -431,15 +435,15 @@ void single_query_cached_kv_attention_launcher(
case 32: \
case 32: \
CALL_KERNEL_LAUNCHER(T, 32); \
CALL_KERNEL_LAUNCHER(T, 32); \
break; \
break; \
case 64:
\
/*
case 64:
*/
\
CALL_KERNEL_LAUNCHER(T, 64);
\
/*
CALL_KERNEL_LAUNCHER(T, 64);
*/
\
break;
\
/*
break;
*/
\
case 128:
\
/*
case 128:
*/
\
CALL_KERNEL_LAUNCHER(T, 128);
\
/*
CALL_KERNEL_LAUNCHER(T, 128);
*/
\
break;
\
/*
break;
*/
\
case 256:
\
/*
case 256:
*/
\
CALL_KERNEL_LAUNCHER(T, 256);
\
/*
CALL_KERNEL_LAUNCHER(T, 256);
*/
\
break;
\
/*
break;
*/
\
default: \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
break; \
...
@@ -455,8 +459,9 @@ void single_query_cached_kv_attention(
...
@@ -455,8 +459,9 @@ void single_query_cached_kv_attention(
torch
::
Tensor
&
context_lens
,
// [num_seqs]
torch
::
Tensor
&
context_lens
,
// [num_seqs]
int
block_size
,
int
block_size
,
int
max_context_len
)
{
int
max_context_len
)
{
// TODO(woosuk): Support FP32.
if
(
query
.
dtype
()
==
at
::
ScalarType
::
Float
)
{
if
(
query
.
dtype
()
==
at
::
ScalarType
::
Half
)
{
CALL_KERNEL_LAUNCHER_BLOCK_SIZE
(
float
);
}
else
if
(
query
.
dtype
()
==
at
::
ScalarType
::
Half
)
{
CALL_KERNEL_LAUNCHER_BLOCK_SIZE
(
uint16_t
);
CALL_KERNEL_LAUNCHER_BLOCK_SIZE
(
uint16_t
);
}
else
if
(
query
.
dtype
()
==
at
::
ScalarType
::
BFloat16
)
{
}
else
if
(
query
.
dtype
()
==
at
::
ScalarType
::
BFloat16
)
{
CALL_KERNEL_LAUNCHER_BLOCK_SIZE
(
__nv_bfloat16
);
CALL_KERNEL_LAUNCHER_BLOCK_SIZE
(
__nv_bfloat16
);
...
...
docs/source/getting_started/installation.rst
View file @
e38074b1
...
@@ -18,8 +18,11 @@ CacheFlow can run on systems that meet the following requirements:
...
@@ -18,8 +18,11 @@ CacheFlow can run on systems that meet the following requirements:
.. code-block:: console
.. code-block:: console
$ # Pull the Docker image with CUDA 11.8.
$ docker run --gpus all -it --rm --shm-size=8g nvcr.io/nvidia/pytorch:22.12-py3
$ docker run --gpus all -it --rm --shm-size=8g nvcr.io/nvidia/pytorch:22.12-py3
Inside the Docker container, please execute :code:`pip uninstall torch` before installing CacheFlow.
Install with pip
Install with pip
----------------
----------------
...
...
setup.py
View file @
e38074b1
...
@@ -66,6 +66,11 @@ if 90 in compute_capabilities and nvcc_cuda_version < Version("11.8"):
...
@@ -66,6 +66,11 @@ if 90 in compute_capabilities and nvcc_cuda_version < Version("11.8"):
raise
RuntimeError
(
raise
RuntimeError
(
"CUDA 11.8 or higher is required for GPUs with compute capability 9.0."
)
"CUDA 11.8 or higher is required for GPUs with compute capability 9.0."
)
# Use NVCC threads to parallelize the build.
if
nvcc_cuda_version
>=
Version
(
"11.2"
):
num_threads
=
min
(
os
.
cpu_count
(),
8
)
NVCC_FLAGS
+=
[
"--threads"
,
str
(
num_threads
)]
ext_modules
=
[]
ext_modules
=
[]
# Cache operations.
# Cache operations.
...
...
tests/kernels/test_attention.py
View file @
e38074b1
...
@@ -270,9 +270,9 @@ def run_multi_query_kv_attention(
...
@@ -270,9 +270,9 @@ def run_multi_query_kv_attention(
def
test_single_query_cached_kv_attention
()
->
None
:
def
test_single_query_cached_kv_attention
()
->
None
:
torch
.
random
.
manual_seed
(
TEST_SEED
)
torch
.
random
.
manual_seed
(
TEST_SEED
)
torch
.
cuda
.
manual_seed
(
TEST_SEED
)
torch
.
cuda
.
manual_seed
(
TEST_SEED
)
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
]:
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]:
for
block_size
in
[
8
,
16
,
32
,
64
]:
for
block_size
in
[
8
,
16
,
32
]:
for
head_size
in
[
32
,
64
,
80
,
96
,
128
,
160
,
192
,
256
]:
for
head_size
in
[
64
,
80
,
96
,
128
]:
print
(
f
'Testing single_query_cached_kv_attention with '
print
(
f
'Testing single_query_cached_kv_attention with '
f
'dtype=
{
dtype
}
, block_size=
{
block_size
}
, '
f
'dtype=
{
dtype
}
, block_size=
{
block_size
}
, '
f
'head_size=
{
head_size
}
'
)
f
'head_size=
{
head_size
}
'
)
...
@@ -289,8 +289,8 @@ def test_single_query_cached_kv_attention() -> None:
...
@@ -289,8 +289,8 @@ def test_single_query_cached_kv_attention() -> None:
def
test_multi_query_kv_attention
()
->
None
:
def
test_multi_query_kv_attention
()
->
None
:
torch
.
random
.
manual_seed
(
TEST_SEED
)
torch
.
random
.
manual_seed
(
TEST_SEED
)
torch
.
cuda
.
manual_seed
(
TEST_SEED
)
torch
.
cuda
.
manual_seed
(
TEST_SEED
)
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
]:
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]:
for
head_size
in
[
32
,
64
,
80
,
96
,
128
,
160
,
192
,
256
]:
for
head_size
in
[
64
,
80
,
96
,
128
]:
print
(
f
'Testing multi_query_kv_attention with dtype=
{
dtype
}
, '
print
(
f
'Testing multi_query_kv_attention with dtype=
{
dtype
}
, '
f
'head_size=
{
head_size
}
'
)
f
'head_size=
{
head_size
}
'
)
run_multi_query_kv_attention
(
run_multi_query_kv_attention
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment