Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
a8e98aee
Unverified
Commit
a8e98aee
authored
Sep 28, 2023
by
Woosuk Kwon
Committed by
GitHub
Sep 28, 2023
Browse files
Fix Mistral model (#1220)
parent
bb1ba58f
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
27 additions
and
14 deletions
+27
-14
vllm/model_executor/models/mistral.py
vllm/model_executor/models/mistral.py
+1
-1
vllm/transformers_utils/config.py
vllm/transformers_utils/config.py
+9
-0
vllm/transformers_utils/configs/__init__.py
vllm/transformers_utils/configs/__init__.py
+2
-0
vllm/worker/worker.py
vllm/worker/worker.py
+15
-13
No files found.
vllm/model_executor/models/mistral.py
View file @
a8e98aee
...
...
@@ -29,7 +29,6 @@ from typing import List, Optional, Tuple
import
torch
from
torch
import
nn
from
vllm.transformers_utils.configs.mistral
import
MistralConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
SiluAndMul
...
...
@@ -46,6 +45,7 @@ from vllm.model_executor.weight_utils import (
convert_pyslice_to_tensor
,
hf_model_weights_iterator
,
load_tensor_parallel_weights
,
load_padded_tensor_parallel_vocab
)
from
vllm.sequence
import
SamplerOutput
from
vllm.transformers_utils.configs.mistral
import
MistralConfig
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
...
...
vllm/transformers_utils/config.py
View file @
a8e98aee
...
...
@@ -17,6 +17,15 @@ _CONFIG_REGISTRY = {
def
get_config
(
model
:
str
,
trust_remote_code
:
bool
,
revision
:
Optional
[
str
]
=
None
)
->
PretrainedConfig
:
# NOTE: Because the Mistral model in HF hub does not have
# `configuration_mistral.py`, we cannot use `AutoConfig` to load the
# config. Instead, we use `MistralConfig` directly.
# NOTE: This is a hack. This does not work for local models.
# FIXME: Remove this once the Mistral model is available in the stable
# version of HF transformers.
if
"mistral"
in
model
.
lower
():
return
MistralConfig
.
from_pretrained
(
model
,
revision
=
revision
)
try
:
config
=
AutoConfig
.
from_pretrained
(
model
,
trust_remote_code
=
trust_remote_code
,
revision
=
revision
)
...
...
vllm/transformers_utils/configs/__init__.py
View file @
a8e98aee
...
...
@@ -6,6 +6,7 @@ from vllm.transformers_utils.configs.qwen import QWenConfig
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
# `FalconConfig` class from the official HuggingFace transformers library.
from
vllm.transformers_utils.configs.falcon
import
RWConfig
from
vllm.transformers_utils.configs.mistral
import
MistralConfig
__all__
=
[
"MPTConfig"
,
...
...
@@ -13,4 +14,5 @@ __all__ = [
"AquilaConfig"
,
"QWenConfig"
,
"RWConfig"
,
"MistralConfig"
,
]
vllm/worker/worker.py
View file @
a8e98aee
...
...
@@ -42,6 +42,7 @@ class Worker:
# self.init_cache_engine().
self
.
cache_config
=
None
self
.
block_size
=
None
self
.
sliding_window
=
None
self
.
cache_engine
=
None
self
.
cache_events
=
None
self
.
gpu_cache
=
None
...
...
@@ -136,10 +137,13 @@ class Worker:
def
init_cache_engine
(
self
,
cache_config
:
CacheConfig
)
->
None
:
self
.
cache_config
=
cache_config
self
.
block_size
=
cache_config
.
block_size
self
.
sliding_window
=
cache_config
.
sliding_window
max_seq_len
=
min
(
self
.
scheduler_config
.
max_model_len
,
cache_config
.
sliding_window
or
float
(
"inf"
))
if
self
.
sliding_window
is
None
:
max_seq_len
=
self
.
scheduler_config
.
max_model_len
else
:
max_seq_len
=
min
(
self
.
scheduler_config
.
max_model_len
,
self
.
sliding_window
)
_check_if_can_support_max_seq_len
(
max_seq_len
,
self
.
block_size
)
self
.
cache_engine
=
CacheEngine
(
self
.
cache_config
,
self
.
model_config
,
...
...
@@ -151,6 +155,8 @@ class Worker:
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
InputMetadata
]:
assert
self
.
block_size
is
not
None
seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]]
=
[]
input_tokens
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
...
...
@@ -193,9 +199,6 @@ class Worker:
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
.
append
(
slot
)
sliding_window
=
getattr
(
self
.
model_config
.
hf_config
,
"sliding_window"
,
float
(
"inf"
))
# Add generation tokens.
max_context_len
=
0
max_num_blocks_per_seq
=
0
...
...
@@ -216,8 +219,8 @@ class Worker:
context_len
=
seq_data
.
get_len
()
position
=
context_len
-
1
if
sliding_window
:
context_len
=
min
(
context_len
,
sliding_window
)
if
self
.
sliding_window
is
not
None
:
context_len
=
min
(
context_len
,
self
.
sliding_window
)
input_positions
.
append
(
position
)
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
...
...
@@ -232,10 +235,9 @@ class Worker:
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
.
append
(
slot
)
if
sliding_window
:
assert
self
.
cache_config
is
not
None
sliding_window_blocks
=
(
sliding_window
//
self
.
cache_config
.
block_size
)
if
self
.
sliding_window
is
not
None
:
sliding_window_blocks
=
(
self
.
sliding_window
//
self
.
block_size
)
block_table
=
block_table
[
-
sliding_window_blocks
:]
generation_block_tables
.
append
(
block_table
)
...
...
@@ -277,7 +279,7 @@ class Worker:
context_lens
=
context_lens_tensor
,
max_context_len
=
max_context_len
,
block_tables
=
block_tables_tensor
,
sliding_window
=
sliding_window
,
sliding_window
=
self
.
sliding_window
,
)
return
tokens_tensor
,
positions_tensor
,
input_metadata
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment