Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0fca3cdc
Unverified
Commit
0fca3cdc
authored
May 13, 2024
by
Woosuk Kwon
Committed by
GitHub
May 13, 2024
Browse files
[Misc] Enhance attention selector (#4751)
parent
e7c46b95
Changes
49
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
256 additions
and
114 deletions
+256
-114
tests/worker/test_model_runner.py
tests/worker/test_model_runner.py
+0
-1
vllm/attention/__init__.py
vllm/attention/__init__.py
+2
-2
vllm/attention/backends/abstract.py
vllm/attention/backends/abstract.py
+2
-3
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+7
-6
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+23
-10
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+9
-7
vllm/attention/backends/torch_sdpa.py
vllm/attention/backends/torch_sdpa.py
+17
-11
vllm/attention/backends/xformers.py
vllm/attention/backends/xformers.py
+6
-6
vllm/attention/layer.py
vllm/attention/layer.py
+17
-2
vllm/attention/selector.py
vllm/attention/selector.py
+23
-5
vllm/model_executor/model_loader/__init__.py
vllm/model_executor/model_loader/__init__.py
+11
-8
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+39
-22
vllm/model_executor/models/arctic.py
vllm/model_executor/models/arctic.py
+13
-3
vllm/model_executor/models/baichuan.py
vllm/model_executor/models/baichuan.py
+22
-7
vllm/model_executor/models/bloom.py
vllm/model_executor/models/bloom.py
+11
-4
vllm/model_executor/models/chatglm.py
vllm/model_executor/models/chatglm.py
+14
-6
vllm/model_executor/models/commandr.py
vllm/model_executor/models/commandr.py
+11
-3
vllm/model_executor/models/dbrx.py
vllm/model_executor/models/dbrx.py
+13
-4
vllm/model_executor/models/decilm.py
vllm/model_executor/models/decilm.py
+3
-1
vllm/model_executor/models/deepseek.py
vllm/model_executor/models/deepseek.py
+13
-3
No files found.
tests/worker/test_model_runner.py
View file @
0fca3cdc
...
...
@@ -307,7 +307,6 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
assert
len
(
attn_metadata
.
slot_mapping
)
==
len
(
input_tokens
)
assert
len
(
input_positions
)
==
len
(
input_tokens
)
assert
attn_metadata
.
kv_cache_dtype
==
"auto"
assert
attn_metadata
.
num_prefills
==
prefill_batch_size
if
enforce_eager
:
assert
attn_metadata
.
num_decode_tokens
==
decode_batch_size
...
...
vllm/attention/__init__.py
View file @
0fca3cdc
...
...
@@ -5,9 +5,9 @@ from vllm.attention.layer import Attention
from
vllm.attention.selector
import
get_attn_backend
__all__
=
[
"Attention"
,
"AttentionBackend"
,
"AttentionMetadata"
,
"Attention"
,
"get_attn_backend"
,
"AttentionMetadataPerStage"
,
"get_attn_backend"
,
]
vllm/attention/backends/abstract.py
View file @
0fca3cdc
...
...
@@ -94,8 +94,6 @@ class AttentionMetadata(Generic[T]):
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping
:
torch
.
Tensor
# The kv cache's data type.
kv_cache_dtype
:
str
def
__post_init__
(
self
):
if
self
.
num_prefill_tokens
>
0
:
...
...
@@ -116,6 +114,7 @@ class AttentionImpl(ABC):
num_kv_heads
:
Optional
[
int
]
=
None
,
alibi_slopes
:
Optional
[
List
[
float
]]
=
None
,
sliding_window
:
Optional
[
int
]
=
None
,
kv_cache_dtype
:
str
=
"auto"
,
)
->
None
:
raise
NotImplementedError
...
...
@@ -127,6 +126,6 @@ class AttentionImpl(ABC):
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
kv_scale
:
float
,
kv_scale
:
float
=
1.0
,
)
->
torch
.
Tensor
:
raise
NotImplementedError
vllm/attention/backends/flash_attn.py
View file @
0fca3cdc
...
...
@@ -140,16 +140,18 @@ class FlashAttentionImpl(AttentionImpl):
num_kv_heads
:
Optional
[
int
]
=
None
,
alibi_slopes
:
Optional
[
List
[
float
]]
=
None
,
sliding_window
:
Optional
[
int
]
=
None
,
kv_cache_dtype
:
str
=
"auto"
,
)
->
None
:
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
num_kv_heads
=
num_heads
if
num_kv_heads
is
None
else
num_kv_heads
self
.
sliding_window
=
((
sliding_window
,
sliding_window
)
if
sliding_window
is
not
None
else
(
-
1
,
-
1
))
if
alibi_slopes
is
not
None
:
alibi_slopes
=
torch
.
tensor
(
alibi_slopes
,
dtype
=
torch
.
float32
)
self
.
alibi_slopes
=
alibi_slopes
self
.
sliding_window
=
((
sliding_window
,
sliding_window
)
if
sliding_window
is
not
None
else
(
-
1
,
-
1
))
self
.
kv_cache_dtype
=
kv_cache_dtype
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
...
...
@@ -167,7 +169,7 @@ class FlashAttentionImpl(AttentionImpl):
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
[
FlashAttentionMetadata
],
kv_scale
:
float
,
kv_scale
:
float
=
1.0
,
)
->
torch
.
Tensor
:
"""Forward pass with FlashAttention and PagedAttention.
...
...
@@ -196,8 +198,7 @@ class FlashAttentionImpl(AttentionImpl):
PagedAttention
.
write_to_paged_cache
(
key
,
value
,
key_cache
,
value_cache
,
attn_metadata
.
slot_mapping
,
attn_metadata
.
kv_cache_dtype
,
kv_scale
)
self
.
kv_cache_dtype
,
kv_scale
)
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
...
...
@@ -264,7 +265,7 @@ class FlashAttentionImpl(AttentionImpl):
decode_meta
.
block_tables
,
decode_meta
.
seq_lens_tensor
,
decode_meta
.
max_seq_len
,
attn_metadata
.
kv_cache_dtype
,
self
.
kv_cache_dtype
,
self
.
num_kv_heads
,
self
.
scale
,
self
.
alibi_slopes
,
...
...
vllm/attention/backends/flashinfer.py
View file @
0fca3cdc
...
...
@@ -149,20 +149,33 @@ class FlashInferImpl(AttentionImpl):
num_kv_heads
:
Optional
[
int
]
=
None
,
alibi_slopes
:
Optional
[
List
[
float
]]
=
None
,
sliding_window
:
Optional
[
int
]
=
None
,
kv_cache_dtype
:
str
=
"auto"
,
)
->
None
:
if
sliding_window
is
not
None
:
raise
ValueError
(
"Sliding window is not supported in FlashInfer."
)
self
.
sliding_window
=
(
-
1
,
-
1
)
self
.
alibi_slopes
=
alibi_slopes
self
.
scale
=
scale
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
num_kv_heads
=
num_heads
if
num_kv_heads
is
None
else
num_kv_heads
if
alibi_slopes
is
not
None
:
alibi_slopes
=
torch
.
tensor
(
alibi_slopes
,
dtype
=
torch
.
float32
)
self
.
alibi_slopes
=
alibi_slopes
if
sliding_window
is
not
None
:
raise
ValueError
(
"Sliding window is not supported in FlashInfer."
)
self
.
sliding_window
=
(
-
1
,
-
1
)
self
.
kv_cache_dtype
=
kv_cache_dtype
def
forward
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
Optional
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
[
FlashInferMetadata
],
kv_scale
:
float
):
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
def
forward
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
Optional
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
[
FlashInferMetadata
],
kv_scale
:
float
=
1.0
,
)
->
torch
.
Tensor
:
assert
kv_scale
==
1.0
num_tokens
,
hidden_size
=
query
.
shape
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
...
...
@@ -183,7 +196,7 @@ class FlashInferImpl(AttentionImpl):
kv_cache
[:,
0
],
kv_cache
[:,
1
],
attn_metadata
.
slot_mapping
.
flatten
(),
attn_metadata
.
kv_cache_dtype
,
self
.
kv_cache_dtype
,
)
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
...
...
vllm/attention/backends/rocm_flash_attn.py
View file @
0fca3cdc
...
...
@@ -138,25 +138,27 @@ class ROCmFlashAttentionImpl(AttentionImpl):
num_kv_heads
:
Optional
[
int
]
=
None
,
alibi_slopes
:
Optional
[
List
[
float
]]
=
None
,
sliding_window
:
Optional
[
int
]
=
None
,
kv_cache_dtype
:
str
=
"auto"
,
)
->
None
:
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
num_kv_heads
=
num_heads
if
num_kv_heads
is
None
else
num_kv_heads
self
.
sliding_window
=
((
sliding_window
,
sliding_window
)
if
sliding_window
is
not
None
else
(
-
1
,
-
1
))
if
alibi_slopes
is
not
None
:
alibi_slopes
=
torch
.
tensor
(
alibi_slopes
,
dtype
=
torch
.
float32
)
self
.
alibi_slopes
=
alibi_slopes
self
.
sliding_window
=
((
sliding_window
,
sliding_window
)
if
sliding_window
is
not
None
else
(
-
1
,
-
1
))
self
.
kv_cache_dtype
=
kv_cache_dtype
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
suppored_head_sizes
=
PagedAttention
.
get_supported_head_sizes
()
if
head_size
not
in
suppored_head_sizes
:
suppor
t
ed_head_sizes
=
PagedAttention
.
get_supported_head_sizes
()
if
head_size
not
in
suppor
t
ed_head_sizes
:
raise
ValueError
(
f
"Head size
{
head_size
}
is not supported by PagedAttention. "
f
"Supported head sizes are:
{
suppored_head_sizes
}
."
)
f
"Supported head sizes are:
{
suppor
t
ed_head_sizes
}
."
)
self
.
use_naive_attn
=
False
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
...
...
@@ -229,7 +231,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
key_cache
,
value_cache
,
attn_metadata
.
slot_mapping
,
attn_metadata
.
kv_cache_dtype
,
self
.
kv_cache_dtype
,
kv_scale
,
)
...
...
@@ -323,7 +325,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
decode_meta
.
block_tables
,
decode_meta
.
seq_lens_tensor
,
decode_meta
.
max_seq_len
,
attn_metadata
.
kv_cache_dtype
,
self
.
kv_cache_dtype
,
self
.
num_kv_heads
,
self
.
scale
,
self
.
alibi_slopes
,
...
...
vllm/attention/backends/torch_sdpa.py
View file @
0fca3cdc
...
...
@@ -83,26 +83,32 @@ class TorchSDPABackendImpl(AttentionImpl):
num_kv_heads
:
Optional
[
int
]
=
None
,
alibi_slopes
:
Optional
[
List
[
float
]]
=
None
,
sliding_window
:
Optional
[
int
]
=
None
,
kv_cache_dtype
:
str
=
"auto"
,
)
->
None
:
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
num_kv_heads
=
num_heads
if
num_kv_heads
is
None
else
num_kv_heads
self
.
sliding_window
=
sliding_window
if
alibi_slopes
is
not
None
:
assert
len
(
alibi_slopes
)
==
num_heads
alibi_slopes
=
torch
.
tensor
(
alibi_slopes
,
dtype
=
torch
.
float32
)
self
.
alibi_slopes
=
alibi_slopes
self
.
need_mask
=
(
self
.
alibi_slopes
is
not
None
or
self
.
sliding_window
is
not
None
)
self
.
sliding_window
=
sliding_window
self
.
kv_cache_dtype
=
kv_cache_dtype
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
suppored_head_sizes
=
PagedAttention
.
get_supported_head_sizes
()
if
head_size
not
in
suppored_head_sizes
:
self
.
need_mask
=
(
self
.
alibi_slopes
is
not
None
or
self
.
sliding_window
is
not
None
)
supported_head_sizes
=
PagedAttention
.
get_supported_head_sizes
()
if
head_size
not
in
supported_head_sizes
:
raise
ValueError
(
f
"Head size
{
head_size
}
is not supported by PagedAttention. "
f
"Supported head sizes are:
{
suppored_head_sizes
}
."
)
f
"Supported head sizes are:
{
supported_head_sizes
}
."
)
if
kv_cache_dtype
!=
"auto"
:
raise
NotImplementedError
(
"Torch SDPA backend does not support FP8 KV cache. "
"Please use xFormers backend instead."
)
def
forward
(
self
,
...
...
@@ -111,7 +117,7 @@ class TorchSDPABackendImpl(AttentionImpl):
value
:
torch
.
Tensor
,
kv_cache
:
Optional
[
torch
.
Tensor
],
attn_metadata
:
TorchSDPAMetadata
,
# type: ignore
kv_scale
:
float
,
kv_scale
:
float
=
1.0
,
)
->
torch
.
Tensor
:
"""Forward pass with torch SDPA and PagedAttention.
...
...
@@ -124,6 +130,7 @@ class TorchSDPABackendImpl(AttentionImpl):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert
kv_scale
==
1.0
num_tokens
,
hidden_size
=
query
.
shape
# Reshape the query, key, and value tensors.
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
...
...
@@ -136,8 +143,7 @@ class TorchSDPABackendImpl(AttentionImpl):
PagedAttention
.
write_to_paged_cache
(
key
,
value
,
key_cache
,
value_cache
,
attn_metadata
.
slot_mapping
,
attn_metadata
.
kv_cache_dtype
,
kv_scale
)
self
.
kv_cache_dtype
,
kv_scale
)
if
attn_metadata
.
is_prompt
:
assert
attn_metadata
.
seq_lens
is
not
None
...
...
@@ -195,7 +201,7 @@ class TorchSDPABackendImpl(AttentionImpl):
attn_metadata
.
block_tables
,
attn_metadata
.
seq_lens_tensor
,
attn_metadata
.
max_seq_len
,
attn_metadata
.
kv_cache_dtype
,
self
.
kv_cache_dtype
,
self
.
num_kv_heads
,
self
.
scale
,
self
.
alibi_slopes
,
...
...
vllm/attention/backends/xformers.py
View file @
0fca3cdc
...
...
@@ -149,15 +149,17 @@ class XFormersImpl(AttentionImpl):
num_kv_heads
:
Optional
[
int
]
=
None
,
alibi_slopes
:
Optional
[
List
[
float
]]
=
None
,
sliding_window
:
Optional
[
int
]
=
None
,
kv_cache_dtype
:
str
=
"auto"
,
)
->
None
:
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
num_kv_heads
=
num_heads
if
num_kv_heads
is
None
else
num_kv_heads
self
.
sliding_window
=
sliding_window
if
alibi_slopes
is
not
None
:
alibi_slopes
=
torch
.
tensor
(
alibi_slopes
,
dtype
=
torch
.
float32
)
self
.
alibi_slopes
=
alibi_slopes
self
.
sliding_window
=
sliding_window
self
.
kv_cache_dtype
=
kv_cache_dtype
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
...
...
@@ -175,7 +177,7 @@ class XFormersImpl(AttentionImpl):
value
:
torch
.
Tensor
,
kv_cache
:
Optional
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
[
XFormersMetadata
],
kv_scale
:
float
,
kv_scale
:
float
=
1.0
,
)
->
torch
.
Tensor
:
"""Forward pass with xFormers and PagedAttention.
...
...
@@ -188,7 +190,6 @@ class XFormersImpl(AttentionImpl):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
num_tokens
,
hidden_size
=
query
.
shape
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
...
...
@@ -203,8 +204,7 @@ class XFormersImpl(AttentionImpl):
PagedAttention
.
write_to_paged_cache
(
key
,
value
,
key_cache
,
value_cache
,
attn_metadata
.
slot_mapping
,
attn_metadata
.
kv_cache_dtype
,
kv_scale
)
self
.
kv_cache_dtype
,
kv_scale
)
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
...
...
@@ -262,7 +262,7 @@ class XFormersImpl(AttentionImpl):
decode_meta
.
block_tables
,
decode_meta
.
seq_lens_tensor
,
decode_meta
.
max_seq_len
,
attn_metadata
.
kv_cache_dtype
,
self
.
kv_cache_dtype
,
self
.
num_kv_heads
,
self
.
scale
,
self
.
alibi_slopes
,
...
...
vllm/attention/layer.py
View file @
0fca3cdc
...
...
@@ -7,6 +7,7 @@ import torch.nn as nn
from
vllm.attention.backends.abstract
import
(
AttentionMetadata
,
AttentionMetadataPerStage
)
from
vllm.attention.selector
import
get_attn_backend
from
vllm.config
import
CacheConfig
class
Attention
(
nn
.
Module
):
...
...
@@ -29,10 +30,24 @@ class Attention(nn.Module):
num_kv_heads
:
Optional
[
int
]
=
None
,
alibi_slopes
:
Optional
[
List
[
float
]]
=
None
,
sliding_window
:
Optional
[
int
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
backend
=
get_attn_backend
(
torch
.
get_default_dtype
())
impl_cls
=
self
.
backend
.
get_impl_cls
()
if
cache_config
is
not
None
:
kv_cache_dtype
=
cache_config
.
cache_dtype
block_size
=
cache_config
.
block_size
else
:
kv_cache_dtype
=
"auto"
block_size
=
16
if
num_kv_heads
is
None
:
num_kv_heads
=
num_heads
# During model initialization, the default dtype is set as the model
# weight and activation dtype.
dtype
=
torch
.
get_default_dtype
()
attn_backend
=
get_attn_backend
(
num_heads
,
head_size
,
num_kv_heads
,
sliding_window
,
dtype
,
kv_cache_dtype
,
block_size
)
impl_cls
=
attn_backend
.
get_impl_cls
()
self
.
impl
=
impl_cls
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
alibi_slopes
,
sliding_window
)
...
...
vllm/attention/selector.py
View file @
0fca3cdc
import
enum
from
functools
import
lru_cache
from
typing
import
Type
from
typing
import
Optional
,
Type
import
torch
...
...
@@ -21,8 +21,18 @@ class _Backend(enum.Enum):
@
lru_cache
(
maxsize
=
None
)
def
get_attn_backend
(
dtype
:
torch
.
dtype
)
->
Type
[
AttentionBackend
]:
backend
=
_which_attn_to_use
(
dtype
)
def
get_attn_backend
(
num_heads
:
int
,
head_size
:
int
,
num_kv_heads
:
int
,
sliding_window
:
Optional
[
int
],
dtype
:
torch
.
dtype
,
kv_cache_dtype
:
Optional
[
str
],
block_size
:
int
,
)
->
Type
[
AttentionBackend
]:
backend
=
_which_attn_to_use
(
num_heads
,
head_size
,
num_kv_heads
,
sliding_window
,
dtype
,
kv_cache_dtype
,
block_size
)
if
backend
==
_Backend
.
FLASH_ATTN
:
logger
.
info
(
"Using FlashAttention-2 backend."
)
from
vllm.attention.backends.flash_attn
import
(
# noqa: F401
...
...
@@ -44,14 +54,22 @@ def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]:
return
TorchSDPABackend
elif
backend
==
_Backend
.
FLASHINFER
:
logger
.
info
(
"Using Flashinfer backend."
)
logger
.
warning
(
"Eager mode is enforced for the Flashinfer backend.
"
)
logger
.
warning
(
"Eager mode is enforced for the Flashinfer backend."
)
from
vllm.attention.backends.flashinfer
import
FlashInferBackend
return
FlashInferBackend
else
:
raise
ValueError
(
"Invalid attention backend."
)
def
_which_attn_to_use
(
dtype
:
torch
.
dtype
)
->
_Backend
:
def
_which_attn_to_use
(
num_heads
:
int
,
head_size
:
int
,
num_kv_heads
:
int
,
sliding_window
:
Optional
[
int
],
dtype
:
torch
.
dtype
,
kv_cache_dtype
:
Optional
[
str
],
block_size
:
int
,
)
->
_Backend
:
"""Returns which flash attention backend to use."""
if
is_cpu
():
return
_Backend
.
TORCH_SDPA
...
...
vllm/model_executor/model_loader/__init__.py
View file @
0fca3cdc
...
...
@@ -2,26 +2,29 @@ from typing import Optional
from
torch
import
nn
from
vllm.config
import
(
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
VisionLanguageConfig
)
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
VisionLanguageConfig
)
from
vllm.model_executor.model_loader.loader
import
(
BaseModelLoader
,
get_model_loader
)
from
vllm.model_executor.model_loader.utils
import
(
get_architecture_class_name
,
get_model_architecture
)
def
get_model
(
*
,
model_config
:
ModelConfig
,
load_config
:
LoadConfig
,
device_config
:
DeviceConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
lora_config
:
Optional
[
LoRAConfig
],
vision_language_config
:
Optional
[
VisionLanguageConfig
])
->
nn
.
Module
:
def
get_model
(
*
,
model_config
:
ModelConfig
,
load_config
:
LoadConfig
,
device_config
:
DeviceConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
lora_config
:
Optional
[
LoRAConfig
],
vision_language_config
:
Optional
[
VisionLanguageConfig
],
cache_config
:
CacheConfig
)
->
nn
.
Module
:
loader
=
get_model_loader
(
load_config
)
return
loader
.
load_model
(
model_config
=
model_config
,
device_config
=
device_config
,
lora_config
=
lora_config
,
vision_language_config
=
vision_language_config
,
parallel_config
=
parallel_config
,
scheduler_config
=
scheduler_config
)
scheduler_config
=
scheduler_config
,
cache_config
=
cache_config
)
__all__
=
[
...
...
vllm/model_executor/model_loader/loader.py
View file @
0fca3cdc
...
...
@@ -9,9 +9,9 @@ import huggingface_hub
import
torch
from
torch
import
nn
from
vllm.config
import
(
DeviceConfig
,
LoadConfig
,
LoadFormat
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
VisionLanguageConfig
)
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoadFormat
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
VisionLanguageConfig
)
from
vllm.envs
import
VLLM_USE_MODELSCOPE
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.base_config
import
(
...
...
@@ -77,15 +77,16 @@ def _get_model_initialization_kwargs(
return
extra_kwargs
def
_initialize_model
(
model_config
:
ModelConfig
,
loa
d
_config
:
Load
Config
,
lora
_config
:
Optional
[
LoRA
Config
],
vision_language_config
:
Optional
[
VisionLanguag
eConfig
]
)
->
nn
.
Module
:
def
_initialize_model
(
model_config
:
ModelConfig
,
load_config
:
LoadConfig
,
lo
r
a_config
:
Optional
[
LoRA
Config
]
,
vision_language
_config
:
Optional
[
VisionLanguage
Config
],
cache_config
:
Cach
eConfig
)
->
nn
.
Module
:
"""Initialize a model with the given configurations."""
model_class
=
get_model_architecture
(
model_config
)[
0
]
quant_config
=
_get_quantization_config
(
model_config
,
load_config
)
return
model_class
(
config
=
model_config
.
hf_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
**
_get_model_initialization_kwargs
(
model_class
,
lora_config
,
vision_language_config
))
...
...
@@ -103,7 +104,8 @@ class BaseModelLoader(ABC):
lora_config
:
Optional
[
LoRAConfig
],
vision_language_config
:
Optional
[
VisionLanguageConfig
],
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
)
->
nn
.
Module
:
scheduler_config
:
SchedulerConfig
,
cache_config
:
CacheConfig
)
->
nn
.
Module
:
"""Load a model with the given configurations."""
...
...
...
@@ -216,11 +218,13 @@ class DefaultModelLoader(BaseModelLoader):
lora_config
:
Optional
[
LoRAConfig
],
vision_language_config
:
Optional
[
VisionLanguageConfig
],
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
)
->
nn
.
Module
:
scheduler_config
:
SchedulerConfig
,
cache_config
:
CacheConfig
)
->
nn
.
Module
:
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
torch
.
device
(
device_config
.
device
):
model
=
_initialize_model
(
model_config
,
self
.
load_config
,
lora_config
,
vision_language_config
)
lora_config
,
vision_language_config
,
cache_config
)
model
.
load_weights
(
self
.
_get_weights_iterator
(
model_config
.
model
,
model_config
.
revision
,
...
...
@@ -253,11 +257,13 @@ class DummyModelLoader(BaseModelLoader):
lora_config
:
Optional
[
LoRAConfig
],
vision_language_config
:
Optional
[
VisionLanguageConfig
],
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
)
->
nn
.
Module
:
scheduler_config
:
SchedulerConfig
,
cache_config
:
CacheConfig
)
->
nn
.
Module
:
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
torch
.
device
(
device_config
.
device
):
model
=
_initialize_model
(
model_config
,
self
.
load_config
,
lora_config
,
vision_language_config
)
lora_config
,
vision_language_config
,
cache_config
)
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights
(
model
)
...
...
@@ -286,9 +292,12 @@ class TensorizerLoader(BaseModelLoader):
return
tensorizer_weights_iterator
(
tensorizer_args
)
def
_load_model_unserialized
(
self
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
vision_language_config
:
Optional
[
VisionLanguageConfig
]
self
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
vision_language_config
:
Optional
[
VisionLanguageConfig
],
cache_config
:
CacheConfig
,
)
->
nn
.
Module
:
"""Load an unserialized model with tensorizer.
...
...
@@ -299,15 +308,19 @@ class TensorizerLoader(BaseModelLoader):
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
torch
.
device
(
device_config
.
device
):
model
=
_initialize_model
(
model_config
,
self
.
load_config
,
lora_config
,
vision_language_config
)
lora_config
,
vision_language_config
,
cache_config
)
model
.
load_weights
(
self
.
_get_weights_iterator
())
return
model
.
eval
()
def
_load_model_serialized
(
self
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
vision_language_config
:
Optional
[
VisionLanguageConfig
]
self
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
vision_language_config
:
Optional
[
VisionLanguageConfig
],
cache_config
:
CacheConfig
,
)
->
nn
.
Module
:
"""Load a serialized model with tensorizer.
...
...
@@ -321,6 +334,7 @@ class TensorizerLoader(BaseModelLoader):
extra_kwargs
=
_get_model_initialization_kwargs
(
model_class
,
lora_config
,
vision_language_config
)
extra_kwargs
[
"quant_config"
]
=
quant_config
extra_kwargs
[
"cache_config"
]
=
cache_config
tensorizer_config
=
copy
.
copy
(
self
.
tensorizer_config
)
tensorizer_config
.
model_class
=
model_class
...
...
@@ -335,16 +349,19 @@ class TensorizerLoader(BaseModelLoader):
lora_config
:
Optional
[
LoRAConfig
],
vision_language_config
:
Optional
[
VisionLanguageConfig
],
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
)
->
nn
.
Module
:
scheduler_config
:
SchedulerConfig
,
cache_config
:
CacheConfig
)
->
nn
.
Module
:
self
.
_verify_config
(
model_config
,
parallel_config
)
if
is_vllm_serialized_tensorizer
(
self
.
tensorizer_config
):
return
self
.
_load_model_serialized
(
model_config
,
device_config
,
lora_config
,
vision_language_config
)
vision_language_config
,
cache_config
)
return
self
.
_load_model_unserialized
(
model_config
,
device_config
,
lora_config
,
vision_language_config
)
vision_language_config
,
cache_config
)
def
get_model_loader
(
load_config
:
LoadConfig
)
->
BaseModelLoader
:
...
...
vllm/model_executor/models/arctic.py
View file @
0fca3cdc
...
...
@@ -5,6 +5,7 @@ import torch
from
torch
import
nn
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
)
...
...
@@ -215,6 +216,7 @@ class ArcticAttention(nn.Module):
self
,
config
:
ArcticConfig
,
layer_idx
:
Optional
[
int
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
...
...
@@ -265,7 +267,8 @@ class ArcticAttention(nn.Module):
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
)
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
)
def
forward
(
self
,
...
...
@@ -288,6 +291,7 @@ class ArcticDecoderLayer(nn.Module):
self
,
config
:
ArcticConfig
,
layer_idx
:
int
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
...
...
@@ -297,6 +301,7 @@ class ArcticDecoderLayer(nn.Module):
self
.
use_residual
=
config
.
use_residual
and
is_moe_layer
self
.
self_attn
=
ArcticAttention
(
config
,
layer_idx
,
cache_config
,
quant_config
=
quant_config
)
self
.
block_sparse_moe
=
ArcticMoE
(
config
,
...
...
@@ -356,6 +361,7 @@ class ArcticModel(nn.Module):
def
__init__
(
self
,
config
:
ArcticConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
...
...
@@ -366,7 +372,10 @@ class ArcticModel(nn.Module):
config
.
hidden_size
,
org_num_embeddings
=
self
.
vocab_size
)
self
.
layers
=
nn
.
ModuleList
([
ArcticDecoderLayer
(
config
,
layer_idx
,
quant_config
=
quant_config
)
ArcticDecoderLayer
(
config
,
layer_idx
,
cache_config
,
quant_config
=
quant_config
)
for
layer_idx
in
range
(
config
.
num_hidden_layers
)
])
self
.
_attn_implementation
=
config
.
_attn_implementation
...
...
@@ -392,11 +401,12 @@ class ArcticForCausalLM(nn.Module):
def
__init__
(
self
,
config
:
ArcticConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
**
kwargs
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
model
=
ArcticModel
(
config
,
quant_config
)
self
.
model
=
ArcticModel
(
config
,
cache_config
,
quant_config
)
self
.
vocab_size
=
config
.
vocab_size
self
.
lm_head
=
ParallelLMHead
(
self
.
vocab_size
,
...
...
vllm/model_executor/models/baichuan.py
View file @
0fca3cdc
...
...
@@ -26,7 +26,7 @@ from torch import nn
from
transformers
import
PretrainedConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
LoRAConfig
from
vllm.config
import
CacheConfig
,
LoRAConfig
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.layers.activation
import
SiluAndMul
...
...
@@ -111,6 +111,7 @@ class BaiChuanAttention(nn.Module):
position_embedding
:
str
,
rope_theta
:
float
=
10000
,
max_position_embeddings
:
int
=
8192
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
...
...
@@ -162,7 +163,10 @@ class BaiChuanAttention(nn.Module):
base
=
self
.
rope_theta
,
)
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
cache_config
=
cache_config
)
def
forward
(
self
,
...
...
@@ -185,6 +189,7 @@ class BaiChuanDecoderLayer(nn.Module):
def
__init__
(
self
,
config
:
PretrainedConfig
,
position_embedding
:
str
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
...
...
@@ -197,6 +202,7 @@ class BaiChuanDecoderLayer(nn.Module):
position_embedding
=
position_embedding
,
rope_theta
=
rope_theta
,
max_position_embeddings
=
max_position_embeddings
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
)
self
.
mlp
=
BaiChuanMLP
(
...
...
@@ -244,6 +250,7 @@ class BaiChuanModel(nn.Module):
def
__init__
(
self
,
config
:
PretrainedConfig
,
position_embedding
:
str
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
self
.
config
=
config
...
...
@@ -255,7 +262,8 @@ class BaiChuanModel(nn.Module):
config
.
hidden_size
,
)
self
.
layers
=
nn
.
ModuleList
([
BaiChuanDecoderLayer
(
config
,
position_embedding
,
quant_config
)
BaiChuanDecoderLayer
(
config
,
position_embedding
,
cache_config
,
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
...
...
@@ -304,13 +312,15 @@ class BaiChuanBaseForCausalLM(nn.Module):
self
,
config
,
position_embedding
:
str
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
model
=
BaiChuanModel
(
config
,
position_embedding
,
quant_config
)
self
.
model
=
BaiChuanModel
(
config
,
position_embedding
,
cache_config
,
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
...
...
@@ -389,13 +399,16 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
def
__init__
(
self
,
config
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
):
if
config
.
hidden_size
==
4096
:
# baichuan2 7b
super
().
__init__
(
config
,
"ROPE"
,
quant_config
,
lora_config
)
super
().
__init__
(
config
,
"ROPE"
,
cache_config
,
quant_config
,
lora_config
)
else
:
# baichuan 13b, baichuan2 13b
super
().
__init__
(
config
,
"ALIBI"
,
quant_config
,
lora_config
)
super
().
__init__
(
config
,
"ALIBI"
,
cache_config
,
quant_config
,
lora_config
)
class
BaiChuanForCausalLM
(
BaiChuanBaseForCausalLM
):
...
...
@@ -404,7 +417,9 @@ class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
def
__init__
(
self
,
config
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
):
super
().
__init__
(
config
,
"ROPE"
,
quant_config
,
lora_config
)
super
().
__init__
(
config
,
"ROPE"
,
cache_config
,
quant_config
,
lora_config
)
vllm/model_executor/models/bloom.py
View file @
0fca3cdc
...
...
@@ -24,6 +24,7 @@ from torch import nn
from
transformers
import
BloomConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.layers.activation
import
get_act_fn
...
...
@@ -71,6 +72,7 @@ class BloomAttention(nn.Module):
def
__init__
(
self
,
config
:
BloomConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
...
...
@@ -108,7 +110,8 @@ class BloomAttention(nn.Module):
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
scaling
,
alibi_slopes
=
alibi_slopes
)
alibi_slopes
=
alibi_slopes
,
cache_config
=
cache_config
)
def
forward
(
self
,
...
...
@@ -158,6 +161,7 @@ class BloomBlock(nn.Module):
def
__init__
(
self
,
config
:
BloomConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
...
...
@@ -165,7 +169,8 @@ class BloomBlock(nn.Module):
self
.
input_layernorm
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
self_attention
=
BloomAttention
(
config
,
quant_config
)
self
.
self_attention
=
BloomAttention
(
config
,
cache_config
,
quant_config
)
self
.
post_attention_layernorm
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
mlp
=
BloomMLP
(
config
,
quant_config
)
...
...
@@ -214,6 +219,7 @@ class BloomModel(nn.Module):
def
__init__
(
self
,
config
:
BloomConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
...
...
@@ -229,7 +235,7 @@ class BloomModel(nn.Module):
# Transformer blocks
self
.
h
=
nn
.
ModuleList
([
BloomBlock
(
config
,
quant_config
)
BloomBlock
(
config
,
cache_config
,
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
...
...
@@ -262,12 +268,13 @@ class BloomForCausalLM(nn.Module):
def
__init__
(
self
,
config
:
BloomConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
transformer
=
BloomModel
(
config
,
quant_config
)
self
.
transformer
=
BloomModel
(
config
,
cache_config
,
quant_config
)
self
.
lm_head_weight
=
self
.
transformer
.
word_embeddings
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
...
...
vllm/model_executor/models/chatglm.py
View file @
0fca3cdc
...
...
@@ -9,7 +9,7 @@ from torch import nn
from
torch.nn
import
LayerNorm
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
LoRAConfig
from
vllm.config
import
CacheConfig
,
LoRAConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
...
...
@@ -34,6 +34,7 @@ class GLMAttention(nn.Module):
def
__init__
(
self
,
config
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
...
...
@@ -90,6 +91,7 @@ class GLMAttention(nn.Module):
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
)
def
forward
(
...
...
@@ -167,6 +169,7 @@ class GLMBlock(nn.Module):
def
__init__
(
self
,
config
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
...
...
@@ -181,7 +184,7 @@ class GLMBlock(nn.Module):
eps
=
config
.
layernorm_epsilon
)
# Self attention.
self
.
self_attention
=
GLMAttention
(
config
,
quant_config
)
self
.
self_attention
=
GLMAttention
(
config
,
cache_config
,
quant_config
)
self
.
hidden_dropout
=
config
.
hidden_dropout
# Layernorm on the attention output
...
...
@@ -237,6 +240,7 @@ class GLMTransformer(nn.Module):
def
__init__
(
self
,
config
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
...
...
@@ -246,8 +250,10 @@ class GLMTransformer(nn.Module):
self
.
num_layers
=
config
.
num_layers
# Transformer layers.
self
.
layers
=
nn
.
ModuleList
(
[
GLMBlock
(
config
,
quant_config
)
for
i
in
range
(
self
.
num_layers
)])
self
.
layers
=
nn
.
ModuleList
([
GLMBlock
(
config
,
cache_config
,
quant_config
)
for
i
in
range
(
self
.
num_layers
)
])
if
self
.
post_layer_norm
:
layer_norm_func
=
RMSNorm
if
config
.
rmsnorm
else
LayerNorm
...
...
@@ -282,6 +288,7 @@ class ChatGLMModel(nn.Module):
def
__init__
(
self
,
config
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
...
...
@@ -292,7 +299,7 @@ class ChatGLMModel(nn.Module):
self
.
num_layers
=
config
.
num_layers
self
.
multi_query_group_num
=
config
.
multi_query_group_num
self
.
kv_channels
=
config
.
kv_channels
self
.
encoder
=
GLMTransformer
(
config
,
quant_config
)
self
.
encoder
=
GLMTransformer
(
config
,
cache_config
,
quant_config
)
self
.
output_layer
=
ParallelLMHead
(
config
.
padded_vocab_size
,
config
.
hidden_size
)
...
...
@@ -334,13 +341,14 @@ class ChatGLMForCausalLM(nn.Module):
def
__init__
(
self
,
config
:
ChatGLMConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
:
ChatGLMConfig
=
config
self
.
quant_config
=
quant_config
self
.
transformer
=
ChatGLMModel
(
config
,
quant_config
)
self
.
transformer
=
ChatGLMModel
(
config
,
cache_config
,
quant_config
)
self
.
lm_head_weight
=
self
.
transformer
.
output_layer
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
padded_vocab_size
)
self
.
sampler
=
Sampler
()
...
...
vllm/model_executor/models/commandr.py
View file @
0fca3cdc
...
...
@@ -29,6 +29,7 @@ from torch.nn.parameter import Parameter
from
transformers
import
CohereConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.layers.activation
import
SiluAndMul
...
...
@@ -124,6 +125,7 @@ class CohereAttention(nn.Module):
def
__init__
(
self
,
config
:
CohereConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
...
...
@@ -180,6 +182,7 @@ class CohereAttention(nn.Module):
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
)
if
self
.
use_qk_norm
:
self
.
q_norm
=
LayerNorm
(
param_shape
=
(
self
.
num_heads
,
...
...
@@ -219,11 +222,14 @@ class CohereDecoderLayer(nn.Module):
def
__init__
(
self
,
config
:
CohereConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
self_attn
=
CohereAttention
(
config
,
quant_config
=
quant_config
)
self
.
self_attn
=
CohereAttention
(
config
,
cache_config
,
quant_config
=
quant_config
)
self
.
mlp
=
CohereMLP
(
config
,
quant_config
=
quant_config
)
self
.
input_layernorm
=
LayerNorm
(
param_shape
=
(
config
.
hidden_size
),
...
...
@@ -258,6 +264,7 @@ class CohereModel(nn.Module):
def
__init__
(
self
,
config
:
CohereConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
...
...
@@ -266,7 +273,7 @@ class CohereModel(nn.Module):
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
layers
=
nn
.
ModuleList
([
CohereDecoderLayer
(
config
,
quant_config
=
quant_config
)
CohereDecoderLayer
(
config
,
cache_config
,
quant_config
=
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
self
.
norm
=
LayerNorm
(
param_shape
=
(
config
.
hidden_size
),
...
...
@@ -299,6 +306,7 @@ class CohereForCausalLM(nn.Module):
def
__init__
(
self
,
config
:
CohereConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
...
...
@@ -306,7 +314,7 @@ class CohereForCausalLM(nn.Module):
self
.
quant_config
=
quant_config
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
,
scale
=
config
.
logit_scale
)
self
.
model
=
CohereModel
(
config
,
quant_config
)
self
.
model
=
CohereModel
(
config
,
cache_config
,
quant_config
)
self
.
sampler
=
Sampler
()
@
torch
.
no_grad
()
...
...
vllm/model_executor/models/dbrx.py
View file @
0fca3cdc
...
...
@@ -5,6 +5,7 @@ import torch
import
torch.nn
as
nn
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
)
...
...
@@ -166,6 +167,7 @@ class DbrxAttention(nn.Module):
def
__init__
(
self
,
config
:
DbrxConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
...
...
@@ -221,6 +223,7 @@ class DbrxAttention(nn.Module):
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
)
def
forward
(
...
...
@@ -279,10 +282,12 @@ class DbrxBlock(nn.Module):
def
__init__
(
self
,
config
:
DbrxConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
norm_attn_norm
=
DbrxFusedNormAttention
(
config
,
quant_config
)
self
.
norm_attn_norm
=
DbrxFusedNormAttention
(
config
,
cache_config
,
quant_config
)
self
.
ffn
=
DbrxExperts
(
config
,
quant_config
)
def
forward
(
...
...
@@ -308,6 +313,7 @@ class DbrxModel(nn.Module):
def
__init__
(
self
,
config
:
DbrxConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
...
...
@@ -315,8 +321,10 @@ class DbrxModel(nn.Module):
config
.
vocab_size
,
config
.
d_model
,
)
self
.
blocks
=
nn
.
ModuleList
(
[
DbrxBlock
(
config
,
quant_config
)
for
_
in
range
(
config
.
n_layers
)])
self
.
blocks
=
nn
.
ModuleList
([
DbrxBlock
(
config
,
cache_config
,
quant_config
)
for
_
in
range
(
config
.
n_layers
)
])
self
.
norm_f
=
nn
.
LayerNorm
(
config
.
d_model
,
eps
=
1e-5
)
for
module
in
self
.
modules
():
if
hasattr
(
module
,
"bias"
)
and
isinstance
(
module
.
bias
,
...
...
@@ -349,13 +357,14 @@ class DbrxForCausalLM(nn.Module):
def
__init__
(
self
,
config
:
DbrxConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
unpadded_vocab_size
=
config
.
vocab_size
self
.
transformer
=
DbrxModel
(
config
,
quant_config
)
self
.
transformer
=
DbrxModel
(
config
,
cache_config
,
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
d_model
,
...
...
vllm/model_executor/models/decilm.py
View file @
0fca3cdc
...
...
@@ -28,7 +28,7 @@ from typing import Iterable, Optional, Tuple
import
torch
from
transformers
import
PretrainedConfig
from
vllm.config
import
LoRAConfig
from
vllm.config
import
CacheConfig
,
LoRAConfig
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
...
...
@@ -56,12 +56,14 @@ class DeciLMForCausalLM(LlamaForCausalLM):
def
__init__
(
self
,
config
:
Optional
[
PretrainedConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
)
->
None
:
config
.
num_key_value_heads
=
max
(
config
.
num_key_value_heads_per_layer
)
delattr
(
config
,
"num_key_value_heads_per_layer"
)
super
().
__init__
(
config
=
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
lora_config
=
lora_config
)
...
...
vllm/model_executor/models/deepseek.py
View file @
0fca3cdc
...
...
@@ -28,6 +28,7 @@ from torch import nn
from
transformers
import
PretrainedConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
)
...
...
@@ -178,6 +179,7 @@ class DeepseekAttention(nn.Module):
rope_theta
:
float
=
10000
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
max_position_embeddings
:
int
=
8192
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
...
...
@@ -229,7 +231,8 @@ class DeepseekAttention(nn.Module):
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
)
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
)
def
forward
(
self
,
...
...
@@ -252,6 +255,7 @@ class DeepseekDecoderLayer(nn.Module):
self
,
config
:
PretrainedConfig
,
layer_idx
:
int
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
...
...
@@ -267,6 +271,7 @@ class DeepseekDecoderLayer(nn.Module):
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
max_position_embeddings
=
max_position_embeddings
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
)
if
(
config
.
n_routed_experts
is
not
None
...
...
@@ -321,6 +326,7 @@ class DeepseekModel(nn.Module):
def
__init__
(
self
,
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
...
...
@@ -332,7 +338,10 @@ class DeepseekModel(nn.Module):
config
.
hidden_size
,
)
self
.
layers
=
nn
.
ModuleList
([
DeepseekDecoderLayer
(
config
,
layer_idx
,
quant_config
=
quant_config
)
DeepseekDecoderLayer
(
config
,
layer_idx
,
cache_config
,
quant_config
=
quant_config
)
for
layer_idx
in
range
(
config
.
num_hidden_layers
)
])
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
...
...
@@ -360,12 +369,13 @@ class DeepseekForCausalLM(nn.Module):
def
__init__
(
self
,
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
model
=
DeepseekModel
(
config
,
quant_config
)
self
.
model
=
DeepseekModel
(
config
,
cache_config
,
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment