Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
SIYIXNI
vllm
Commits
a8e98aee
"...composable_kernel_rocm.git" did not exist on "690c75a7eb7012bf0fd6fb3f6e129e83fbcbdb53"
Unverified
Commit
a8e98aee
authored
Sep 28, 2023
by
Woosuk Kwon
Committed by
GitHub
Sep 28, 2023
Browse files
Fix Mistral model (#1220)
parent
bb1ba58f
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
27 additions
and
14 deletions
+27
-14
vllm/model_executor/models/mistral.py
vllm/model_executor/models/mistral.py
+1
-1
vllm/transformers_utils/config.py
vllm/transformers_utils/config.py
+9
-0
vllm/transformers_utils/configs/__init__.py
vllm/transformers_utils/configs/__init__.py
+2
-0
vllm/worker/worker.py
vllm/worker/worker.py
+15
-13
No files found.
vllm/model_executor/models/mistral.py
View file @
a8e98aee
...
@@ -29,7 +29,6 @@ from typing import List, Optional, Tuple
...
@@ -29,7 +29,6 @@ from typing import List, Optional, Tuple
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
vllm.transformers_utils.configs.mistral
import
MistralConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
...
@@ -46,6 +45,7 @@ from vllm.model_executor.weight_utils import (
...
@@ -46,6 +45,7 @@ from vllm.model_executor.weight_utils import (
convert_pyslice_to_tensor
,
hf_model_weights_iterator
,
convert_pyslice_to_tensor
,
hf_model_weights_iterator
,
load_tensor_parallel_weights
,
load_padded_tensor_parallel_vocab
)
load_tensor_parallel_weights
,
load_padded_tensor_parallel_vocab
)
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
from
vllm.transformers_utils.configs.mistral
import
MistralConfig
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
...
...
vllm/transformers_utils/config.py
View file @
a8e98aee
...
@@ -17,6 +17,15 @@ _CONFIG_REGISTRY = {
...
@@ -17,6 +17,15 @@ _CONFIG_REGISTRY = {
def
get_config
(
model
:
str
,
def
get_config
(
model
:
str
,
trust_remote_code
:
bool
,
trust_remote_code
:
bool
,
revision
:
Optional
[
str
]
=
None
)
->
PretrainedConfig
:
revision
:
Optional
[
str
]
=
None
)
->
PretrainedConfig
:
# NOTE: Because the Mistral model in HF hub does not have
# `configuration_mistral.py`, we cannot use `AutoConfig` to load the
# config. Instead, we use `MistralConfig` directly.
# NOTE: This is a hack. This does not work for local models.
# FIXME: Remove this once the Mistral model is available in the stable
# version of HF transformers.
if
"mistral"
in
model
.
lower
():
return
MistralConfig
.
from_pretrained
(
model
,
revision
=
revision
)
try
:
try
:
config
=
AutoConfig
.
from_pretrained
(
config
=
AutoConfig
.
from_pretrained
(
model
,
trust_remote_code
=
trust_remote_code
,
revision
=
revision
)
model
,
trust_remote_code
=
trust_remote_code
,
revision
=
revision
)
...
...
vllm/transformers_utils/configs/__init__.py
View file @
a8e98aee
...
@@ -6,6 +6,7 @@ from vllm.transformers_utils.configs.qwen import QWenConfig
...
@@ -6,6 +6,7 @@ from vllm.transformers_utils.configs.qwen import QWenConfig
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
# `FalconConfig` class from the official HuggingFace transformers library.
# `FalconConfig` class from the official HuggingFace transformers library.
from
vllm.transformers_utils.configs.falcon
import
RWConfig
from
vllm.transformers_utils.configs.falcon
import
RWConfig
from
vllm.transformers_utils.configs.mistral
import
MistralConfig
__all__
=
[
__all__
=
[
"MPTConfig"
,
"MPTConfig"
,
...
@@ -13,4 +14,5 @@ __all__ = [
...
@@ -13,4 +14,5 @@ __all__ = [
"AquilaConfig"
,
"AquilaConfig"
,
"QWenConfig"
,
"QWenConfig"
,
"RWConfig"
,
"RWConfig"
,
"MistralConfig"
,
]
]
vllm/worker/worker.py
View file @
a8e98aee
...
@@ -42,6 +42,7 @@ class Worker:
...
@@ -42,6 +42,7 @@ class Worker:
# self.init_cache_engine().
# self.init_cache_engine().
self
.
cache_config
=
None
self
.
cache_config
=
None
self
.
block_size
=
None
self
.
block_size
=
None
self
.
sliding_window
=
None
self
.
cache_engine
=
None
self
.
cache_engine
=
None
self
.
cache_events
=
None
self
.
cache_events
=
None
self
.
gpu_cache
=
None
self
.
gpu_cache
=
None
...
@@ -136,10 +137,13 @@ class Worker:
...
@@ -136,10 +137,13 @@ class Worker:
def
init_cache_engine
(
self
,
cache_config
:
CacheConfig
)
->
None
:
def
init_cache_engine
(
self
,
cache_config
:
CacheConfig
)
->
None
:
self
.
cache_config
=
cache_config
self
.
cache_config
=
cache_config
self
.
block_size
=
cache_config
.
block_size
self
.
block_size
=
cache_config
.
block_size
self
.
sliding_window
=
cache_config
.
sliding_window
max_seq_len
=
min
(
self
.
scheduler_config
.
max_model_len
,
if
self
.
sliding_window
is
None
:
cache_config
.
sliding_window
or
float
(
"inf"
))
max_seq_len
=
self
.
scheduler_config
.
max_model_len
else
:
max_seq_len
=
min
(
self
.
scheduler_config
.
max_model_len
,
self
.
sliding_window
)
_check_if_can_support_max_seq_len
(
max_seq_len
,
self
.
block_size
)
_check_if_can_support_max_seq_len
(
max_seq_len
,
self
.
block_size
)
self
.
cache_engine
=
CacheEngine
(
self
.
cache_config
,
self
.
model_config
,
self
.
cache_engine
=
CacheEngine
(
self
.
cache_config
,
self
.
model_config
,
...
@@ -151,6 +155,8 @@ class Worker:
...
@@ -151,6 +155,8 @@ class Worker:
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
InputMetadata
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
InputMetadata
]:
assert
self
.
block_size
is
not
None
seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]]
=
[]
seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]]
=
[]
input_tokens
:
List
[
int
]
=
[]
input_tokens
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
...
@@ -193,9 +199,6 @@ class Worker:
...
@@ -193,9 +199,6 @@ class Worker:
slot
=
block_number
*
self
.
block_size
+
block_offset
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
.
append
(
slot
)
slot_mapping
.
append
(
slot
)
sliding_window
=
getattr
(
self
.
model_config
.
hf_config
,
"sliding_window"
,
float
(
"inf"
))
# Add generation tokens.
# Add generation tokens.
max_context_len
=
0
max_context_len
=
0
max_num_blocks_per_seq
=
0
max_num_blocks_per_seq
=
0
...
@@ -216,8 +219,8 @@ class Worker:
...
@@ -216,8 +219,8 @@ class Worker:
context_len
=
seq_data
.
get_len
()
context_len
=
seq_data
.
get_len
()
position
=
context_len
-
1
position
=
context_len
-
1
if
sliding_window
:
if
self
.
sliding_window
is
not
None
:
context_len
=
min
(
context_len
,
sliding_window
)
context_len
=
min
(
context_len
,
self
.
sliding_window
)
input_positions
.
append
(
position
)
input_positions
.
append
(
position
)
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
...
@@ -232,10 +235,9 @@ class Worker:
...
@@ -232,10 +235,9 @@ class Worker:
slot
=
block_number
*
self
.
block_size
+
block_offset
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
.
append
(
slot
)
slot_mapping
.
append
(
slot
)
if
sliding_window
:
if
self
.
sliding_window
is
not
None
:
assert
self
.
cache_config
is
not
None
sliding_window_blocks
=
(
self
.
sliding_window
//
sliding_window_blocks
=
(
sliding_window
//
self
.
block_size
)
self
.
cache_config
.
block_size
)
block_table
=
block_table
[
-
sliding_window_blocks
:]
block_table
=
block_table
[
-
sliding_window_blocks
:]
generation_block_tables
.
append
(
block_table
)
generation_block_tables
.
append
(
block_table
)
...
@@ -277,7 +279,7 @@ class Worker:
...
@@ -277,7 +279,7 @@ class Worker:
context_lens
=
context_lens_tensor
,
context_lens
=
context_lens_tensor
,
max_context_len
=
max_context_len
,
max_context_len
=
max_context_len
,
block_tables
=
block_tables_tensor
,
block_tables
=
block_tables_tensor
,
sliding_window
=
sliding_window
,
sliding_window
=
self
.
sliding_window
,
)
)
return
tokens_tensor
,
positions_tensor
,
input_metadata
return
tokens_tensor
,
positions_tensor
,
input_metadata
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment