Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6d2051cc
Commit
6d2051cc
authored
Oct 21, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.6.3.post1' into v0.6.3.post1-dev
parents
2c7f740a
a2c71c54
Changes
457
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
830 additions
and
455 deletions
+830
-455
vllm/inputs/preprocess.py
vllm/inputs/preprocess.py
+117
-81
vllm/inputs/registry.py
vllm/inputs/registry.py
+24
-22
vllm/lora/layers.py
vllm/lora/layers.py
+3
-0
vllm/lora/models.py
vllm/lora/models.py
+49
-6
vllm/lora/utils.py
vllm/lora/utils.py
+34
-1
vllm/model_executor/custom_op.py
vllm/model_executor/custom_op.py
+2
-1
vllm/model_executor/guided_decoding/__init__.py
vllm/model_executor/guided_decoding/__init__.py
+18
-50
vllm/model_executor/guided_decoding/guided_fields.py
vllm/model_executor/guided_decoding/guided_fields.py
+1
-0
vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py
...l_executor/guided_decoding/lm_format_enforcer_decoding.py
+16
-74
vllm/model_executor/guided_decoding/outlines_decoding.py
vllm/model_executor/guided_decoding/outlines_decoding.py
+19
-53
vllm/model_executor/layers/activation.py
vllm/model_executor/layers/activation.py
+27
-0
vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_L40S.json
...fused_moe/configs/E=8,N=3584,device_name=NVIDIA_L40S.json
+173
-0
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
+111
-33
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+3
-0
vllm/model_executor/layers/layernorm.py
vllm/model_executor/layers/layernorm.py
+30
-2
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+58
-46
vllm/model_executor/layers/mamba/ops/causal_conv1d.py
vllm/model_executor/layers/mamba/ops/causal_conv1d.py
+63
-53
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
+73
-33
vllm/model_executor/layers/pooler.py
vllm/model_executor/layers/pooler.py
+7
-0
vllm/model_executor/layers/quantization/__init__.py
vllm/model_executor/layers/quantization/__init__.py
+2
-0
No files found.
Too many changes to show.
To preserve performance only
457 of 457+
files are displayed.
Plain diff
Email patch
vllm/inputs/preprocess.py
View file @
6d2051cc
import
asyncio
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing_extensions
import
assert_never
...
...
@@ -8,9 +8,10 @@ from vllm.logger import init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.transformers_utils.tokenizer_group
import
BaseTokenizerGroup
from
vllm.utils
import
print_warning_once
from
.data
import
(
Encoder
Decoder
LLM
Inputs
,
LLM
Inputs
,
Prompt
Inputs
,
SingletonPrompt
Inputs
)
from
.data
import
(
Decoder
Only
Inputs
,
EncoderDecoder
Inputs
,
Prompt
Type
,
SingletonPrompt
)
from
.parse
import
is_explicit_encoder_decoder_prompt
,
parse_singleton_prompt
if
TYPE_CHECKING
:
...
...
@@ -19,9 +20,11 @@ if TYPE_CHECKING:
logger
=
init_logger
(
__name__
)
PromptComponents
=
Tuple
[
Optional
[
str
],
List
[
int
],
Optional
[
"MultiModalDataDict"
]]
Optional
[
"MultiModalDataDict"
],
Optional
[
Dict
[
str
,
Any
]]]
DecoderPromptComponents
=
Tuple
[
Optional
[
str
],
Optional
[
List
[
int
]],
Optional
[
"MultiModalDataDict"
]]
Optional
[
"MultiModalDataDict"
],
Optional
[
Dict
[
str
,
Any
]]]
class
InputPreprocessor
:
...
...
@@ -71,20 +74,21 @@ class InputPreprocessor:
'''
if
not
self
.
is_encoder_decoder_model
():
logger
.
warning
(
"Using None for decoder start token id because "
"this is not an encoder/decoder model."
)
print_
warning
_once
(
"Using None for decoder start token id because "
"this is not an encoder/decoder model."
)
return
None
if
(
self
.
model_config
is
None
or
self
.
model_config
.
hf_config
is
None
):
logger
.
warning
(
"Using None for decoder start token id because "
"model config is not available."
)
print_
warning
_once
(
"Using None for decoder start token id because "
"model config is not available."
)
return
None
dec_start_token_id
=
getattr
(
self
.
model_config
.
hf_config
,
'decoder_start_token_id'
,
None
)
if
dec_start_token_id
is
None
:
logger
.
warning
(
"Falling back on <BOS> for decoder start token id "
"because decoder start token id is not available."
)
print_warning_once
(
"Falling back on <BOS> for decoder start token "
"id because decoder start token id is not "
"available."
)
dec_start_token_id
=
self
.
get_bos_token_id
()
return
dec_start_token_id
...
...
@@ -207,7 +211,7 @@ class InputPreprocessor:
def
_extract_prompt_components
(
self
,
inputs
:
SingletonPrompt
Inputs
,
prompt
:
SingletonPrompt
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
)
->
PromptComponents
:
...
...
@@ -217,7 +221,7 @@ class InputPreprocessor:
Arguments:
* request_id
*
inputs
: single encoder or decoder input prompt
*
prompt
: single encoder or decoder input prompt
* lora_request: this is only valid for decoder prompts
Returns:
...
...
@@ -225,77 +229,89 @@ class InputPreprocessor:
* prompt
* prompt_token_ids
* multi_modal_data
* mm_processor_kwargs (request-level input processor/mapper overrides)
'''
parsed
=
parse_singleton_prompt
(
inputs
)
parsed
=
parse_singleton_prompt
(
prompt
)
if
parsed
[
"type"
]
==
"str"
:
prompt
=
parsed
[
"content"
]
prompt
_text
=
parsed
[
"content"
]
prompt_token_ids
=
self
.
_tokenize_prompt
(
prompt
,
prompt
_text
,
request_id
=
request_id
,
lora_request
=
lora_request
,
)
multi_modal_data
=
None
mm_processor_kwargs
=
None
elif
parsed
[
"type"
]
==
"tokens"
:
prompt
=
None
prompt
_text
=
None
prompt_token_ids
=
parsed
[
"content"
][
"prompt_token_ids"
]
multi_modal_data
=
parsed
[
"content"
].
get
(
"multi_modal_data"
)
mm_processor_kwargs
=
parsed
[
"content"
].
get
(
"mm_processor_kwargs"
)
elif
parsed
[
"type"
]
==
"text"
:
prompt
=
parsed
[
"content"
][
"prompt"
]
prompt
_text
=
parsed
[
"content"
][
"prompt"
]
prompt_token_ids
=
self
.
_tokenize_prompt
(
prompt
,
prompt
_text
,
request_id
=
request_id
,
lora_request
=
lora_request
,
)
multi_modal_data
=
parsed
[
"content"
].
get
(
"multi_modal_data"
)
mm_processor_kwargs
=
parsed
[
"content"
].
get
(
"mm_processor_kwargs"
)
else
:
assert_never
(
parsed
)
return
prompt
,
prompt_token_ids
,
multi_modal_data
return
(
prompt_text
,
prompt_token_ids
,
multi_modal_data
,
mm_processor_kwargs
)
async
def
_extract_prompt_components_async
(
self
,
inputs
:
SingletonPrompt
Inputs
,
prompt
:
SingletonPrompt
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
)
->
PromptComponents
:
"""Async version of :meth:`_extract_prompt_components`."""
parsed
=
parse_singleton_prompt
(
inputs
)
parsed
=
parse_singleton_prompt
(
prompt
)
if
parsed
[
"type"
]
==
"str"
:
prompt
=
parsed
[
"content"
]
prompt
_text
=
parsed
[
"content"
]
prompt_token_ids
=
await
self
.
_tokenize_prompt_async
(
prompt
,
prompt
_text
,
request_id
=
request_id
,
lora_request
=
lora_request
,
)
multi_modal_data
=
None
mm_processor_kwargs
=
None
elif
parsed
[
"type"
]
==
"tokens"
:
prompt
=
None
prompt
_text
=
None
prompt_token_ids
=
parsed
[
"content"
][
"prompt_token_ids"
]
multi_modal_data
=
parsed
[
"content"
].
get
(
"multi_modal_data"
)
mm_processor_kwargs
=
parsed
[
"content"
].
get
(
"mm_processor_kwargs"
)
elif
parsed
[
"type"
]
==
"text"
:
prompt
=
parsed
[
"content"
][
"prompt"
]
prompt
_text
=
parsed
[
"content"
][
"prompt"
]
prompt_token_ids
=
await
self
.
_tokenize_prompt_async
(
prompt
,
prompt
_text
,
request_id
=
request_id
,
lora_request
=
lora_request
,
)
multi_modal_data
=
parsed
[
"content"
].
get
(
"multi_modal_data"
)
mm_processor_kwargs
=
parsed
[
"content"
].
get
(
"mm_processor_kwargs"
)
else
:
assert_never
(
parsed
)
return
prompt
,
prompt_token_ids
,
multi_modal_data
return
(
prompt_text
,
prompt_token_ids
,
multi_modal_data
,
mm_processor_kwargs
)
def
_build_enc_dec_llm_inputs
(
self
,
encoder_comps
:
PromptComponents
,
decoder_comps
:
DecoderPromptComponents
,
)
->
EncoderDecoderLLMInputs
:
encoder_prompt
,
encoder_prompt_ids
,
encoder_mm_data
=
encoder_comps
decoder_prompt
,
decoder_prompt_ids
,
decoder_mm_data
=
decoder_comps
mm_processor_kwargs
:
Dict
[
str
,
Any
],
)
->
EncoderDecoderInputs
:
encoder_prompt
,
encoder_prompt_ids
,
encoder_mm_data
,
_
=
encoder_comps
decoder_prompt
,
decoder_prompt_ids
,
decoder_mm_data
,
_
=
decoder_comps
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
# If the feature combo become valid
if
decoder_mm_data
is
not
None
:
raise
ValueError
(
"Multi-modality decoder inputs of encoder-decoder models are "
...
...
@@ -308,10 +324,11 @@ class InputPreprocessor:
decoder_prompt_ids
,
force_bos
=
(
encoder_mm_data
is
None
and
decoder_mm_data
is
None
)))
return
EncoderDecoder
LLM
Inputs
(
return
EncoderDecoderInputs
(
prompt_token_ids
=
decoder_prompt_ids
,
prompt
=
decoder_prompt
,
multi_modal_data
=
decoder_mm_data
,
mm_processor_kwargs
=
mm_processor_kwargs
,
encoder_prompt_token_ids
=
encoder_prompt_ids
,
encoder_prompt
=
encoder_prompt
,
encoder_multi_modal_data
=
encoder_mm_data
,
...
...
@@ -319,13 +336,13 @@ class InputPreprocessor:
def
_process_encoder_decoder_prompt
(
self
,
inputs
:
Prompt
Inputs
,
prompt
:
Prompt
Type
,
request_id
:
str
,
)
->
EncoderDecoder
LLM
Inputs
:
)
->
EncoderDecoderInputs
:
'''
For encoder/decoder models only:
Process an input prompt into an
:class:`EncoderDecoder
LLM
Inputs` instance.
:class:`EncoderDecoderInputs` instance.
There are two types of input prompts:
singleton prompts which carry only the
...
...
@@ -347,58 +364,67 @@ class InputPreprocessor:
Arguments:
*
inputs
: an input prompt
*
prompt
: an input prompt
* request_id
Returns:
* :class:`EncoderDecoder
LLM
Inputs` instance
* :class:`EncoderDecoderInputs` instance
'''
encoder_comps
:
PromptComponents
decoder_comps
:
DecoderPromptComponents
if
is_explicit_encoder_decoder_prompt
(
inputs
):
if
is_explicit_encoder_decoder_prompt
(
prompt
):
encoder_comps
=
self
.
_extract_prompt_components
(
inputs
[
"encoder_prompt"
],
prompt
[
"encoder_prompt"
],
request_id
=
request_id
,
)
if
(
decoder_input
:
=
inputs
[
"decoder_prompt"
])
is
None
:
decoder_comps
=
None
,
None
,
None
if
(
decoder_input
:
=
prompt
[
"decoder_prompt"
])
is
None
:
decoder_comps
=
None
,
None
,
None
,
None
else
:
decoder_comps
=
self
.
_extract_prompt_components
(
decoder_input
,
request_id
=
request_id
,
)
# Handle this carefully in case it was directly initialized by user
mm_processor_kwargs
=
prompt
.
get
(
"mm_processor_kwargs"
,
{})
else
:
encoder_comps
=
self
.
_extract_prompt_components
(
inputs
,
prompt
,
request_id
=
request_id
,
)
decoder_comps
=
None
,
None
,
None
return
self
.
_build_enc_dec_llm_inputs
(
encoder_comps
,
decoder_comps
)
# If there are no decoder components, we assume the
# mm_processor_kwargs are in the encoder prompt
mm_processor_kwargs
=
encoder_comps
[
-
1
]
if
encoder_comps
[
-
1
]
is
not
None
else
{}
decoder_comps
=
None
,
None
,
None
,
None
return
self
.
_build_enc_dec_llm_inputs
(
encoder_comps
,
decoder_comps
,
mm_processor_kwargs
,
)
async
def
_process_encoder_decoder_prompt_async
(
self
,
inputs
:
Prompt
Inputs
,
prompt
:
Prompt
Type
,
request_id
:
str
,
)
->
EncoderDecoder
LLM
Inputs
:
)
->
EncoderDecoderInputs
:
"""Async version of :meth:`_process_encoder_decoder_prompt`."""
encoder_comps
:
PromptComponents
decoder_comps
:
DecoderPromptComponents
if
is_explicit_encoder_decoder_prompt
(
inputs
):
if
is_explicit_encoder_decoder_prompt
(
prompt
):
encoder_task
=
self
.
_extract_prompt_components_async
(
inputs
[
"encoder_prompt"
],
prompt
[
"encoder_prompt"
],
request_id
=
request_id
,
)
if
(
decoder_input
:
=
inputs
[
"decoder_prompt"
])
is
None
:
if
(
decoder_input
:
=
prompt
[
"decoder_prompt"
])
is
None
:
encoder_comps
=
await
encoder_task
decoder_comps
=
None
,
None
,
None
decoder_comps
=
None
,
None
,
None
,
None
else
:
decoder_task
=
self
.
_extract_prompt_components_async
(
decoder_input
,
...
...
@@ -407,55 +433,65 @@ class InputPreprocessor:
encoder_comps
,
decoder_comps
=
await
asyncio
.
gather
(
encoder_task
,
decoder_task
)
mm_processor_kwargs
=
prompt
[
"mm_processor_kwargs"
]
else
:
encoder_comps
=
await
self
.
_extract_prompt_components_async
(
inputs
,
prompt
,
request_id
=
request_id
,
)
decoder_comps
=
None
,
None
,
None
return
self
.
_build_enc_dec_llm_inputs
(
encoder_comps
,
decoder_comps
)
# If there are no decoder components, we assume the
# mm_processor_kwargs are in the encoder prompt
mm_processor_kwargs
=
encoder_comps
[
-
1
]
if
encoder_comps
[
-
1
]
is
not
None
else
{}
decoder_comps
=
None
,
None
,
None
,
None
return
self
.
_build_enc_dec_llm_inputs
(
encoder_comps
,
decoder_comps
,
mm_processor_kwargs
,
)
def
_build_decoder_only_llm_inputs
(
self
,
prompt_comps
:
PromptComponents
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
],
)
->
LLMInputs
:
prompt
,
prompt_token_ids
,
multi_modal_data
=
prompt_comps
)
->
DecoderOnlyInputs
:
(
prompt
,
prompt_token_ids
,
multi_modal_data
,
mm_processor_kwargs
)
=
prompt_comps
prompt_token_ids
=
self
.
_apply_prompt_adapter
(
prompt_token_ids
,
prompt_adapter_request
=
prompt_adapter_request
)
return
LLMInputs
(
prompt_token_ids
=
prompt_token_ids
,
prompt
=
prompt
,
multi_modal_data
=
multi_modal_data
)
return
DecoderOnlyInputs
(
prompt_token_ids
=
prompt_token_ids
,
prompt
=
prompt
,
multi_modal_data
=
multi_modal_data
,
mm_processor_kwargs
=
mm_processor_kwargs
)
def
_process_decoder_only_prompt
(
self
,
inputs
:
SingletonPrompt
Inputs
,
prompt
:
SingletonPrompt
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
LLM
Inputs
:
)
->
DecoderOnly
Inputs
:
'''
For decoder-only models:
Process an input prompt into an :class:`
LLM
Inputs` instance.
Process an input prompt into an :class:`
DecoderOnly
Inputs` instance.
Arguments:
*
inputs
: input prompt
*
prompt
: input prompt
* request_id
* lora_request
* prompt_adapter_request
Returns:
* :class:`
LLM
Inputs` instance
* :class:`
DecoderOnly
Inputs` instance
'''
prompt_comps
=
self
.
_extract_prompt_components
(
inputs
,
prompt
,
request_id
=
request_id
,
lora_request
=
lora_request
,
)
...
...
@@ -467,14 +503,14 @@ class InputPreprocessor:
async
def
_process_decoder_only_prompt_async
(
self
,
inputs
:
SingletonPrompt
Inputs
,
prompt
:
SingletonPrompt
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
LLM
Inputs
:
)
->
DecoderOnly
Inputs
:
"""Async version of :meth:`_process_decoder_only_prompt`."""
prompt_comps
=
await
self
.
_extract_prompt_components_async
(
inputs
,
prompt
,
request_id
=
request_id
,
lora_request
=
lora_request
,
)
...
...
@@ -486,27 +522,27 @@ class InputPreprocessor:
def
preprocess
(
self
,
inputs
:
Prompt
Inputs
,
prompt
:
Prompt
Type
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
Union
[
LLM
Inputs
,
EncoderDecoder
LLM
Inputs
]:
)
->
Union
[
DecoderOnly
Inputs
,
EncoderDecoderInputs
]:
"""Preprocess the input prompt."""
if
self
.
is_encoder_decoder_model
():
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder
return
self
.
_process_encoder_decoder_prompt
(
inputs
,
prompt
,
request_id
=
request_id
,
)
if
is_explicit_encoder_decoder_prompt
(
inputs
):
if
is_explicit_encoder_decoder_prompt
(
prompt
):
raise
ValueError
(
"Cannot pass encoder-decoder prompt "
"to decoder-only models"
)
# Decoder-only operation
return
self
.
_process_decoder_only_prompt
(
inputs
,
prompt
,
request_id
=
request_id
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
...
...
@@ -514,27 +550,27 @@ class InputPreprocessor:
async
def
preprocess_async
(
self
,
inputs
:
Prompt
Inputs
,
prompt
:
Prompt
Type
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
Union
[
LLM
Inputs
,
EncoderDecoder
LLM
Inputs
]:
)
->
Union
[
DecoderOnly
Inputs
,
EncoderDecoderInputs
]:
"""Async version of :meth:`preprocess`."""
if
self
.
is_encoder_decoder_model
():
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder
return
await
self
.
_process_encoder_decoder_prompt_async
(
inputs
,
prompt
,
request_id
=
request_id
,
)
if
is_explicit_encoder_decoder_prompt
(
inputs
):
if
is_explicit_encoder_decoder_prompt
(
prompt
):
raise
ValueError
(
"Cannot pass encoder-decoder prompt "
"to decoder-only models"
)
# Decoder-only operation
return
await
self
.
_process_decoder_only_prompt_async
(
inputs
,
prompt
,
request_id
=
request_id
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
...
...
vllm/inputs/registry.py
View file @
6d2051cc
...
...
@@ -9,9 +9,10 @@ from transformers import PretrainedConfig
from
typing_extensions
import
TypeVar
from
vllm.logger
import
init_logger
from
vllm.utils
import
get_allowed_kwarg_only_overrides
from
vllm.utils
import
(
get_allowed_kwarg_only_overrides
,
print_warning_once
,
resolve_mm_processor_kwargs
)
from
.data
import
LLM
Inputs
from
.data
import
DecoderOnly
Inputs
if
TYPE_CHECKING
:
from
vllm.config
import
ModelConfig
...
...
@@ -99,7 +100,7 @@ class _MultiModalCounts(UserDict):
raise
KeyError
(
msg
)
from
exc
InputProcessor
=
Callable
[[
InputContext
,
LLMInputs
],
LLM
Inputs
]
InputProcessor
=
Callable
[[
InputContext
,
DecoderOnlyInputs
],
DecoderOnly
Inputs
]
"""Preprocess the inputs to the model."""
...
...
@@ -133,7 +134,7 @@ class InputRegistry:
# Avoid circular import
from
vllm.sequence
import
SequenceData
dummy_seq_data
=
SequenceData
.
from_token_counts
((
0
,
seq_len
))
dummy_seq_data
=
SequenceData
.
from_
prompt_
token_counts
((
0
,
seq_len
))
dummy_multi_modal_data
=
None
return
dummy_seq_data
,
dummy_multi_modal_data
...
...
@@ -185,16 +186,8 @@ class InputRegistry:
return
wrapper
def
_get_dummy_encoder_data_factory
(
self
,
model_cls
:
Type
[
nn
.
Module
]):
if
model_cls
in
self
.
_dummy_encoder_factories_by_model_type
:
dummy_factory
=
self
.
_dummy_encoder_factories_by_model_type
[
model_cls
]
else
:
logger
.
warning
(
"No dummy encoder data factory registered to %s. "
"Using the dummy data factory for the model instead."
,
model_cls
)
dummy_factory
=
self
.
_get_dummy_data_factory
(
model_cls
)
return
dummy_factory
return
self
.
_dummy_encoder_factories_by_model_type
\
.
get
(
model_cls
,
self
.
_default_dummy_data_factory
)
def
dummy_data_for_profiling
(
self
,
...
...
@@ -235,9 +228,9 @@ class InputRegistry:
num_tokens
=
seq_data
.
prompt_token_ids
if
len
(
num_tokens
)
<
seq_len
:
if
is_encoder_data
:
logger
.
warning
(
"Expected at least
%d
dummy encoder tokens for
profiling,
"
"but found %d tokens instead."
,
seq_len
,
len
(
num_tokens
)
)
print_
warning
_once
(
f
"Expected at least
{
seq_len
}
dummy encoder tokens for "
f
"profiling, but found
{
len
(
num_tokens
)
}
tokens instead."
)
else
:
raise
AssertionError
(
f
"Expected at least
{
seq_len
}
dummy tokens for profiling, "
...
...
@@ -252,8 +245,11 @@ class InputRegistry:
return
seq_data
,
mm_data
def
_default_input_processor
(
self
,
ctx
:
InputContext
,
inputs
:
LLMInputs
)
->
LLMInputs
:
def
_default_input_processor
(
self
,
ctx
:
InputContext
,
inputs
:
DecoderOnlyInputs
,
)
->
DecoderOnlyInputs
:
"""The default input processor is a no-op."""
return
inputs
...
...
@@ -286,7 +282,7 @@ class InputRegistry:
.
get
(
model_cls
,
self
.
_default_input_processor
)
def
process_input
(
self
,
model_config
:
"ModelConfig"
,
inputs
:
LLMInputs
)
->
LLM
Inputs
:
inputs
:
DecoderOnlyInputs
)
->
DecoderOnly
Inputs
:
"""
Apply an input processor to an instance of model inputs.
...
...
@@ -301,8 +297,14 @@ class InputRegistry:
model_cls
,
_
=
get_model_architecture
(
model_config
)
processor
=
self
.
_get_model_input_processor
(
model_cls
)
mm_processor_kwargs
=
get_allowed_kwarg_only_overrides
(
processor
,
overrides
=
model_config
.
mm_processor_kwargs
)
# Handle multimodal processor kwargs with priority:
# Inference kwargs -> Init kwargs -> {}
# If it's empty, it'll fall back to the default kwarg values
mm_processor_kwargs
=
resolve_mm_processor_kwargs
(
model_config
.
mm_processor_kwargs
,
inputs
.
get
(
"mm_processor_kwargs"
),
processor
,
)
return
processor
(
InputContext
(
model_config
),
inputs
,
**
mm_processor_kwargs
)
...
...
vllm/lora/layers.py
View file @
6d2051cc
...
...
@@ -39,6 +39,9 @@ def _get_lora_device(base_layer: nn.Module) -> torch.device:
# unquantizedLinear
if
hasattr
(
base_layer
,
"weight"
):
return
base_layer
.
weight
.
device
# Compressed Tensor
elif
hasattr
(
base_layer
,
"weight_packed"
):
return
base_layer
.
weight_packed
.
device
# GPTQ/AWQ
elif
hasattr
(
base_layer
,
"qweight"
):
return
base_layer
.
qweight
.
device
...
...
vllm/lora/models.py
View file @
6d2051cc
...
...
@@ -23,8 +23,10 @@ from vllm.lora.layers import (BaseLayerWithLoRA,
from
vllm.lora.lora
import
LoRALayerWeights
,
PackedLoRALayerWeights
from
vllm.lora.punica
import
PunicaWrapper
from
vllm.lora.utils
import
(
from_layer
,
from_layer_logits_processor
,
is_regex_target_modules
,
parse_fine_tuned_lora_name
,
replace_submodule
)
from
vllm.model_executor.models.interfaces
import
SupportsLoRA
from
vllm.model_executor.models
import
SupportsLoRA
,
supports_multimodal
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.model_executor.models.utils
import
PPMissingLayer
from
vllm.utils
import
is_pin_memory_available
...
...
@@ -232,6 +234,8 @@ class LoRAModel(AdapterModel):
# modules.
unexpected_modules
=
[]
target_modules
=
config
[
"target_modules"
]
if
not
isinstance
(
target_modules
,
list
):
target_modules
=
[
target_modules
]
for
module
in
target_modules
:
# Compatible with more modules,
# such as:layers.11.self_attn.k_proj
...
...
@@ -242,8 +246,8 @@ class LoRAModel(AdapterModel):
# expected_lora_modules. It is not reliable. See
# https://github.com/vllm-project/vllm/pull/5909. But there's no
# other better mechanism.
if
unexpected_modules
:
print
(
unexpected_modules
,
"
modules
"
)
if
unexpected_modules
and
not
is_regex_target_modules
(
config
[
"target_modules"
],
expected_lora_
modules
)
:
raise
ValueError
(
f
"While loading
{
lora_dir
}
, expected"
f
" target modules in
{
expected_lora_modules
}
"
...
...
@@ -332,6 +336,12 @@ class LoRAModelManager(AdapterModelManager):
self
.
supported_lora_modules
.
append
(
"rotary_emb"
)
self
.
packed_modules_mapping
=
copy
.
deepcopy
(
self
.
model
.
packed_modules_mapping
)
# Used to indicate whether the model is a multimodal model
self
.
supports_mm
:
bool
=
(
supports_multimodal
(
self
.
model
)
# In case the model only supports LoRA for
# text modules (e.g. ChatGLM)
and
hasattr
(
self
.
model
,
"get_mm_mapping"
))
self
.
packed_modules
:
Dict
[
str
,
List
[
str
]]
=
{}
self
.
modules
:
Dict
[
str
,
"BaseLayerWithLoRA"
]
=
{}
# Dict instead of a Set for compatibility with LRUCache.
...
...
@@ -437,12 +447,22 @@ class LoRAModelManager(AdapterModelManager):
continue
if
not
self
.
_match_target_modules
(
module_name
):
continue
# A temporary approach for multimodal models to support LoRA
# TODO: Remove this restriction
if
self
.
_filter_unsupported_mm_module
(
module_name
):
logger
.
warning
(
"Regarding multimodal models, vLLM currently only supports "
"adding LoRA to language model, %s will be ignored."
,
module_name
,
)
continue
parts
=
module_name
.
split
(
"."
)[
-
1
]
packed_moduled_lst
=
self
.
packed_modules_mapping
.
get
(
parts
,
[])
new_module
=
replace_submodule
(
self
.
model
,
module_name
,
from_layer
(
module
,
self
.
lora_slots
,
self
.
lora_config
,
packed_moduled_lst
,
self
.
model
.
config
))
# LinearScalingRotaryEmbeddingWithLora is used to handle
# long context lora. Register relevant metadata.
if
isinstance
(
new_module
,
LinearScalingRotaryEmbeddingWithLora
):
...
...
@@ -460,6 +480,15 @@ class LoRAModelManager(AdapterModelManager):
module
,
self
.
lora_slots
,
self
.
lora_config
,
self
.
model
.
config
))
# In some models, especially multimodal ones, layers with the same
# name may have different types, such as nn.Linear and
# ReplicatedLinear. The nn.Linear layers cannot be replaced with
# LoRA layers, leading to assertion error. The following check
# aims to prevent this error
if
self
.
supports_mm
and
not
isinstance
(
new_module
,
BaseLayerWithLoRA
):
continue
self
.
register_module
(
module_name
,
new_module
)
self
.
_register_packed_modules
(
module_name
)
# All lora layers share the same punica_wrapper based on reference.
...
...
@@ -478,9 +507,10 @@ class LoRAModelManager(AdapterModelManager):
"""Create zero-initialized LoRAModel for warmup."""
model
=
LoRAModel
(
lora_id
,
rank
,
{},
scaling_factor
)
for
module_name
,
module
in
self
.
model
.
named_modules
():
if
not
self
.
_match_target_modules
(
module_name
)
or
not
isinstance
(
module
,
BaseLayerWithLoRA
)
or
isinstance
(
module
,
LinearScalingRotaryEmbeddingWithLora
):
if
(
not
self
.
_match_target_modules
(
module_name
)
or
not
isinstance
(
module
,
BaseLayerWithLoRA
)
or
isinstance
(
module
,
LinearScalingRotaryEmbeddingWithLora
)
or
self
.
_filter_unsupported_mm_module
(
module_name
)):
continue
parts
=
module_name
.
split
(
"."
)
if
module_name
not
in
self
.
packed_modules
:
...
...
@@ -541,6 +571,19 @@ class LoRAModelManager(AdapterModelManager):
module_name
)
or
target_module
==
module_name
for
target_module
in
self
.
supported_lora_modules
)
def
_filter_unsupported_mm_module
(
self
,
module_name
:
str
)
->
bool
:
"""
Regarding multimodal models, vLLM currently only supports adding LoRA to
language model. LoRA for other modules, such as the vision tower, will
be filtered out.
"""
if
self
.
supports_mm
:
prefix
=
module_name
.
split
(
"."
)[
0
]
module_mapping
:
MultiModelKeys
=
self
.
model
.
get_mm_mapping
()
return
(
prefix
in
module_mapping
.
connector
or
prefix
in
module_mapping
.
tower_model
)
return
False
def
_register_packed_modules
(
self
,
module_full_name
:
str
)
->
None
:
parts
=
module_full_name
.
split
(
"."
)
module_name
=
parts
[
-
1
]
...
...
vllm/lora/utils.py
View file @
6d2051cc
import
os
from
typing
import
List
,
Optional
,
Set
,
Tuple
,
Type
import
re
from
typing
import
List
,
Optional
,
Set
,
Tuple
,
Type
,
Union
import
huggingface_hub
from
huggingface_hub.utils
import
(
EntryNotFoundError
,
HfHubHTTPError
,
...
...
@@ -113,6 +114,38 @@ def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]:
raise
ValueError
(
f
"
{
name
}
is unsupported LoRA weight"
)
def
is_regex_target_modules
(
load_modules
:
Union
[
str
,
List
[
str
]],
expected_lora_modules
:
List
[
str
])
->
bool
:
"""
PEFT supports passing `target_modules` in the form of regular expressions,
such as `model.*(q_proj|k_proj|v_proj)$`. This function is mainly used to
determine whether the suffix in the regular expression is present in the
`expected_lora_modules`.
"""
def
is_valid_regex
(
pattern
):
try
:
re
.
compile
(
pattern
)
return
True
except
re
.
error
:
return
False
def
is_subset
(
sub_list
,
full_list
):
return
set
(
sub_list
).
issubset
(
set
(
full_list
))
# Similar to PEFT's processing logic, regex-related operations are only
# executed when the load_modules is a `str`.
if
not
isinstance
(
load_modules
,
str
):
return
False
if
is_valid_regex
(
load_modules
):
match
=
re
.
search
(
r
"\((.*?)\)\$?$"
,
load_modules
)
if
match
:
suffix
=
match
.
group
(
1
).
split
(
"|"
)
return
is_subset
(
suffix
,
expected_lora_modules
)
return
False
def
get_adapter_absolute_path
(
lora_path
:
str
)
->
str
:
"""
Resolves the given lora_path to an absolute local path.
...
...
vllm/model_executor/custom_op.py
View file @
6d2051cc
import
torch.nn
as
nn
import
vllm.envs
as
envs
from
vllm.compilation.levels
import
CompilationLevel
from
vllm.platforms
import
current_platform
from
vllm.utils
import
is_cpu
,
is_hip
,
is_xpu
...
...
@@ -55,7 +56,7 @@ class CustomOp(nn.Module):
# NOTE(woosuk): Here we assume that vLLM was built for only one
# specific backend. Currently, we do not support dynamic dispatching.
if
envs
.
VLLM_T
EST
_COMPILE_
NO_CUSTOM_OPS
:
if
envs
.
VLLM_T
ORCH
_COMPILE_
LEVEL
>=
CompilationLevel
.
INDUCTOR
:
return
self
.
forward_native
if
is_hip
():
...
...
vllm/model_executor/guided_decoding/__init__.py
View file @
6d2051cc
from
typing
import
Optional
,
Union
from
typing
import
Optional
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionNamedToolChoiceParam
,
ChatCompletionRequest
,
CompletionRequest
)
from
vllm.model_executor.guided_decoding.guided_fields
import
(
GuidedDecodingRequest
)
from
vllm.sampling_params
import
LogitsProcessor
from
vllm.sampling_params
import
GuidedDecodingParams
,
LogitsProcessor
async
def
get_guided_decoding_logits_processor
(
guided_decoding_backend
:
str
,
request
:
Union
[
CompletionRequest
,
ChatCompletionRequest
],
guided_params
:
GuidedDecodingParams
,
tokenizer
)
->
Optional
[
LogitsProcessor
]:
request
=
_adapt_request_for_tool_use
(
request
)
if
guided_decoding_backend
==
'outlines'
:
# CFG grammar not supported by LMFE, so we use outlines instead
if
guided_params
.
backend
==
'outlines'
or
guided_params
.
grammar
:
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from
vllm.model_executor.guided_decoding.outlines_decoding
import
(
# noqa
get_outlines_guided_decoding_logits_processor
)
return
await
get_outlines_guided_decoding_logits_processor
(
request
,
tokenizer
)
if
guided_
decoding_
backend
==
'lm-format-enforcer'
:
guided_params
,
tokenizer
)
if
guided_
params
.
backend
==
'lm-format-enforcer'
:
from
vllm.model_executor.guided_decoding.lm_format_enforcer_decoding
import
(
# noqa
get_lm_format_enforcer_guided_decoding_logits_processor
)
return
await
get
_lm_format_enforcer_guided_decoding_logits_processor
(
request
,
tokenizer
)
get_
local_
lm_format_enforcer_guided_decoding_logits_processor
)
return
get_local
_lm_format_enforcer_guided_decoding_logits_processor
(
guided_params
,
tokenizer
)
raise
ValueError
(
f
"Unknown guided decoding backend '
{
guided_
decoding_
backend
}
'. "
f
"Unknown guided decoding backend '
{
guided_
params
.
backend
}
'. "
"Must be one of 'outlines, 'lm-format-enforcer'"
)
def
get_local_guided_decoding_logits_processor
(
guided_
decoding_backend
:
str
,
guided_option
s
:
GuidedDecoding
Request
,
guided_
param
s
:
GuidedDecoding
Params
,
tokenizer
)
->
Optional
[
LogitsProcessor
]:
# request = _adapt_request_for_tool_use(request)
if
guided_decoding_backend
==
'outlines'
:
# CFG grammar not supported by LMFE, so we use outlines instead
if
guided_params
.
backend
==
'outlines'
or
guided_params
.
grammar
:
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from
vllm.model_executor.guided_decoding.outlines_decoding
import
(
# noqa
get_local_outlines_guided_decoding_logits_processor
)
return
get_local_outlines_guided_decoding_logits_processor
(
guided_
option
s
,
tokenizer
)
if
guided_
decoding_
backend
==
'lm-format-enforcer'
:
guided_
param
s
,
tokenizer
)
if
guided_
params
.
backend
==
'lm-format-enforcer'
:
from
vllm.model_executor.guided_decoding.lm_format_enforcer_decoding
import
(
# noqa
get_local_lm_format_enforcer_guided_decoding_logits_processor
)
return
get_local_lm_format_enforcer_guided_decoding_logits_processor
(
guided_
option
s
,
tokenizer
)
guided_
param
s
,
tokenizer
)
raise
ValueError
(
f
"Unknown guided decoding backend '
{
guided_
decoding_
backend
}
'. "
f
"Unknown guided decoding backend '
{
guided_
params
.
backend
}
'. "
"Must be one of 'outlines, 'lm-format-enforcer'"
)
def
_adapt_request_for_tool_use
(
request
:
Union
[
CompletionRequest
,
ChatCompletionRequest
]):
# the legacy completion API does not support tool use
if
type
(
request
)
is
CompletionRequest
:
return
request
# user has chosen to not use any tool,
# OR is allowing the model to choose a tool.
if
request
.
tool_choice
==
"none"
or
request
.
tool_choice
==
"auto"
:
return
request
# user has chosen to use a named tool
if
type
(
request
.
tool_choice
)
is
ChatCompletionNamedToolChoiceParam
:
tool_name
=
request
.
tool_choice
.
function
.
name
tools
=
{
tool
.
function
.
name
:
tool
.
function
for
tool
in
request
.
tools
}
if
tool_name
not
in
tools
:
raise
ValueError
(
f
"Tool '
{
tool_name
}
' has not been passed in `tools`."
)
tool
=
tools
[
tool_name
]
request
.
guided_json
=
tool
.
parameters
return
request
vllm/model_executor/guided_decoding/guided_fields.py
View file @
6d2051cc
...
...
@@ -4,6 +4,7 @@ from typing import Dict, List, Optional, TypedDict, Union
from
pydantic
import
BaseModel
# These classes are deprecated, see SamplingParams
class
LLMGuidedOptions
(
TypedDict
,
total
=
False
):
guided_json
:
Union
[
Dict
,
BaseModel
,
str
]
guided_regex
:
str
...
...
vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py
View file @
6d2051cc
...
...
@@ -7,66 +7,13 @@ from lmformatenforcer import (CharacterLevelParser, JsonSchemaParser,
TokenEnforcerTokenizerData
,
UnionParser
)
from
lmformatenforcer.integrations.vllm
import
(
build_vllm_logits_processor
,
build_vllm_token_enforcer_tokenizer_data
)
from
pydantic
import
BaseModel
from
transformers
import
PreTrainedTokenizerBase
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
CompletionRequest
)
from
vllm.model_executor.guided_decoding.guided_fields
import
(
GuidedDecodingRequest
)
from
vllm.sampling_params
import
LogitsProcessor
async
def
get_lm_format_enforcer_guided_decoding_logits_processor
(
request
:
Union
[
CompletionRequest
,
ChatCompletionRequest
],
tokenizer
)
->
Optional
[
LogitsProcessor
]:
"""
Given an OpenAI-compatible request, check for guided decoding parameters
and get the necessary logits processor for the given guide.
We cache logit processors by (guide, tokenizer), and on cache hit
we make a shallow copy to reuse the same underlying FSM.
"""
tokenizer_data
=
_cached_build_vllm_token_enforcer_tokenizer_data
(
tokenizer
)
character_level_parser
:
CharacterLevelParser
if
request
.
guided_json
:
schema
=
_normalize_json_schema_object
(
request
.
guided_json
)
character_level_parser
=
JsonSchemaParser
(
schema
)
elif
request
.
guided_choice
:
character_level_parser
=
UnionParser
(
[
StringParser
(
choice
)
for
choice
in
request
.
guided_choice
])
elif
request
.
guided_regex
:
character_level_parser
=
RegexParser
(
request
.
guided_regex
)
elif
request
.
guided_grammar
:
# CFG grammar not supported by LMFE, revert to outlines
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from
vllm.model_executor.guided_decoding.outlines_decoding
import
(
get_outlines_guided_decoding_logits_processor
)
return
await
get_outlines_guided_decoding_logits_processor
(
request
,
tokenizer
)
elif
(
request
.
response_format
is
not
None
and
request
.
response_format
.
type
==
"json_object"
):
character_level_parser
=
JsonSchemaParser
(
None
)
# None means any json object
elif
(
request
.
response_format
is
not
None
and
request
.
response_format
.
type
==
"json_schema"
and
request
.
response_format
.
json_schema
is
not
None
and
request
.
response_format
.
json_schema
.
json_schema
is
not
None
):
schema
=
_normalize_json_schema_object
(
request
.
response_format
.
json_schema
.
json_schema
)
character_level_parser
=
JsonSchemaParser
(
schema
)
else
:
return
None
logits_processor
=
build_vllm_logits_processor
(
tokenizer_data
,
character_level_parser
)
return
logits_processor
from
vllm.sampling_params
import
GuidedDecodingParams
,
LogitsProcessor
def
get_local_lm_format_enforcer_guided_decoding_logits_processor
(
guided_
option
s
:
GuidedDecoding
Request
,
guided_
param
s
:
GuidedDecoding
Params
,
tokenizer
)
->
Optional
[
LogitsProcessor
]:
"""
Given an OpenAI-compatible request, check for guided decoding parameters
...
...
@@ -78,23 +25,20 @@ def get_local_lm_format_enforcer_guided_decoding_logits_processor(
tokenizer_data
=
_cached_build_vllm_token_enforcer_tokenizer_data
(
tokenizer
)
character_level_parser
:
CharacterLevelParser
if
guided_
options
.
guided_
json
:
schema
=
_normalize_json_schema_object
(
guided_
options
.
guided_
json
)
character_level_parser
=
JsonSchemaParser
(
schema
)
elif
guided_
options
.
guided_
choice
:
if
guided_
params
.
json
:
schema
_dict
=
_normalize_json_schema_object
(
guided_
params
.
json
)
character_level_parser
=
JsonSchemaParser
(
schema
_dict
)
elif
guided_
params
.
choice
:
character_level_parser
=
UnionParser
(
[
StringParser
(
choice
)
for
choice
in
guided_options
.
guided_choice
])
elif
guided_options
.
guided_regex
:
character_level_parser
=
RegexParser
(
guided_options
.
guided_regex
)
elif
guided_options
.
guided_grammar
:
# CFG grammar not supported by LMFE, revert to outlines
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from
vllm.model_executor.guided_decoding.outlines_decoding
import
(
get_local_outlines_guided_decoding_logits_processor
)
return
get_local_outlines_guided_decoding_logits_processor
(
guided_options
,
tokenizer
)
elif
guided_options
.
guided_json_object
:
[
StringParser
(
choice
)
for
choice
in
guided_params
.
choice
])
elif
guided_params
.
regex
:
character_level_parser
=
RegexParser
(
guided_params
.
regex
)
elif
guided_params
.
grammar
:
# CFG grammar not supported by LMFE
raise
ValueError
(
"Cannot construct a guided decoding logits processor"
" using the grammar option with the"
" lm_format_enforcer backend."
)
elif
guided_params
.
json_object
:
# None means any json object
character_level_parser
=
JsonSchemaParser
(
None
)
else
:
...
...
@@ -105,13 +49,11 @@ def get_local_lm_format_enforcer_guided_decoding_logits_processor(
return
logits_processor
def
_normalize_json_schema_object
(
schema
:
Union
[
str
,
dict
,
BaseModel
])
->
dict
:
def
_normalize_json_schema_object
(
schema
:
Union
[
str
,
dict
])
->
dict
:
if
isinstance
(
schema
,
str
):
return
json_loads
(
schema
)
if
isinstance
(
schema
,
dict
):
return
schema
if
isinstance
(
schema
,
BaseModel
):
return
schema
.
model_json_schema
()
raise
AssertionError
(
f
"Unsupported schema type
{
schema
}
"
)
...
...
vllm/model_executor/guided_decoding/outlines_decoding.py
View file @
6d2051cc
...
...
@@ -5,16 +5,11 @@ from json import dumps as json_dumps
from
re
import
escape
as
regex_escape
from
typing
import
Tuple
,
Union
from
pydantic
import
BaseModel
from
transformers
import
PreTrainedTokenizerBase
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionNamedToolChoiceParam
,
ChatCompletionRequest
,
CompletionRequest
)
from
vllm.model_executor.guided_decoding.guided_fields
import
(
GuidedDecodingRequest
)
from
vllm.model_executor.guided_decoding.outlines_logits_processors
import
(
CFGLogitsProcessor
,
JSONLogitsProcessor
,
RegexLogitsProcessor
)
from
vllm.sampling_params
import
GuidedDecodingParams
class
GuidedDecodingMode
(
Enum
):
...
...
@@ -55,8 +50,7 @@ global_thread_pool = None # used for generating logits processor fsm
async
def
get_outlines_guided_decoding_logits_processor
(
request
:
Union
[
CompletionRequest
,
ChatCompletionRequest
],
tokenizer
:
PreTrainedTokenizerBase
guided_params
:
GuidedDecodingParams
,
tokenizer
:
PreTrainedTokenizerBase
)
->
Union
[
JSONLogitsProcessor
,
RegexLogitsProcessor
,
CFGLogitsProcessor
,
None
]:
"""
...
...
@@ -66,7 +60,7 @@ async def get_outlines_guided_decoding_logits_processor(
we make a shallow copy to reuse the same underlying FSM.
"""
global
global_thread_pool
guide
,
mode
=
_get_guide_and_mode
(
request
)
guide
,
mode
=
_get_guide_and_mode
(
guided_params
)
if
not
guide
or
not
mode
:
return
None
...
...
@@ -77,11 +71,11 @@ async def get_outlines_guided_decoding_logits_processor(
return
await
loop
.
run_in_executor
(
global_thread_pool
,
_get_logits_processor
,
guide
,
tokenizer
,
mode
,
request
.
guided_whitespace_pattern
)
mode
,
guided_
params
.
whitespace_pattern
)
def
get_local_outlines_guided_decoding_logits_processor
(
guided_
option
s
:
GuidedDecoding
Request
,
tokenizer
:
PreTrainedTokenizerBase
guided_
param
s
:
GuidedDecoding
Params
,
tokenizer
:
PreTrainedTokenizerBase
)
->
Union
[
JSONLogitsProcessor
,
RegexLogitsProcessor
,
CFGLogitsProcessor
,
None
]:
"""
...
...
@@ -90,65 +84,37 @@ def get_local_outlines_guided_decoding_logits_processor(
We cache logit processors by (guide, tokenizer), and on cache hit
we make a shallow copy to reuse the same underlying FSM.
"""
guide
,
mode
=
_get_guide_and_mode
(
guided_
option
s
)
guide
,
mode
=
_get_guide_and_mode
(
guided_
param
s
)
if
not
guide
or
not
mode
:
return
None
return
_get_logits_processor
(
guide
,
tokenizer
,
mode
,
guided_
options
.
guided_
whitespace_pattern
)
guided_
params
.
whitespace_pattern
)
def
_get_guide_and_mode
(
request
:
Union
[
CompletionRequest
,
ChatCompletionRequest
,
GuidedDecodingRequest
]
guided_params
:
GuidedDecodingParams
)
->
Union
[
Tuple
[
str
,
GuidedDecodingMode
],
Tuple
[
None
,
None
]]:
# if the request is a chat completion request, AND the tool choice is a
# named tool choice, do guided decoding
# using that tool as the JSON schema
if
isinstance
(
request
,
ChatCompletionRequest
)
and
isinstance
(
request
.
tool_choice
,
ChatCompletionNamedToolChoiceParam
):
# Guided generation for tools/functions parameters
if
request
.
tool_choice
.
type
==
"function"
:
for
tool
in
request
.
tools
:
if
(
tool
.
type
==
"function"
and
tool
.
function
.
name
==
request
.
tool_choice
.
function
.
name
):
json
=
json_dumps
(
tool
.
function
.
parameters
,
sort_keys
=
True
)
return
json
,
GuidedDecodingMode
.
JSON
return
None
,
None
elif
request
.
guided_json
:
if
isinstance
(
request
.
guided_json
,
dict
):
if
guided_params
.
json
:
if
isinstance
(
guided_params
.
json
,
dict
):
# turn dict into hashable string
json
=
json_dumps
(
request
.
guided_json
)
elif
isinstance
(
request
.
guided_json
,
BaseModel
):
# use pydantic signature so that different model classes
# with the same fields will get hashed the same
json
=
str
(
request
.
guided_json
.
__signature__
)
json
=
json_dumps
(
guided_params
.
json
)
else
:
json
=
request
.
guided_json
json
=
guided_
params
.
json
return
json
,
GuidedDecodingMode
.
JSON
elif
request
.
guided_regex
:
return
request
.
guided_regex
,
GuidedDecodingMode
.
REGEX
elif
request
.
guided_choice
:
elif
guided_
params
.
regex
:
return
guided_
params
.
regex
,
GuidedDecodingMode
.
REGEX
elif
guided_
params
.
choice
:
# choice just uses regex
choices
=
[
regex_escape
(
str
(
choice
))
for
choice
in
request
.
guided_choice
regex_escape
(
str
(
choice
))
for
choice
in
guided_
params
.
choice
]
choices_regex
=
"("
+
"|"
.
join
(
choices
)
+
")"
return
choices_regex
,
GuidedDecodingMode
.
CHOICE
elif
request
.
guided_grammar
:
return
request
.
guided_grammar
,
GuidedDecodingMode
.
GRAMMAR
elif
(
not
isinstance
(
request
,
GuidedDecodingRequest
)
and
request
.
response_format
is
not
None
and
request
.
response_format
.
type
==
"json_object"
):
elif
guided_params
.
grammar
:
return
guided_params
.
grammar
,
GuidedDecodingMode
.
GRAMMAR
elif
guided_params
.
json_object
:
return
JSON_GRAMMAR
,
GuidedDecodingMode
.
GRAMMAR
elif
(
not
isinstance
(
request
,
GuidedDecodingRequest
)
and
request
.
response_format
is
not
None
and
request
.
response_format
.
type
==
"json_schema"
and
request
.
response_format
.
json_schema
is
not
None
and
request
.
response_format
.
json_schema
.
json_schema
is
not
None
):
json
=
json_dumps
(
request
.
response_format
.
json_schema
.
json_schema
)
return
json
,
GuidedDecodingMode
.
JSON
else
:
return
None
,
None
...
...
vllm/model_executor/layers/activation.py
View file @
6d2051cc
...
...
@@ -14,6 +14,33 @@ from vllm.model_executor.utils import set_weight_attrs
import
vllm.envs
as
envs
class
FatreluAndMul
(
CustomOp
):
"""An activation function for FATReLU.
The function computes x -> FATReLU(x[:d]) * x[d:] where
d = x.shape[-1] // 2.
This is used in openbmb/MiniCPM-S-1B-sft.
Shapes:
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
return: (num_tokens, d) or (batch_size, seq_len, d)
"""
def
__init__
(
self
,
threshold
:
float
=
0.
):
super
().
__init__
()
self
.
threshold
=
threshold
def
forward_native
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
d
=
x
.
shape
[
-
1
]
//
2
x1
=
x
[...,
:
d
]
x2
=
x
[...,
d
:]
x1
=
F
.
threshold
(
x1
,
self
.
threshold
,
0.0
)
return
x1
*
x2
def
forward_cuda
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
forward_native
(
x
)
class
SiluAndMul
(
CustomOp
):
"""An activation function for SwiGLU.
...
...
vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_L40S.json
0 → 100644
View file @
6d2051cc
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_ctas"
:
1
,
"num_stages"
:
2
},
"2"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
2
,
"num_warps"
:
4
,
"num_ctas"
:
1
,
"num_stages"
:
7
},
"4"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
128
,
"num_warps"
:
2
,
"num_ctas"
:
1
,
"num_stages"
:
4
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
2
,
"num_ctas"
:
1
,
"num_stages"
:
1
},
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
2
,
"num_ctas"
:
1
,
"num_stages"
:
1
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
2
,
"num_warps"
:
4
,
"num_ctas"
:
1
,
"num_stages"
:
2
},
"48"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
2
,
"num_warps"
:
4
,
"num_ctas"
:
1
,
"num_stages"
:
2
},
"64"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
2
,
"num_warps"
:
4
,
"num_ctas"
:
1
,
"num_stages"
:
2
},
"96"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
8
,
"num_warps"
:
8
,
"num_ctas"
:
1
,
"num_stages"
:
2
},
"128"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
8
,
"num_warps"
:
8
,
"num_ctas"
:
1
,
"num_stages"
:
2
},
"192"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
8
,
"num_warps"
:
8
,
"num_ctas"
:
1
,
"num_stages"
:
2
},
"256"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
16
,
"num_ctas"
:
1
,
"num_stages"
:
2
},
"512"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
128
,
"num_warps"
:
2
,
"num_ctas"
:
1
,
"num_stages"
:
8
},
"1024"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
4
,
"num_warps"
:
8
,
"num_ctas"
:
1
,
"num_stages"
:
3
},
"2048"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
16
,
"num_ctas"
:
1
,
"num_stages"
:
2
},
"3072"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
16
,
"num_ctas"
:
1
,
"num_stages"
:
2
},
"4096"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
8
,
"num_ctas"
:
1
,
"num_stages"
:
2
},
"6144"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
8
,
"num_ctas"
:
1
,
"num_stages"
:
2
},
"8192"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
16
,
"num_ctas"
:
1
,
"num_stages"
:
2
}
}
\ No newline at end of file
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
View file @
6d2051cc
...
...
@@ -10,17 +10,27 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
from
vllm.scalar_type
import
scalar_types
def
get_scalar_type
(
num_bits
:
int
,
has_zp
:
bool
):
if
has_zp
:
assert
num_bits
==
4
return
scalar_types
.
uint4
else
:
return
scalar_types
.
uint4b8
if
num_bits
==
4
else
scalar_types
.
uint8b128
def
single_marlin_moe
(
hidden_states
:
torch
.
Tensor
,
w
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
g_idx
:
torch
.
Tensor
,
perm
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
g_idx
:
Optional
[
torch
.
Tensor
]
=
None
,
sort_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
w_zeros
:
Optional
[
torch
.
Tensor
]
=
None
,
override_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
num_bits
:
int
=
8
,
is_k_full
:
bool
=
True
,
)
->
torch
.
Tensor
:
"""
This function computes the multiplication of hidden_states with expert
...
...
@@ -33,10 +43,12 @@ def single_marlin_moe(
- scales (torch.Tensor): The quantization scales.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- g_idx (torch.Tensor): The act_order indices.
- perm (torch.Tensor): The act_order input permutation.
- g_idx (Optional[torch.Tensor]): Optional act_order indices.
- sort_indices (Optional[torch.Tensor]): Optional act_order input
permutation.
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- w_zeros (Optional[torch.Tensor]): Optional zero points to be used for w.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- num_bits (bool): The number of bits in expert weights quantization.
...
...
@@ -78,16 +90,34 @@ def single_marlin_moe(
max_workspace_size
=
(
N
//
64
)
*
16
workspace
=
torch
.
zeros
(
max_workspace_size
,
dtype
=
torch
.
int
,
device
=
"cuda"
,
device
=
hidden_states
.
device
,
requires_grad
=
False
)
scalar_type
=
(
scalar_types
.
uint4b8
if
num_bits
==
4
else
scalar_types
.
uint8b128
)
has_zero_point
=
w_zeros
is
not
None
if
w_zeros
is
None
:
w_zeros
=
torch
.
empty
((
0
,
0
),
dtype
=
hidden_states
.
dtype
,
device
=
hidden_states
.
device
,
requires_grad
=
False
)
if
g_idx
is
None
:
g_idx
=
torch
.
empty
((
0
,
0
),
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
,
requires_grad
=
False
)
if
sort_indices
is
None
:
sort_indices
=
torch
.
empty
((
0
),
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
,
requires_grad
=
False
)
scalar_type
=
get_scalar_type
(
num_bits
,
has_zero_point
)
intermediate_cache
=
torch
.
ops
.
_moe_C
.
marlin_gemm_moe
(
hidden_states
,
w
,
sorted_token_ids
,
topk_weights
,
topk_ids
,
scales
,
g_idx
,
perm
,
workspace
,
scalar_type
,
M
,
N
,
K
,
True
,
E
,
topk
,
block_size_m
,
True
,
False
)
w_zeros
,
g_idx
,
sort_indices
,
workspace
,
scalar_type
,
M
,
N
,
K
,
is_k_full
,
E
,
topk
,
block_size_m
,
True
,
False
)
return
torch
.
sum
(
intermediate_cache
.
view
(
*
intermediate_cache
.
shape
),
dim
=
1
)
...
...
@@ -96,17 +126,20 @@ def fused_marlin_moe(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
g_idx1
:
torch
.
Tensor
,
g_idx2
:
torch
.
Tensor
,
perm1
:
torch
.
Tensor
,
perm2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
g_idx1
:
Optional
[
torch
.
Tensor
]
=
None
,
g_idx2
:
Optional
[
torch
.
Tensor
]
=
None
,
sort_indices1
:
Optional
[
torch
.
Tensor
]
=
None
,
sort_indices2
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_zeros
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_zeros
:
Optional
[
torch
.
Tensor
]
=
None
,
override_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
num_bits
:
int
=
8
,
is_k_full
:
bool
=
True
,
)
->
torch
.
Tensor
:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
...
...
@@ -116,21 +149,22 @@ def fused_marlin_moe(
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- w1_scale (torch.Tensor): Scale to be used for w1.
- w2_scale (torch.Tensor): Scale to be used for w2.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- g_idx1 (torch.Tensor): The first set of act_order indices.
- g_idx2 (torch.Tensor): The second set of act_order indices.
- perm1 (torch.Tensor): The first act_order input permutation.
- perm2 (torch.Tensor): The second act_order input permutation.
- g_idx1 (Optional[torch.Tensor]): The first set of act_order indices.
- g_idx2 (Optional[torch.Tensor]): The second set of act_order indices.
- sort_indices1 (Optional[torch.Tensor]): The first act_order input
permutation.
- sort_indices2 (Optional[torch.Tensor]): The second act_order input
permutation.
- topk_weights (torch.Tensor): Top-k weights.
- topk_ids (torch.Tensor): Indices of topk-k elements.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.
- w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1.
- w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2.
- num_bits (bool): The number of bits in expert weights quantization.
Returns:
...
...
@@ -150,6 +184,20 @@ def fused_marlin_moe(
assert
hidden_states
.
dtype
==
torch
.
float16
assert
num_bits
in
[
4
,
8
]
has_no_act_order
=
(
g_idx1
is
None
and
g_idx2
is
None
and
sort_indices1
is
None
and
sort_indices2
is
None
)
has_all_act_order
=
(
g_idx1
is
not
None
and
g_idx2
is
not
None
and
sort_indices1
is
not
None
and
sort_indices2
is
not
None
)
assert
has_no_act_order
or
has_all_act_order
,
(
"g_idx and sorted_indices "
"must be all not None or must be all None"
)
has_no_zp
=
w1_zeros
is
None
and
w2_zeros
is
None
has_all_zp
=
w1_zeros
is
not
None
and
w2_zeros
is
not
None
assert
has_no_zp
or
has_all_zp
,
(
"zero points must be both not None or "
"must be both None"
)
M
,
K
=
hidden_states
.
shape
E
=
w1
.
shape
[
0
]
N
=
w2
.
shape
[
1
]
*
16
...
...
@@ -170,14 +218,42 @@ def fused_marlin_moe(
sorted_token_ids
,
_
,
_
=
moe_align_block_size
(
topk_ids
,
block_size_m
,
E
)
max_workspace_size
=
((
M
+
255
)
//
256
)
*
(
max
(
2
*
N
,
K
)
//
64
)
*
16
max_workspace_size
=
(
max
(
2
*
N
,
K
)
//
64
)
*
16
workspace
=
torch
.
zeros
(
max_workspace_size
,
dtype
=
torch
.
int
,
device
=
"cuda"
,
requires_grad
=
False
)
scalar_type
=
(
scalar_types
.
uint4b8
if
num_bits
==
4
else
scalar_types
.
uint8b128
)
if
has_no_zp
:
w1_zeros
=
torch
.
empty
((
0
,
0
),
dtype
=
hidden_states
.
dtype
,
device
=
hidden_states
.
device
,
requires_grad
=
False
)
w2_zeros
=
torch
.
empty
((
0
,
0
),
dtype
=
hidden_states
.
dtype
,
device
=
hidden_states
.
device
,
requires_grad
=
False
)
if
has_no_act_order
:
g_idx1
=
torch
.
empty
((
0
,
0
),
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
,
requires_grad
=
False
)
g_idx2
=
torch
.
empty
((
0
,
0
),
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
,
requires_grad
=
False
)
sort_indices1
=
torch
.
empty
((
0
),
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
,
requires_grad
=
False
)
sort_indices2
=
torch
.
empty
((
0
,
0
),
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
,
requires_grad
=
False
)
scalar_type1
=
get_scalar_type
(
num_bits
,
has_all_zp
)
scalar_type2
=
get_scalar_type
(
num_bits
,
has_all_zp
)
intermediate_cache2
=
torch
.
empty
(
(
M
*
topk_ids
.
shape
[
1
],
N
),
...
...
@@ -192,14 +268,15 @@ def fused_marlin_moe(
topk_weights
,
topk_ids
,
w1_scale
,
w1_zeros
,
g_idx1
,
perm
1
,
sort_indices
1
,
workspace
,
scalar_type
,
scalar_type
1
,
M
,
2
*
N
,
K
,
True
,
is_k_full
,
E
,
topk
,
block_size_m
,
...
...
@@ -216,14 +293,15 @@ def fused_marlin_moe(
topk_weights
,
topk_ids
,
w2_scale
,
w2_zeros
,
g_idx2
,
perm
2
,
sort_indices
2
,
workspace
,
scalar_type
,
scalar_type
2
,
M
,
K
,
N
,
True
,
is_k_full
,
E
,
topk
,
block_size_m
,
...
...
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
6d2051cc
...
...
@@ -320,6 +320,9 @@ def get_moe_configs(E: int, N: int,
# If no optimized configuration is available, we will use the default
# configuration
logger
.
warning
(
(
"Using default MoE config. Performance might be sub-optimal! "
"Config file not found at %s"
),
config_file_path
)
return
None
...
...
vllm/model_executor/layers/layernorm.py
View file @
6d2051cc
...
...
@@ -19,10 +19,16 @@ class RMSNorm(CustomOp):
self
,
hidden_size
:
int
,
eps
:
float
=
1e-6
,
var_hidden_size
:
Optional
[
int
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
))
self
.
hidden_size
=
hidden_size
self
.
variance_epsilon
=
eps
self
.
variance_size_override
=
(
None
if
var_hidden_size
==
hidden_size
else
var_hidden_size
)
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
))
def
forward_native
(
self
,
...
...
@@ -36,7 +42,23 @@ class RMSNorm(CustomOp):
x
=
x
+
residual
.
to
(
torch
.
float32
)
residual
=
x
.
to
(
orig_dtype
)
variance
=
x
.
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
hidden_size
=
x
.
shape
[
-
1
]
if
hidden_size
!=
self
.
hidden_size
:
raise
ValueError
(
"Expected hidden_size to be "
f
"
{
self
.
hidden_size
}
, but found:
{
hidden_size
}
"
)
if
self
.
variance_size_override
is
None
:
x_var
=
x
else
:
if
hidden_size
<
self
.
variance_size_override
:
raise
ValueError
(
"Expected hidden_size to be at least "
f
"
{
self
.
variance_size_override
}
, but found:
{
hidden_size
}
"
)
x_var
=
x
[:,
:,
:
self
.
variance_size_override
]
variance
=
x_var
.
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
x
=
x
*
torch
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
x
=
x
.
to
(
orig_dtype
)
*
self
.
weight
if
residual
is
None
:
...
...
@@ -49,6 +71,9 @@ class RMSNorm(CustomOp):
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
if
self
.
variance_size_override
is
not
None
:
return
self
.
forward_native
(
x
,
residual
)
from
vllm
import
_custom_ops
as
ops
if
residual
is
not
None
:
...
...
@@ -89,6 +114,9 @@ class RMSNorm(CustomOp):
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
if
self
.
variance_size_override
is
not
None
:
return
self
.
forward_native
(
x
,
residual
)
from
vllm._ipex_ops
import
ipex_ops
as
ops
if
residual
is
not
None
:
...
...
vllm/model_executor/layers/linear.py
View file @
6d2051cc
...
...
@@ -30,7 +30,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [
"AWQLinearMethod"
,
"GPTQMarlinLinearMethod"
,
"Fp8LinearMethod"
,
"MarlinLinearMethod"
,
"QQQLinearMethod"
,
"GPTQMarlin24LinearMethod"
,
"TPUInt8LinearMethod"
,
"GPTQLinearMethod"
,
"FBGEMMFp8LinearMethod"
,
"ModelOptFp8LinearMethod"
"ModelOptFp8LinearMethod"
,
"IPEXAWQLinearMethod"
]
...
...
@@ -355,8 +355,12 @@ class ColumnParallelLinear(LinearBase):
if
is_gguf_weight
and
isinstance
(
param
,
UninitializedParameter
):
param
.
materialize
(
loaded_weight
.
shape
,
dtype
=
loaded_weight
.
dtype
)
use_bitsandbytes_4bit
=
getattr
(
param
,
"use_bitsandbytes_4bit"
,
False
)
param_data
=
param
.
data
if
output_dim
is
not
None
:
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if
output_dim
is
not
None
and
not
use_bitsandbytes_4bit
:
shard_size
=
param_data
.
shape
[
output_dim
]
start_idx
=
tp_rank
*
shard_size
loaded_weight
=
loaded_weight
.
narrow
(
output_dim
,
start_idx
,
...
...
@@ -459,17 +463,23 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
param
.
shard_weight_type
[
loaded_shard_id
]
=
loaded_weight
.
item
()
return
if
is_gguf_weight
and
isinstance
(
param
,
UninitializedParameter
):
from
gguf.constants
import
GGML_QUANT_SIZES
if
is_gguf_weight
:
tp_size
=
get_tensor_model_parallel_world_size
()
tp_rank
=
get_tensor_model_parallel_rank
()
ori_shape
=
param
.
tensor_shape
weight_types
=
self
.
qweight_type
.
shard_weight_type
.
values
()
row_size
=
[]
for
weight_type
in
weight_types
:
block_size
,
type_size
=
GGML_QUANT_SIZES
[
weight_type
]
row_size
.
append
(
ori_shape
[
1
]
//
block_size
*
type_size
)
q_shape
=
(
ori_shape
[
0
],
max
(
row_size
))
param
.
materialize
(
q_shape
,
dtype
=
loaded_weight
.
dtype
)
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
shard_size
=
loaded_weight
.
size
(
output_dim
)
//
tp_size
start_idx
=
tp_rank
*
shard_size
loaded_weight
=
loaded_weight
.
narrow
(
output_dim
,
start_idx
,
shard_size
)
param
.
shard_id
.
append
(
loaded_shard_id
)
param
.
shard_id_map
[
loaded_shard_id
]
=
len
(
param
.
data_container
)
param
.
data_container
.
append
(
loaded_weight
)
if
len
(
param
.
data_container
)
==
2
:
self
.
qweight
=
param
.
materialize_nested
()
return
param_data
=
param
.
data
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
...
...
@@ -534,18 +544,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_offset
=
loaded_weight
.
shape
[
output_dim
]
*
\
loaded_shard_id
if
is_gguf_weight
:
tp_size
=
get_tensor_model_parallel_world_size
()
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
shard_shape
=
list
(
loaded_weight
.
shape
)
shard_shape
[
output_dim
]
=
shard_shape
[
output_dim
]
//
tp_size
param
.
shard_id
.
append
(
loaded_shard_id
)
param
.
shard_size
[
loaded_shard_id
]
=
shard_shape
input_dim
=
getattr
(
param
,
"input_dim"
,
None
)
input_size
=
loaded_weight
.
shape
[
input_dim
]
param_data
=
param_data
.
narrow
(
input_dim
,
0
,
input_size
)
param_data
=
param_data
.
narrow
(
output_dim
,
shard_offset
,
shard_size
)
start_idx
=
tp_rank
*
shard_size
...
...
@@ -802,17 +800,23 @@ class QKVParallelLinear(ColumnParallelLinear):
param
.
shard_weight_type
[
loaded_shard_id
]
=
loaded_weight
.
item
()
return
if
is_gguf_weight
and
isinstance
(
param
,
UninitializedParameter
):
from
gguf.constants
import
GGML_QUANT_SIZES
if
is_gguf_weight
:
tp_size
=
get_tensor_model_parallel_world_size
()
tp_rank
=
get_tensor_model_parallel_rank
()
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
shard_size
=
loaded_weight
.
size
(
output_dim
)
//
tp_size
start_idx
=
tp_rank
*
shard_size
loaded_weight
=
loaded_weight
.
narrow
(
output_dim
,
start_idx
,
shard_size
)
ori_shape
=
param
.
tensor_shape
weight_types
=
self
.
qweight_type
.
shard_weight_type
.
values
()
row_size
=
[]
for
weight_type
in
weight_types
:
block_size
,
type_size
=
GGML_QUANT_SIZES
[
weight_type
]
row_size
.
append
(
ori_shape
[
1
]
//
block_size
*
type_size
)
q_shape
=
(
ori_shape
[
0
],
max
(
row_size
))
param
.
materialize
(
q_shape
,
dtype
=
loaded_weight
.
dtype
)
param
.
shard_id
.
append
(
loaded_shard_id
)
param
.
shard_id_map
[
loaded_shard_id
]
=
len
(
param
.
data_container
)
param
.
data_container
.
append
(
loaded_weight
)
if
len
(
param
.
data_container
)
==
3
:
self
.
qweight
=
param
.
materialize_nested
()
return
param_data
=
param
.
data
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
...
...
@@ -840,6 +844,9 @@ class QKVParallelLinear(ColumnParallelLinear):
(
"v"
,
(
self
.
total_num_heads
+
self
.
total_num_kv_heads
)
*
self
.
head_size
,
self
.
total_num_kv_heads
*
self
.
head_size
),
]
use_bitsandbytes_4bit
=
getattr
(
param
,
"use_bitsandbytes_4bit"
,
False
)
packed_dim
=
getattr
(
param
,
"packed_dim"
,
None
)
for
shard_id
,
shard_offset
,
shard_size
in
shard_offsets
:
# Special case for Quantized Weights.
...
...
@@ -853,6 +860,23 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_size
,
shard_offset
=
adjust_marlin_shard
(
param
,
shard_size
,
shard_offset
)
if
use_bitsandbytes_4bit
:
orig_qkv_offsets
=
{
"q"
:
(
0
,
self
.
total_num_heads
*
self
.
head_size
),
"k"
:
(
self
.
total_num_heads
*
self
.
head_size
,
self
.
total_num_kv_heads
*
self
.
head_size
),
"v"
:
((
self
.
total_num_heads
+
self
.
total_num_kv_heads
)
*
self
.
head_size
,
self
.
total_num_kv_heads
*
self
.
head_size
),
"total"
:
((
self
.
total_num_heads
+
2
*
self
.
total_num_kv_heads
)
*
self
.
head_size
,
0
)
}
shard_size
,
shard_offset
=
adjust_bitsandbytes_4bit_shard
(
param
,
orig_qkv_offsets
,
shard_id
)
loaded_weight_shard
=
loaded_weight
.
narrow
(
output_dim
,
shard_offset
,
shard_size
)
self
.
weight_loader
(
param
,
loaded_weight_shard
,
shard_id
)
...
...
@@ -902,18 +926,6 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_size
,
shard_offset
=
adjust_bitsandbytes_4bit_shard
(
param
,
orig_qkv_offsets
,
loaded_shard_id
)
if
is_gguf_weight
:
tp_size
=
get_tensor_model_parallel_world_size
()
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
shard_shape
=
list
(
loaded_weight
.
shape
)
shard_shape
[
output_dim
]
=
shard_shape
[
output_dim
]
//
tp_size
param
.
shard_id
.
append
(
loaded_shard_id
)
param
.
shard_size
[
loaded_shard_id
]
=
shard_shape
input_dim
=
getattr
(
param
,
"input_dim"
,
None
)
input_size
=
loaded_weight
.
shape
[
input_dim
]
param_data
=
param_data
.
narrow
(
input_dim
,
0
,
input_size
)
param_data
=
param_data
.
narrow
(
output_dim
,
shard_offset
,
shard_size
)
if
loaded_shard_id
==
"q"
:
...
...
vllm/model_executor/layers/mamba/ops/causal_conv1d.py
View file @
6d2051cc
...
...
@@ -6,65 +6,57 @@ from typing import Optional
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.attention.backends.utils
import
PAD_SLOT_ID
def
causal_conv1d_fn
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
seq_idx
:
Optional
[
torch
.
Tensor
]
=
None
,
initial_states
:
Optional
[
torch
.
Tensor
]
=
None
,
return_final_states
:
bool
=
False
,
final_states_out
=
None
,
activation
:
str
=
"silu"
,
):
def
causal_conv1d_fn
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
query_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
,
cache_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
has_initial_state
:
Optional
[
torch
.
Tensor
]
=
None
,
conv_states
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
Optional
[
str
]
=
"silu"
,
pad_slot_id
:
int
=
PAD_SLOT_ID
):
"""
x: (batch, dim, seqlen)
x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen
sequences are concatenated from left to right for varlen
weight: (dim, width)
bias: (dim,)
seq_idx: (batch, seqlen)
initial_states: (batch, dim, width - 1)
final_states_out: (batch, dim, width - 1), to be written to
query_start_loc: (batch + 1) int32
The cumulative sequence lengths of the sequences in
the batch, used to index into sequence. prepended by 0.
for example: query_start_loc = torch.Tensor([0,10,16,17]),
x.shape=(dim,17)
cache_indices: (batch) int32
indicates the corresponding state index,
like so: conv_state = conv_states[cache_indices[batch_id]]
has_initial_state: (batch) bool
indicates whether should the kernel take the current state as initial
state for the calculations
conv_states: (...,dim,width - 1) itype
updated inplace if provided
activation: either None or "silu" or "swish"
pad_slot_id: int
if cache_indices is passed, lets the kernel identify padded
entries that will not be processed,
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
in this case, the kernel will not process entries at
indices 0 and 3
out: (batch, dim, seqlen)
"""
if
activation
not
in
[
None
,
"silu"
,
"swish"
]:
raise
NotImplementedError
(
"activation must be None, silu, or swish"
)
if
x
.
stride
(
2
)
!=
1
and
x
.
stride
(
1
)
!=
1
:
if
x
.
stride
(
-
1
)
!=
1
:
x
=
x
.
contiguous
()
bias
=
bias
.
contiguous
()
if
bias
is
not
None
else
None
if
seq_idx
is
not
None
:
assert
(
initial_states
is
None
),
"initial_states must be None if seq_idx is not None"
assert
(
not
return_final_states
),
"If seq_idx is not None, we don't return final_states_out"
seq_idx
=
seq_idx
.
contiguous
()
if
seq_idx
is
not
None
else
None
if
initial_states
is
not
None
and
(
initial_states
.
stride
(
2
)
!=
1
and
initial_states
.
stride
(
1
)
!=
1
):
initial_states
=
initial_states
.
contiguous
()
if
return_final_states
:
assert
(
x
.
stride
(
1
)
==
1
),
"Only channel-last layout support returning final_states_out"
if
final_states_out
is
not
None
:
assert
(
final_states_out
.
stride
(
2
)
==
1
or
final_states_out
.
stride
(
1
)
==
1
)
else
:
batch
,
dim
,
seqlen
=
x
.
shape
width
=
weight
.
shape
[
1
]
final_states_out
=
torch
.
empty
(
batch
,
width
-
1
,
dim
,
device
=
x
.
device
,
dtype
=
x
.
dtype
).
transpose
(
1
,
2
)
else
:
final_states_out
=
None
out
=
ops
.
causal_conv1d_fwd
(
x
,
weight
,
bias
,
seq_idx
,
initial_states
,
fin
al_state
s_out
,
activation
in
[
"silu"
,
"swish"
])
return
(
out
,
None
)
if
not
return_final_states
else
(
out
,
final_states_out
)
ops
.
causal_conv1d_fwd
(
x
,
weight
,
bias
,
conv_states
,
query_start_loc
,
cache_indices
,
has_initi
al_state
,
activation
in
[
"silu"
,
"swish"
]
,
pad_slot_id
)
return
x
def
causal_conv1d_update
(
x
:
torch
.
Tensor
,
...
...
@@ -72,21 +64,39 @@ def causal_conv1d_update(x: torch.Tensor,
weight
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
Optional
[
str
]
=
None
,
conv_state_indices
:
Optional
[
torch
.
Tensor
]
=
None
):
cache_seqlens
:
Optional
[
torch
.
Tensor
]
=
None
,
conv_state_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
pad_slot_id
:
int
=
PAD_SLOT_ID
):
"""
x: (batch, dim)
conv_state: (batch, dim, width
)
x: (batch, dim)
or (batch, dim, seqlen)
conv_state: (batch, dim,
state_len), where state_len >=
width
- 1
weight: (dim, width)
bias: (dim,)
cache_seqlens: (batch,), dtype int32.
If not None, the conv_state is treated as a circular buffer.
The conv_state will be updated by copying x to the conv_state
starting at the index
@cache_seqlens % state_len.
conv_state_indices: (batch,), dtype int32
If not None, the conv_state is a larger tensor along the batch dim,
and we are selecting the batch coords specified by conv_state_indices.
Useful for a continuous batching scenario.
out: (batch, dim)
pad_slot_id: int
if cache_indices is passed, lets the kernel identify padded
entries that will not be processed,
for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
in this case, the kernel will not process entries at
indices 0 and 3
out: (batch, dim) or (batch, dim, seqlen)
"""
if
activation
not
in
[
None
,
"silu"
,
"swish"
]:
raise
NotImplementedError
(
"activation must be None, silu, or swish"
)
activation_bool
=
activation
in
[
"silu"
,
"swish"
]
return
ops
.
causal_conv1d_update
(
x
,
conv_state
,
weight
,
bias
,
activation_bool
,
conv_state_indices
)
activation_val
=
activation
in
[
"silu"
,
"swish"
]
unsqueeze
=
x
.
dim
()
==
2
if
unsqueeze
:
x
=
x
.
unsqueeze
(
-
1
)
ops
.
causal_conv1d_update
(
x
,
conv_state
,
weight
,
bias
,
activation_val
,
cache_seqlens
,
conv_state_indices
,
pad_slot_id
)
if
unsqueeze
:
x
=
x
.
squeeze
(
-
1
)
return
x
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
View file @
6d2051cc
...
...
@@ -7,6 +7,7 @@ import triton.language as tl
from
packaging
import
version
from
vllm
import
_custom_ops
as
ops
from
vllm.attention.backends.utils
import
PAD_SLOT_ID
TRITON3
=
version
.
parse
(
triton
.
__version__
)
>=
version
.
parse
(
"3.0.0"
)
...
...
@@ -48,6 +49,7 @@ def _selective_scan_update_kernel(
z_ptr
,
out_ptr
,
state_batch_indices_ptr
,
pad_slot_id
,
# Matrix dimensions
batch
,
nheads
,
...
...
@@ -141,10 +143,11 @@ def _selective_scan_update_kernel(
if
HAS_Z
:
z_ptrs
=
z_ptr
+
offs_m
*
stride_z_dim
out_ptrs
=
out_ptr
+
offs_m
*
stride_out_dim
mask
=
(
offs_m
[:,
None
]
<
dim
)
&
(
offs_n
[
None
,
:]
<
dstate
)
if
HAS_STATE_BATCH_INDICES
:
mask
&=
(
state_batch_idx
!=
pad_slot_id
)
state
=
tl
.
load
(
state_ptrs
,
mask
=
mask
,
other
=
0.0
)
state
=
tl
.
load
(
state_ptrs
,
mask
=
(
offs_m
[:,
None
]
<
dim
)
&
(
offs_n
[
None
,
:]
<
dstate
),
other
=
0.0
)
x
=
tl
.
load
(
x_ptrs
,
mask
=
offs_m
<
dim
,
other
=
0.0
).
to
(
tl
.
float32
)
if
not
TIE_HDIM
:
dt
=
tl
.
load
(
dt_ptrs
,
mask
=
offs_m
<
dim
,
other
=
0.0
).
to
(
tl
.
float32
)
...
...
@@ -175,9 +178,11 @@ def _selective_scan_update_kernel(
dB
=
B
[
None
,
:]
*
dt
[:,
None
]
if
not
TIE_HDIM
else
B
*
dt
state
=
state
*
dA
+
dB
*
x
[:,
None
]
tl
.
store
(
state_ptrs
,
state
,
mask
=
(
offs_m
[:,
None
]
<
dim
)
&
(
offs_n
[
None
,
:]
<
dstate
))
mask
=
(
offs_m
[:,
None
]
<
dim
)
&
(
offs_n
[
None
,
:]
<
dstate
)
if
HAS_STATE_BATCH_INDICES
:
mask
&=
(
state_batch_idx
!=
pad_slot_id
)
tl
.
store
(
state_ptrs
,
state
,
mask
=
mask
)
out
=
tl
.
sum
(
state
*
C
[
None
,
:],
axis
=
1
)
if
HAS_D
:
out
+=
x
*
D
...
...
@@ -196,7 +201,8 @@ def selective_state_update(state,
z
=
None
,
dt_bias
=
None
,
dt_softplus
=
False
,
state_batch_indices
=
None
):
state_batch_indices
=
None
,
pad_slot_id
=
PAD_SLOT_ID
):
"""
Argument:
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
...
...
@@ -208,6 +214,12 @@ def selective_state_update(state,
D: (dim,) or (nheads, dim)
z: (batch, dim) or (batch, nheads, dim)
dt_bias: (dim,) or (nheads, dim)
pad_slot_id: int
if cache_indices is passed, lets the kernel identify padded
entries that will not be processed,
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
in this case, the kernel will not process entries at
indices 0 and 3
Return:
out: (batch, dim) or (batch, nheads, dim)
"""
...
...
@@ -274,6 +286,7 @@ def selective_state_update(state,
z
,
out
,
state_batch_indices
,
pad_slot_id
,
batch
,
nheads
,
dim
,
...
...
@@ -318,6 +331,7 @@ def selective_state_update(state,
def
selective_scan_fn
(
u
,
ssm_states
,
delta
,
A
,
B
,
...
...
@@ -326,11 +340,45 @@ def selective_scan_fn(u,
z
=
None
,
delta_bias
=
None
,
delta_softplus
=
False
,
return_last_state
=
False
,
position_indices
=
None
,
prev_state
=
None
):
"""if return_last_state is True, returns (out, last_state)
last_state has shape (batch, dim, dstate).
query_start_loc
=
None
,
cache_indices
=
None
,
has_initial_state
=
None
,
pad_slot_id
=
PAD_SLOT_ID
)
->
torch
.
Tensor
:
"""
u: (dim, total_length) for varlen or (batch, dim, seqlen)
applies changes in place.
ssm_states: (batch, dim, dstate) or (batch, nheads, dim, dstate)
applies changes in place.
delta: (dim, total_length) for varlen or (batch, dim, seqlen)
A: (dim, dstate)
B: (ngroups, dstate, total_length) for varlen or
(batch,ngroups,dstate,seqlen)
C: (ngroups, dstate, total_length) for varlen or
(batch,ngroups,dstate,seqlen)
D: (dim,)
z: (dim, total_length) for varlen or (batch, dim, seqlen)
dt_bias: (dim,) or (dim)
query_start_loc: (batch + 1) int32
The cumulative sequence lengths of the sequences in
the batch, used to index into sequence. prepended with 0.
for example: query_start_loc = torch.Tensor([0,10,16,17]),
x.shape=(dim,17)
cache_indices: (batch) int32
A tensor with each cell is a correspondent
input and output ssm_state index
has_initial_state: (batch) bool
A tensor populated with ones and zeros,
indicate if the ssm_state at the corresponding index should be
used as initial state. Not providing argument assumes
there's no initial state
pad_slot_id: int
if cache_indices is passed, lets the kernel identify padding entries
that will not be processed,
for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
in this case, the kernel will not process entries at indices 0 and 3
returns
output: (dim, total_length) for varlen or (batch, dim, seqlen)
supports inplace replacement
"""
if
u
.
stride
(
-
1
)
!=
1
:
u
=
u
.
contiguous
()
...
...
@@ -344,28 +392,20 @@ def selective_scan_fn(u,
C
=
C
.
contiguous
()
if
z
is
not
None
and
z
.
stride
(
-
1
)
!=
1
:
z
=
z
.
contiguous
()
if
B
.
dim
()
==
3
:
if
B
.
dim
()
==
3
and
query_start_loc
is
None
:
B
=
B
.
unsqueeze
(
1
)
if
C
.
dim
()
==
3
:
if
B
.
dim
()
==
2
and
query_start_loc
is
not
None
:
B
=
B
.
unsqueeze
(
0
)
if
C
.
dim
()
==
3
and
query_start_loc
is
None
:
C
=
C
.
unsqueeze
(
1
)
n_chunks
=
int
((
u
.
shape
[
-
1
]
+
2048
-
1
)
/
2048
)
x
=
torch
.
zeros
((
u
.
shape
[
0
],
u
.
shape
[
1
],
n_chunks
,
int
(
A
.
shape
[
1
]
*
2
),
),
device
=
u
.
device
,
dtype
=
torch
.
float32
,
requires_grad
=
False
)
x
[:,
:,
0
,
0
::
2
]
=
1
if
prev_state
is
not
None
:
x
[:,
:,
0
,
1
::
2
].
copy_
(
prev_state
)
out
,
*
rest
=
ops
.
selective_scan_fwd
(
u
,
delta
,
A
,
B
,
C
,
D
,
z
,
delta_bias
,
delta_softplus
,
position_indices
,
x
)
last_state
=
x
[:,
:,
-
1
,
1
::
2
]
# (batch, dim, dstate)
if
C
.
dim
()
==
2
and
query_start_loc
is
not
None
:
C
=
C
.
unsqueeze
(
0
)
ops
.
selective_scan_fwd
(
u
,
delta
,
A
,
B
,
C
,
D
,
z
,
delta_bias
,
delta_softplus
,
query_start_loc
,
cache_indices
,
has_initial_state
,
ssm_states
,
pad_slot_id
)
if
z
is
None
:
return
out
if
not
return_last_state
else
(
out
,
last_state
)
return
delta
# output written inplace to delta
else
:
out_z
=
rest
[
0
]
return
out_z
if
not
return_last_state
else
(
out_z
,
last_state
)
return
z
# output written inplace to z
vllm/model_executor/layers/pooler.py
View file @
6d2051cc
...
...
@@ -11,6 +11,7 @@ from vllm.sequence import EmbeddingSequenceGroupOutput, PoolerOutput
class
PoolingType
(
IntEnum
):
"""Enumeration for different types of pooling methods."""
LAST
=
0
ALL
=
1
class
Pooler
(
nn
.
Module
):
...
...
@@ -43,6 +44,12 @@ class Pooler(nn.Module):
if
self
.
pooling_type
==
PoolingType
.
LAST
:
last_token_flat_indices
=
torch
.
cumsum
(
prompt_lens
,
dim
=
0
)
-
1
pooled_data
=
hidden_states
[
last_token_flat_indices
]
elif
self
.
pooling_type
==
PoolingType
.
ALL
:
offset
=
0
pooled_data
=
[]
for
prompt_len
in
prompt_lens
:
pooled_data
.
append
(
hidden_states
[
offset
:
offset
+
prompt_len
])
offset
+=
prompt_len
else
:
raise
ValueError
(
f
"Invalid pooling type:
{
self
.
pooling_type
}
"
)
...
...
vllm/model_executor/layers/quantization/__init__.py
View file @
6d2051cc
...
...
@@ -21,6 +21,7 @@ from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinConfig
)
from
vllm.model_executor.layers.quantization.gptq_marlin_24
import
(
GPTQMarlin24Config
)
from
vllm.model_executor.layers.quantization.ipex_quant
import
IPEXConfig
from
vllm.model_executor.layers.quantization.marlin
import
MarlinConfig
from
vllm.model_executor.layers.quantization.modelopt
import
ModelOptFp8Config
from
vllm.model_executor.layers.quantization.neuron_quant
import
(
...
...
@@ -49,6 +50,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"qqq"
:
QQQConfig
,
"experts_int8"
:
ExpertsInt8Config
,
"neuron_quant"
:
NeuronQuantConfig
,
"ipex"
:
IPEXConfig
,
}
...
...
Prev
1
…
15
16
17
18
19
20
21
22
23
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment