Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
b322fd16
Unverified
Commit
b322fd16
authored
May 14, 2023
by
Woosuk Kwon
Committed by
GitHub
May 14, 2023
Browse files
Add docstrings to some modules and classes (#100)
parent
667ba399
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
166 additions
and
31 deletions
+166
-31
cacheflow/block.py
cacheflow/block.py
+9
-2
cacheflow/core/block_manager.py
cacheflow/core/block_manager.py
+20
-15
cacheflow/core/server.py
cacheflow/core/server.py
+2
-2
cacheflow/model_executor/layers/activation.py
cacheflow/model_executor/layers/activation.py
+5
-0
cacheflow/model_executor/layers/attention.py
cacheflow/model_executor/layers/attention.py
+28
-1
cacheflow/model_executor/layers/layernorm.py
cacheflow/model_executor/layers/layernorm.py
+6
-0
cacheflow/model_executor/layers/sampler.py
cacheflow/model_executor/layers/sampler.py
+14
-0
cacheflow/model_executor/model_loader.py
cacheflow/model_executor/model_loader.py
+2
-2
cacheflow/model_executor/models/gpt2.py
cacheflow/model_executor/models/gpt2.py
+5
-1
cacheflow/model_executor/models/gpt_neox.py
cacheflow/model_executor/models/gpt_neox.py
+6
-1
cacheflow/model_executor/models/llama.py
cacheflow/model_executor/models/llama.py
+5
-1
cacheflow/model_executor/models/opt.py
cacheflow/model_executor/models/opt.py
+5
-1
cacheflow/model_executor/utils.py
cacheflow/model_executor/utils.py
+6
-5
cacheflow/model_executor/weight_utils.py
cacheflow/model_executor/weight_utils.py
+8
-0
cacheflow/sampling_params.py
cacheflow/sampling_params.py
+30
-0
cacheflow/worker/cache_engine.py
cacheflow/worker/cache_engine.py
+8
-0
cacheflow/worker/worker.py
cacheflow/worker/worker.py
+7
-0
No files found.
cacheflow/block.py
View file @
b322fd16
"""Token blocks."""
from
typing
import
List
from
typing
import
List
from
cacheflow.utils
import
Device
from
cacheflow.utils
import
Device
BLANK_TOKEN_ID
=
-
1
_
BLANK_TOKEN_ID
=
-
1
class
LogicalTokenBlock
:
class
LogicalTokenBlock
:
"""A block that stores a contiguous chunk of tokens from left to right.
Logical blocks are used to represent the states of the corresponding
physical blocks in the KV cache.
"""
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -15,7 +21,7 @@ class LogicalTokenBlock:
...
@@ -15,7 +21,7 @@ class LogicalTokenBlock:
self
.
block_number
=
block_number
self
.
block_number
=
block_number
self
.
block_size
=
block_size
self
.
block_size
=
block_size
self
.
token_ids
=
[
BLANK_TOKEN_ID
]
*
block_size
self
.
token_ids
=
[
_
BLANK_TOKEN_ID
]
*
block_size
self
.
num_tokens
=
0
self
.
num_tokens
=
0
def
is_empty
(
self
)
->
bool
:
def
is_empty
(
self
)
->
bool
:
...
@@ -41,6 +47,7 @@ class LogicalTokenBlock:
...
@@ -41,6 +47,7 @@ class LogicalTokenBlock:
class
PhysicalTokenBlock
:
class
PhysicalTokenBlock
:
"""Represents the state of a block in the KV cache."""
def
__init__
(
def
__init__
(
self
,
self
,
...
...
cacheflow/core/block_manager.py
View file @
b322fd16
"""A block manager that manages token blocks."""
from
typing
import
Dict
,
List
,
Optional
,
Set
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Set
,
Tuple
from
cacheflow.block
import
PhysicalTokenBlock
from
cacheflow.block
import
PhysicalTokenBlock
from
cacheflow.sequence
import
Sequence
from
cacheflow.sequence
import
Sequence
,
SequenceGroup
,
SequenceStatus
from
cacheflow.sequence
import
SequenceGroup
from
cacheflow.sequence
import
SequenceStatus
from
cacheflow.utils
import
Device
from
cacheflow.utils
import
Device
class
BlockAllocator
:
class
BlockAllocator
:
"""Manages free physical token blocks for a device.
The allocator maintains a list of free blocks and allocates a block when
requested. When a block is freed, its reference count is decremented. If
the reference count becomes zero, the block is added back to the free list.
"""
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -20,24 +25,22 @@ class BlockAllocator:
...
@@ -20,24 +25,22 @@ class BlockAllocator:
self
.
num_blocks
=
num_blocks
self
.
num_blocks
=
num_blocks
# Initialize the free blocks.
# Initialize the free blocks.
# TODO(woosuk): Make this a priority queue.
self
.
free_blocks
:
List
[
PhysicalTokenBlock
]
=
[]
self
.
free
_blocks
=
[
for
i
in
range
(
num
_blocks
):
PhysicalTokenBlock
(
device
=
device
,
block_number
=
i
,
block_size
=
block_size
)
block
=
PhysicalTokenBlock
(
for
i
in
range
(
num_blocks
)
device
=
device
,
block_number
=
i
,
block_size
=
block_size
)
]
self
.
free_blocks
.
append
(
block
)
def
allocate
(
self
)
->
PhysicalTokenBlock
:
def
allocate
(
self
)
->
PhysicalTokenBlock
:
if
not
self
.
free_blocks
:
if
not
self
.
free_blocks
:
raise
ValueError
(
'Out of memory! '
raise
ValueError
(
"Out of memory! No free blocks are available."
)
f
'No more free blocks are available.'
)
block
=
self
.
free_blocks
.
pop
()
block
=
self
.
free_blocks
.
pop
()
block
.
ref_count
=
1
block
.
ref_count
=
1
return
block
return
block
def
free
(
self
,
block
:
PhysicalTokenBlock
)
->
None
:
def
free
(
self
,
block
:
PhysicalTokenBlock
)
->
None
:
if
block
.
ref_count
==
0
:
if
block
.
ref_count
==
0
:
raise
ValueError
(
'Double free! '
raise
ValueError
(
f
"Double free!
{
block
}
is already freed."
)
f
'The block
{
block
}
is already freed.'
)
block
.
ref_count
-=
1
block
.
ref_count
-=
1
if
block
.
ref_count
==
0
:
if
block
.
ref_count
==
0
:
self
.
free_blocks
.
append
(
block
)
self
.
free_blocks
.
append
(
block
)
...
@@ -51,6 +54,7 @@ BlockTable = List[PhysicalTokenBlock]
...
@@ -51,6 +54,7 @@ BlockTable = List[PhysicalTokenBlock]
class
BlockSpaceManager
:
class
BlockSpaceManager
:
"""Manages the mapping between logical and physical token blocks."""
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -66,9 +70,10 @@ class BlockSpaceManager:
...
@@ -66,9 +70,10 @@ class BlockSpaceManager:
assert
watermark
>=
0.0
assert
watermark
>=
0.0
self
.
watermark_blocks
=
int
(
watermark
*
num_gpu_blocks
)
self
.
watermark_blocks
=
int
(
watermark
*
num_gpu_blocks
)
self
.
gpu_allocator
=
BlockAllocator
(
Device
.
GPU
,
block_size
,
num_gpu_blocks
)
self
.
gpu_allocator
=
BlockAllocator
(
Device
.
GPU
,
block_size
,
self
.
cpu_allocator
=
BlockAllocator
(
Device
.
CPU
,
block_size
,
num_cpu_blocks
)
num_gpu_blocks
)
self
.
cpu_allocator
=
BlockAllocator
(
Device
.
CPU
,
block_size
,
num_cpu_blocks
)
# Mapping: seq_id -> BlockTable.
# Mapping: seq_id -> BlockTable.
self
.
block_tables
:
Dict
[
int
,
BlockTable
]
=
{}
self
.
block_tables
:
Dict
[
int
,
BlockTable
]
=
{}
...
...
cacheflow/core/server.py
View file @
b322fd16
import
argparse
import
argparse
from
typing
import
List
,
Tuple
,
Optional
import
random
import
random
from
typing
import
List
,
Optional
,
Tuple
import
torch
try
:
try
:
import
ray
import
ray
except
ImportError
:
except
ImportError
:
ray
=
None
ray
=
None
import
torch
from
cacheflow.core.scheduler
import
Scheduler
from
cacheflow.core.scheduler
import
Scheduler
from
cacheflow.frontend.simple_frontend
import
SimpleFrontend
from
cacheflow.frontend.simple_frontend
import
SimpleFrontend
...
...
cacheflow/model_executor/layers/activation.py
View file @
b322fd16
"""Custom activation functions."""
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -5,6 +6,10 @@ from cacheflow import activation_ops
...
@@ -5,6 +6,10 @@ from cacheflow import activation_ops
class
SiluAndMul
(
nn
.
Module
):
class
SiluAndMul
(
nn
.
Module
):
"""An activation function for SwiGLU.
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[1] // 2.
"""
def
__init__
(
self
):
def
__init__
(
self
):
super
().
__init__
()
super
().
__init__
()
...
...
cacheflow/model_executor/layers/attention.py
View file @
b322fd16
"""Multi-head attention."""
from
typing
import
Optional
from
typing
import
Optional
import
torch
import
torch
...
@@ -11,6 +12,32 @@ from cacheflow.model_executor.input_metadata import InputMetadata
...
@@ -11,6 +12,32 @@ from cacheflow.model_executor.input_metadata import InputMetadata
class
GPTCacheFlowAttention
(
nn
.
Module
):
class
GPTCacheFlowAttention
(
nn
.
Module
):
"""GPT-style multi-head attention.
This class takes flattened 1D query, key, and value tensors as input. The
input 1D tensors can be split into three parts: the prompt tokens, the
generation tokens, and the paddings.
|<------------------------------------- num_valid_tokens ------------------------------------->|
|<--------------- num_prompt_tokens -------------->|<------- num_generation_tokens (M) ------->|
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|<--generation_0-->|...|<--generation_M-1-->|<--padding-->|
The prompts might have different lengths, while the generation tokens always
have length 1. The paddings are appended to make the input length a multiple
of 8, which is desirable for Tensor Cores.
The class does the following:
1. Perform multi_query_kv_attention for the prompts. This operation does
not use the KV cache.
2. Wait for the cache operations (e.g., swap, copy) to finish. The cache
operations are issued by the cache engine before executing the forward
pass of the model, and they are executed asynchronously.
3. Reshape and store the input key and value tensors in the KV cache.
4. Perform single_query_cached_kv_attention for the generation tokens.
This operation reads the previous key and value tensors from the KV
cache.
5. Output a flattened 1D tensor.
"""
def
__init__
(
self
,
scale
:
float
)
->
None
:
def
__init__
(
self
,
scale
:
float
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -157,7 +184,7 @@ class GPTNeoXCacheFlowAttention(GPTCacheFlowAttention):
...
@@ -157,7 +184,7 @@ class GPTNeoXCacheFlowAttention(GPTCacheFlowAttention):
torch_dtype
=
torch
.
get_default_dtype
()
torch_dtype
=
torch
.
get_default_dtype
()
cache
=
cache
.
to
(
torch_dtype
)
cache
=
cache
.
to
(
torch_dtype
)
# Embedding size: [max_position, rotary_dim]
# Embedding size: [max_position, rotary_dim]
self
.
register_buffer
(
'
cos_sin_cache
'
,
cache
,
persistent
=
False
)
self
.
register_buffer
(
"
cos_sin_cache
"
,
cache
,
persistent
=
False
)
def
forward
(
def
forward
(
self
,
self
,
...
...
cacheflow/model_executor/layers/layernorm.py
View file @
b322fd16
"""Custom normalization layers."""
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -5,6 +6,11 @@ from cacheflow import layernorm_ops
...
@@ -5,6 +6,11 @@ from cacheflow import layernorm_ops
class
RMSNorm
(
nn
.
Module
):
class
RMSNorm
(
nn
.
Module
):
"""Root mean square normalization.
Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
Refer to https://arxiv.org/abs/1910.07467
"""
def
__init__
(
def
__init__
(
self
,
self
,
...
...
cacheflow/model_executor/layers/sampler.py
View file @
b322fd16
"""A layer that samples the next tokens from the model's outputs."""
from
typing
import
Dict
,
List
,
Tuple
from
typing
import
Dict
,
List
,
Tuple
import
numpy
as
np
import
numpy
as
np
...
@@ -12,6 +13,19 @@ from cacheflow.sequence import SequenceOutputs
...
@@ -12,6 +13,19 @@ from cacheflow.sequence import SequenceOutputs
class
Sampler
(
nn
.
Module
):
class
Sampler
(
nn
.
Module
):
"""Samples the next tokens from the model's outputs.
This layer does the following:
1. Discard the hidden states that are not used for sampling (i.e., all
tokens except the final one in each prompt).
2. Compute the logits for the next tokens.
3. Apply presence and frequency penalties.
4. Apply temperature scaling.
5. Apply top-p and top-k truncation.
6. Sample the next tokens.
Here, each sequence group within the batch can have different sampling
parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
"""
def
__init__
(
self
,
vocab_size
:
int
)
->
None
:
def
__init__
(
self
,
vocab_size
:
int
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
...
cacheflow/model_executor/model_loader.py
View file @
b322fd16
"""Utilities for selecting and loading models."""
from
typing
import
Optional
from
typing
import
Optional
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
transformers
import
AutoConfig
from
transformers
import
AutoConfig
,
PretrainedConfig
from
transformers
import
PretrainedConfig
from
cacheflow.model_executor.memory_analyzer
import
(
from
cacheflow.model_executor.memory_analyzer
import
(
CacheFlowMemoryAnalyzer
,
GPT2MemoryAnalyzer
,
GPTNeoXMemoryAnalyzer
,
CacheFlowMemoryAnalyzer
,
GPT2MemoryAnalyzer
,
GPTNeoXMemoryAnalyzer
,
...
...
cacheflow/model_executor/models/gpt2.py
View file @
b322fd16
...
@@ -15,7 +15,11 @@
...
@@ -15,7 +15,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""1D GPT-2 model compatible with HuggingFace weights."""
"""Inference-only GPT-2 model compatible with HuggingFace weights.
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch
...
...
cacheflow/model_executor/models/gpt_neox.py
View file @
b322fd16
...
@@ -14,7 +14,11 @@
...
@@ -14,7 +14,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""1D GPT-NeoX model compatible with HuggingFace weights."""
"""Inference-only GPT-NeoX model compatible with HuggingFace weights.
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch
...
@@ -79,6 +83,7 @@ class GPTNeoXAttention(nn.Module):
...
@@ -79,6 +83,7 @@ class GPTNeoXAttention(nn.Module):
class
GPTNeoXMLP
(
nn
.
Module
):
class
GPTNeoXMLP
(
nn
.
Module
):
def
__init__
(
self
,
config
:
GPTNeoXConfig
):
def
__init__
(
self
,
config
:
GPTNeoXConfig
):
super
().
__init__
()
super
().
__init__
()
self
.
dense_h_to_4h
=
ColumnParallelLinear
(
config
.
hidden_size
,
self
.
dense_h_to_4h
=
ColumnParallelLinear
(
config
.
hidden_size
,
...
...
cacheflow/model_executor/models/llama.py
View file @
b322fd16
...
@@ -19,7 +19,11 @@
...
@@ -19,7 +19,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""1D LLaMA model compatible with HuggingFace weights."""
"""Inference-only LLaMA model compatible with HuggingFace weights.
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch
...
...
cacheflow/model_executor/models/opt.py
View file @
b322fd16
...
@@ -14,7 +14,11 @@
...
@@ -14,7 +14,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""1D OPT model compatible with HuggingFace weights."""
"""Inference-only OPT model compatible with HuggingFace weights.
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch
...
...
cacheflow/model_executor/utils.py
View file @
b322fd16
"""Utils for model executor."""
import
random
import
random
from
typing
import
Union
from
typing
import
Union
...
@@ -9,11 +10,11 @@ from cacheflow.model_executor.parallel_utils.tensor_parallel import model_parall
...
@@ -9,11 +10,11 @@ from cacheflow.model_executor.parallel_utils.tensor_parallel import model_parall
_STR_DTYPE_TO_TORCH_DTYPE
=
{
_STR_DTYPE_TO_TORCH_DTYPE
=
{
'
half
'
:
torch
.
half
,
"
half
"
:
torch
.
half
,
'
float
'
:
torch
.
float
,
"
float
"
:
torch
.
float
,
'
float16
'
:
torch
.
float16
,
"
float16
"
:
torch
.
float16
,
'
float32
'
:
torch
.
float32
,
"
float32
"
:
torch
.
float32
,
'
bfloat16
'
:
torch
.
bfloat16
,
"
bfloat16
"
:
torch
.
bfloat16
,
}
}
...
...
cacheflow/model_executor/weight_utils.py
View file @
b322fd16
"""Utilities for downloading and initializing model weights."""
import
filelock
import
filelock
import
glob
import
glob
import
json
import
json
...
@@ -106,5 +107,12 @@ def initialize_dummy_weights(
...
@@ -106,5 +107,12 @@ def initialize_dummy_weights(
low
:
float
=
-
1e-3
,
low
:
float
=
-
1e-3
,
high
:
float
=
1e-3
,
high
:
float
=
1e-3
,
)
->
None
:
)
->
None
:
"""Initialize model weights with random values.
The model weights must be randomly initialized for accurate performance
measurements. Additionally, the model weights should not cause NaNs in the
forward pass. We empirically found that initializing the weights with
values between -1e-3 and 1e-3 works well for most models.
"""
for
param
in
model
.
state_dict
().
values
():
for
param
in
model
.
state_dict
().
values
():
param
.
data
.
uniform_
(
low
,
high
)
param
.
data
.
uniform_
(
low
,
high
)
cacheflow/sampling_params.py
View file @
b322fd16
"""Sampling parameters for text generation."""
from
typing
import
Set
from
typing
import
Set
class
SamplingParams
:
class
SamplingParams
:
"""Sampling parameters for text generation.
Overall, we follow the sampling parameters from the OpenAI text completion
API (https://platform.openai.com/docs/api-reference/completions/create).
In addition, we support beam search, which is not supported by OpenAI.
Args:
n: Number of output sequences to generate from the given prompt. This is
regarded as the beam width when using beam search.
presence_penalty: Float that penalizes new tokens based on whether they
appear in the generated text so far. Values > 0 encourage the model
to use new tokens, while values < 0 encourage the model to repeat
tokens.
frequency_penalty: Float that penalizes new tokens based on their
frequency in the generated text so far. Values > 0 encourage the
model to use new tokens, while values < 0 encourage the model to
repeat tokens.
temperature: Float that controls the randomness of the sampling. Lower
values make the model more deterministic, while higher values make
the model more random. Zero means greedy sampling.
top_p: Float that controls the cumulative probability of the top tokens
to consider. Must be in (0, 1]. Set to 1 to consider all tokens.
top_k: Integer that controls the number of top tokens to consider. Set
to -1 to consider all tokens.
use_beam_search: Whether to use beam search instead of sampling.
stop_token_ids: Set of token IDs that indicate the end of a sequence.
max_tokens: Maximum number of tokens to generate per output sequence.
logprobs: Number of log probabilities to return per output token.
"""
def
__init__
(
def
__init__
(
self
,
self
,
...
...
cacheflow/worker/cache_engine.py
View file @
b322fd16
"""CacheEngine class for managing the KV cache."""
from
typing
import
Dict
,
List
,
Tuple
from
typing
import
Dict
,
List
,
Tuple
import
torch
import
torch
from
cacheflow
import
cache_ops
from
cacheflow
import
cache_ops
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
CacheEngine
:
class
CacheEngine
:
"""Manages the KV cache.
This class is responsible for initializing and managing the GPU and CPU KV
caches. It also provides methods for performing KV cache operations, such
as swapping and copying.
"""
def
__init__
(
def
__init__
(
self
,
self
,
...
...
cacheflow/worker/worker.py
View file @
b322fd16
"""A GPU worker class."""
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch
...
@@ -14,6 +15,12 @@ from cacheflow.worker.cache_engine import CacheEngine
...
@@ -14,6 +15,12 @@ from cacheflow.worker.cache_engine import CacheEngine
class
Worker
:
class
Worker
:
"""A worker class that executes (a partition of) the model on a GPU.
Each worker is associated with a single GPU. The worker is responsible for
maintaining the KV cache and executing the model on the GPU. In case of
distributed inference, each worker is assigned a partition of the model.
"""
def
__init__
(
def
__init__
(
self
,
self
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment