Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
a8e98aee
Unverified
Commit
a8e98aee
authored
Sep 28, 2023
by
Woosuk Kwon
Committed by
GitHub
Sep 28, 2023
Browse files
Fix Mistral model (#1220)
parent
bb1ba58f
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
27 additions
and
14 deletions
+27
-14
vllm/model_executor/models/mistral.py
vllm/model_executor/models/mistral.py
+1
-1
vllm/transformers_utils/config.py
vllm/transformers_utils/config.py
+9
-0
vllm/transformers_utils/configs/__init__.py
vllm/transformers_utils/configs/__init__.py
+2
-0
vllm/worker/worker.py
vllm/worker/worker.py
+15
-13
No files found.
vllm/model_executor/models/mistral.py
View file @
a8e98aee
...
@@ -29,7 +29,6 @@ from typing import List, Optional, Tuple
...
@@ -29,7 +29,6 @@ from typing import List, Optional, Tuple
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
vllm.transformers_utils.configs.mistral
import
MistralConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
...
@@ -46,6 +45,7 @@ from vllm.model_executor.weight_utils import (
...
@@ -46,6 +45,7 @@ from vllm.model_executor.weight_utils import (
convert_pyslice_to_tensor
,
hf_model_weights_iterator
,
convert_pyslice_to_tensor
,
hf_model_weights_iterator
,
load_tensor_parallel_weights
,
load_padded_tensor_parallel_vocab
)
load_tensor_parallel_weights
,
load_padded_tensor_parallel_vocab
)
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
from
vllm.transformers_utils.configs.mistral
import
MistralConfig
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
...
...
vllm/transformers_utils/config.py
View file @
a8e98aee
...
@@ -17,6 +17,15 @@ _CONFIG_REGISTRY = {
...
@@ -17,6 +17,15 @@ _CONFIG_REGISTRY = {
def
get_config
(
model
:
str
,
def
get_config
(
model
:
str
,
trust_remote_code
:
bool
,
trust_remote_code
:
bool
,
revision
:
Optional
[
str
]
=
None
)
->
PretrainedConfig
:
revision
:
Optional
[
str
]
=
None
)
->
PretrainedConfig
:
# NOTE: Because the Mistral model in HF hub does not have
# `configuration_mistral.py`, we cannot use `AutoConfig` to load the
# config. Instead, we use `MistralConfig` directly.
# NOTE: This is a hack. This does not work for local models.
# FIXME: Remove this once the Mistral model is available in the stable
# version of HF transformers.
if
"mistral"
in
model
.
lower
():
return
MistralConfig
.
from_pretrained
(
model
,
revision
=
revision
)
try
:
try
:
config
=
AutoConfig
.
from_pretrained
(
config
=
AutoConfig
.
from_pretrained
(
model
,
trust_remote_code
=
trust_remote_code
,
revision
=
revision
)
model
,
trust_remote_code
=
trust_remote_code
,
revision
=
revision
)
...
...
vllm/transformers_utils/configs/__init__.py
View file @
a8e98aee
...
@@ -6,6 +6,7 @@ from vllm.transformers_utils.configs.qwen import QWenConfig
...
@@ -6,6 +6,7 @@ from vllm.transformers_utils.configs.qwen import QWenConfig
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
# `FalconConfig` class from the official HuggingFace transformers library.
# `FalconConfig` class from the official HuggingFace transformers library.
from
vllm.transformers_utils.configs.falcon
import
RWConfig
from
vllm.transformers_utils.configs.falcon
import
RWConfig
from
vllm.transformers_utils.configs.mistral
import
MistralConfig
__all__
=
[
__all__
=
[
"MPTConfig"
,
"MPTConfig"
,
...
@@ -13,4 +14,5 @@ __all__ = [
...
@@ -13,4 +14,5 @@ __all__ = [
"AquilaConfig"
,
"AquilaConfig"
,
"QWenConfig"
,
"QWenConfig"
,
"RWConfig"
,
"RWConfig"
,
"MistralConfig"
,
]
]
vllm/worker/worker.py
View file @
a8e98aee
...
@@ -42,6 +42,7 @@ class Worker:
...
@@ -42,6 +42,7 @@ class Worker:
# self.init_cache_engine().
# self.init_cache_engine().
self
.
cache_config
=
None
self
.
cache_config
=
None
self
.
block_size
=
None
self
.
block_size
=
None
self
.
sliding_window
=
None
self
.
cache_engine
=
None
self
.
cache_engine
=
None
self
.
cache_events
=
None
self
.
cache_events
=
None
self
.
gpu_cache
=
None
self
.
gpu_cache
=
None
...
@@ -136,10 +137,13 @@ class Worker:
...
@@ -136,10 +137,13 @@ class Worker:
def
init_cache_engine
(
self
,
cache_config
:
CacheConfig
)
->
None
:
def
init_cache_engine
(
self
,
cache_config
:
CacheConfig
)
->
None
:
self
.
cache_config
=
cache_config
self
.
cache_config
=
cache_config
self
.
block_size
=
cache_config
.
block_size
self
.
block_size
=
cache_config
.
block_size
self
.
sliding_window
=
cache_config
.
sliding_window
max_seq_len
=
min
(
self
.
scheduler_config
.
max_model_len
,
if
self
.
sliding_window
is
None
:
cache_config
.
sliding_window
or
float
(
"inf"
))
max_seq_len
=
self
.
scheduler_config
.
max_model_len
else
:
max_seq_len
=
min
(
self
.
scheduler_config
.
max_model_len
,
self
.
sliding_window
)
_check_if_can_support_max_seq_len
(
max_seq_len
,
self
.
block_size
)
_check_if_can_support_max_seq_len
(
max_seq_len
,
self
.
block_size
)
self
.
cache_engine
=
CacheEngine
(
self
.
cache_config
,
self
.
model_config
,
self
.
cache_engine
=
CacheEngine
(
self
.
cache_config
,
self
.
model_config
,
...
@@ -151,6 +155,8 @@ class Worker:
...
@@ -151,6 +155,8 @@ class Worker:
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
InputMetadata
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
InputMetadata
]:
assert
self
.
block_size
is
not
None
seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]]
=
[]
seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]]
=
[]
input_tokens
:
List
[
int
]
=
[]
input_tokens
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
...
@@ -193,9 +199,6 @@ class Worker:
...
@@ -193,9 +199,6 @@ class Worker:
slot
=
block_number
*
self
.
block_size
+
block_offset
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
.
append
(
slot
)
slot_mapping
.
append
(
slot
)
sliding_window
=
getattr
(
self
.
model_config
.
hf_config
,
"sliding_window"
,
float
(
"inf"
))
# Add generation tokens.
# Add generation tokens.
max_context_len
=
0
max_context_len
=
0
max_num_blocks_per_seq
=
0
max_num_blocks_per_seq
=
0
...
@@ -216,8 +219,8 @@ class Worker:
...
@@ -216,8 +219,8 @@ class Worker:
context_len
=
seq_data
.
get_len
()
context_len
=
seq_data
.
get_len
()
position
=
context_len
-
1
position
=
context_len
-
1
if
sliding_window
:
if
self
.
sliding_window
is
not
None
:
context_len
=
min
(
context_len
,
sliding_window
)
context_len
=
min
(
context_len
,
self
.
sliding_window
)
input_positions
.
append
(
position
)
input_positions
.
append
(
position
)
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
...
@@ -232,10 +235,9 @@ class Worker:
...
@@ -232,10 +235,9 @@ class Worker:
slot
=
block_number
*
self
.
block_size
+
block_offset
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
.
append
(
slot
)
slot_mapping
.
append
(
slot
)
if
sliding_window
:
if
self
.
sliding_window
is
not
None
:
assert
self
.
cache_config
is
not
None
sliding_window_blocks
=
(
self
.
sliding_window
//
sliding_window_blocks
=
(
sliding_window
//
self
.
block_size
)
self
.
cache_config
.
block_size
)
block_table
=
block_table
[
-
sliding_window_blocks
:]
block_table
=
block_table
[
-
sliding_window_blocks
:]
generation_block_tables
.
append
(
block_table
)
generation_block_tables
.
append
(
block_table
)
...
@@ -277,7 +279,7 @@ class Worker:
...
@@ -277,7 +279,7 @@ class Worker:
context_lens
=
context_lens_tensor
,
context_lens
=
context_lens_tensor
,
max_context_len
=
max_context_len
,
max_context_len
=
max_context_len
,
block_tables
=
block_tables_tensor
,
block_tables
=
block_tables_tensor
,
sliding_window
=
sliding_window
,
sliding_window
=
self
.
sliding_window
,
)
)
return
tokens_tensor
,
positions_tensor
,
input_metadata
return
tokens_tensor
,
positions_tensor
,
input_metadata
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment