Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2216a4e5
"vscode:/vscode.git/clone" did not exist on "0313cf854d87a41c84efb69e89a79cd7b5897593"
Commit
2216a4e5
authored
Oct 23, 2024
by
zhuwenwen
Browse files
Merge remote-tracking branch 'mirror/main'
parents
ad385667
51c24c97
Changes
239
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
928 additions
and
86 deletions
+928
-86
vllm/entrypoints/openai/tool_parsers/__init__.py
vllm/entrypoints/openai/tool_parsers/__init__.py
+3
-1
vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py
vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py
+2
-1
vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py
vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py
+300
-0
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
+1
-1
vllm/envs.py
vllm/envs.py
+28
-1
vllm/logger.py
vllm/logger.py
+3
-1
vllm/model_executor/custom_op.py
vllm/model_executor/custom_op.py
+65
-5
vllm/model_executor/layers/activation.py
vllm/model_executor/layers/activation.py
+27
-11
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
+3
-3
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+1
-4
vllm/model_executor/layers/layernorm.py
vllm/model_executor/layers/layernorm.py
+2
-1
vllm/model_executor/layers/logits_processor.py
vllm/model_executor/layers/logits_processor.py
+6
-4
vllm/model_executor/layers/pooler.py
vllm/model_executor/layers/pooler.py
+10
-2
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+22
-12
vllm/model_executor/layers/rotary_embedding.py
vllm/model_executor/layers/rotary_embedding.py
+2
-1
vllm/model_executor/model_loader/neuron.py
vllm/model_executor/model_loader/neuron.py
+1
-30
vllm/model_executor/models/bert.py
vllm/model_executor/models/bert.py
+419
-0
vllm/model_executor/models/eagle.py
vllm/model_executor/models/eagle.py
+1
-1
vllm/model_executor/models/gemma2.py
vllm/model_executor/models/gemma2.py
+1
-7
vllm/model_executor/models/intern_vit.py
vllm/model_executor/models/intern_vit.py
+31
-0
No files found.
vllm/entrypoints/openai/tool_parsers/__init__.py
View file @
2216a4e5
from
.abstract_tool_parser
import
ToolParser
,
ToolParserManager
from
.hermes_tool_parser
import
Hermes2ProToolParser
from
.internlm2_tool_parser
import
Internlm2ToolParser
from
.jamba_tool_parser
import
JambaToolParser
from
.llama_tool_parser
import
Llama3JsonToolParser
from
.mistral_tool_parser
import
MistralToolParser
__all__
=
[
"ToolParser"
,
"ToolParserManager"
,
"Hermes2ProToolParser"
,
"MistralToolParser"
,
"Internlm2ToolParser"
,
"Llama3JsonToolParser"
"MistralToolParser"
,
"Internlm2ToolParser"
,
"Llama3JsonToolParser"
,
"JambaToolParser"
]
vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py
View file @
2216a4e5
...
...
@@ -53,7 +53,8 @@ class Hermes2ProToolParser(ToolParser):
self
.
tool_call_start_token_id
=
self
.
vocab
.
get
(
self
.
tool_call_start_token
)
self
.
tool_call_end_token_id
=
self
.
vocab
.
get
(
self
.
tool_call_end_token
)
if
not
self
.
tool_call_start_token_id
or
not
self
.
tool_call_end_token_id
:
if
(
self
.
tool_call_start_token_id
is
None
or
self
.
tool_call_end_token_id
is
None
):
raise
RuntimeError
(
"Hermes 2 Pro Tool parser could not locate tool call start/end "
"tokens in the tokenizer!"
)
...
...
vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py
0 → 100644
View file @
2216a4e5
import
json
import
re
from
typing
import
Dict
,
List
,
Sequence
,
Union
import
partial_json_parser
from
partial_json_parser.core.options
import
Allow
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
DeltaFunctionCall
,
DeltaMessage
,
DeltaToolCall
,
ExtractedToolCallInformation
,
FunctionCall
,
ToolCall
)
from
vllm.entrypoints.openai.tool_parsers
import
ToolParser
,
ToolParserManager
from
vllm.entrypoints.openai.tool_parsers.utils
import
(
extract_intermediate_diff
)
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.transformers_utils.tokenizers
import
MistralTokenizer
from
vllm.utils
import
random_uuid
logger
=
init_logger
(
__name__
)
@
ToolParserManager
.
register_module
(
"jamba"
)
class
JambaToolParser
(
ToolParser
):
def
__init__
(
self
,
tokenizer
:
AnyTokenizer
):
super
().
__init__
(
tokenizer
)
if
isinstance
(
self
.
model_tokenizer
,
MistralTokenizer
):
raise
ValueError
(
"Detected a MistralTokenizer tokenizer when using a Jamba model"
)
self
.
current_tool_name_sent
:
bool
=
False
self
.
prev_tool_call_arr
:
List
[
Dict
]
=
[]
self
.
current_tool_id
:
int
=
-
1
self
.
streamed_args_for_tool
:
List
[
str
]
=
[
]
# map what has been streamed for each tool so far to a list
self
.
tool_calls_start_token
:
str
=
"<tool_calls>"
self
.
tool_calls_end_token
:
str
=
"</tool_calls>"
self
.
tool_calls_regex
=
re
.
compile
(
rf
"
{
self
.
tool_calls_start_token
}
(.*?)
{
self
.
tool_calls_end_token
}
"
,
re
.
DOTALL
)
if
not
self
.
model_tokenizer
:
raise
ValueError
(
"The model tokenizer must be passed to the ToolParser "
"constructor during construction."
)
self
.
tool_calls_start_token_id
=
self
.
vocab
.
get
(
self
.
tool_calls_start_token
)
self
.
tool_calls_end_token_id
=
self
.
vocab
.
get
(
self
.
tool_calls_end_token
)
if
(
self
.
tool_calls_start_token_id
is
None
or
self
.
tool_calls_end_token_id
is
None
):
raise
RuntimeError
(
"Jamba Tool parser could not locate tool calls start/end "
"tokens in the tokenizer!"
)
def
adjust_request
(
self
,
request
:
ChatCompletionRequest
)
->
ChatCompletionRequest
:
if
request
.
tools
and
request
.
tool_choice
!=
'none'
:
# do not skip special tokens because jamba use the special
# tokens to indicate the start and end of the tool calls
# information.
request
.
skip_special_tokens
=
False
return
request
def
extract_tool_calls
(
self
,
model_output
:
str
,
request
:
ChatCompletionRequest
)
->
ExtractedToolCallInformation
:
# sanity check; avoid unnecessary processing
if
self
.
tool_calls_start_token
not
in
model_output
:
return
ExtractedToolCallInformation
(
tools_called
=
False
,
tool_calls
=
[],
content
=
model_output
)
else
:
try
:
# use a regex to find the tool call between the tags
function_calls
=
self
.
tool_calls_regex
.
findall
(
model_output
)[
0
]
# load the JSON, and then use it to build the Function and
# Tool Call
raw_function_calls
=
json
.
loads
(
function_calls
)
tool_calls
=
[
ToolCall
(
type
=
"function"
,
function
=
FunctionCall
(
name
=
function_call
[
"name"
],
# function call args are JSON but as a string
arguments
=
json
.
dumps
(
function_call
[
"arguments"
])))
for
function_call
in
raw_function_calls
]
content
=
model_output
[:
model_output
.
find
(
self
.
tool_calls_start_token
)]
return
ExtractedToolCallInformation
(
tools_called
=
True
,
tool_calls
=
tool_calls
,
content
=
content
if
(
len
(
content
)
>
0
and
content
!=
" "
)
else
None
)
except
Exception
:
logger
.
exception
(
"Error in extracting tool call from response."
)
return
ExtractedToolCallInformation
(
tools_called
=
False
,
tool_calls
=
[],
content
=
model_output
)
def
extract_tool_calls_streaming
(
self
,
previous_text
:
str
,
current_text
:
str
,
delta_text
:
str
,
previous_token_ids
:
Sequence
[
int
],
current_token_ids
:
Sequence
[
int
],
delta_token_ids
:
Sequence
[
int
],
request
:
ChatCompletionRequest
,
)
->
Union
[
DeltaMessage
,
None
]:
# if the tool call token is not in the tokens generated so far, append
# output to contents since it's not a tool
if
self
.
tool_calls_start_token
not
in
current_text
:
return
DeltaMessage
(
content
=
delta_text
)
# if the tool call token ID IS in the tokens generated so far, that
# means we're parsing as tool calls now
# handle if we detected the start of tool calls token which means
# the start of tool calling
if
(
self
.
tool_calls_start_token_id
in
delta_token_ids
and
len
(
delta_token_ids
)
==
1
):
# if it's the only token, return None, so we don't send a chat
# completion and don't send a control token
return
None
# bit mask flags for partial JSON parsing. If the name hasn't been
# sent yet, don't allow sending
# an incomplete string since OpenAI only ever (as far as I have
# seen) allows sending the entire tool/ function name at once.
flags
=
Allow
.
ALL
if
self
.
current_tool_name_sent
\
else
Allow
.
ALL
&
~
Allow
.
STR
try
:
# Extract the tool calls between the special tool call tokens
parsable_arr
=
current_text
.
split
(
self
.
tool_calls_start_token
)[
-
1
].
split
(
self
.
tool_calls_end_token
)[
0
]
# tool calls are generated in an array, so do partial JSON
# parsing on the entire array
try
:
tool_call_arr
:
List
[
Dict
]
=
partial_json_parser
.
loads
(
parsable_arr
,
flags
)
except
partial_json_parser
.
core
.
exceptions
.
MalformedJSON
:
logger
.
debug
(
'not enough tokens to parse into JSON yet'
)
return
None
# select as the current tool call the one we're on the state at
current_tool_call
:
Dict
=
tool_call_arr
[
self
.
current_tool_id
]
\
if
len
(
tool_call_arr
)
>
0
else
{}
# case -- if no tokens have been streamed for the tool, e.g.
# only the array brackets, stream nothing
if
len
(
tool_call_arr
)
==
0
:
return
None
# case: we are starting a new tool in the array
# -> array has > 0 length AND length has moved past cursor
elif
(
len
(
tool_call_arr
)
>
0
and
len
(
tool_call_arr
)
>
self
.
current_tool_id
+
1
):
# if we're moving on to a new call, first make sure we
# haven't missed anything in the previous one that was
# auto-generated due to JSON completions, but wasn't
# streamed to the client yet.
if
self
.
current_tool_id
>=
0
:
diff
:
Union
[
str
,
None
]
=
current_tool_call
.
get
(
"arguments"
)
if
diff
:
diff
=
json
.
dumps
(
diff
).
replace
(
self
.
streamed_args_for_tool
[
self
.
current_tool_id
],
""
)
delta
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
current_tool_id
,
function
=
DeltaFunctionCall
(
arguments
=
diff
).
model_dump
(
exclude_none
=
True
))
])
self
.
streamed_args_for_tool
[
self
.
current_tool_id
]
+=
diff
else
:
delta
=
None
else
:
delta
=
None
# re-set stuff pertaining to progress in the current tool
self
.
current_tool_id
=
len
(
tool_call_arr
)
-
1
self
.
current_tool_name_sent
=
False
self
.
streamed_args_for_tool
.
append
(
""
)
logger
.
debug
(
"starting on new tool %d"
,
self
.
current_tool_id
)
return
delta
# case: update an existing tool - this is handled below
# if the current tool name hasn't been sent, send if available
# - otherwise send nothing
if
not
self
.
current_tool_name_sent
:
function_name
=
current_tool_call
.
get
(
"name"
)
if
function_name
:
delta
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
current_tool_id
,
type
=
"function"
,
id
=
f
"chatcmpl-tool-
{
random_uuid
()
}
"
,
function
=
DeltaFunctionCall
(
name
=
function_name
).
model_dump
(
exclude_none
=
True
))
])
self
.
current_tool_name_sent
=
True
else
:
delta
=
None
# now we know we're on the same tool call and we're streaming
# arguments
else
:
prev_arguments
=
self
.
prev_tool_call_arr
[
self
.
current_tool_id
].
get
(
"arguments"
)
cur_arguments
=
current_tool_call
.
get
(
"arguments"
)
new_text
=
delta_text
.
replace
(
"
\'
"
,
"
\"
"
)
if
not
cur_arguments
and
not
prev_arguments
:
delta
=
None
elif
not
cur_arguments
and
prev_arguments
:
logger
.
error
(
"INVARIANT - impossible to have arguments reset "
"mid-arguments"
)
delta
=
None
elif
cur_arguments
and
not
prev_arguments
:
cur_arguments_json
=
json
.
dumps
(
cur_arguments
)
logger
.
debug
(
"finding %s in %s"
,
new_text
,
cur_arguments_json
)
arguments_delta
=
cur_arguments_json
[:
cur_arguments_json
.
index
(
new_text
)
+
len
(
new_text
)]
logger
.
debug
(
"First tokens in arguments received: %s"
,
arguments_delta
)
delta
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
current_tool_id
,
function
=
DeltaFunctionCall
(
arguments
=
arguments_delta
).
model_dump
(
exclude_none
=
True
))
])
self
.
streamed_args_for_tool
[
self
.
current_tool_id
]
+=
arguments_delta
elif
cur_arguments
and
prev_arguments
:
cur_args_json
=
json
.
dumps
(
cur_arguments
)
prev_args_json
=
json
.
dumps
(
prev_arguments
)
logger
.
debug
(
"Searching for diff between
\n
%s
\n
%s"
,
cur_args_json
,
prev_args_json
)
argument_diff
=
extract_intermediate_diff
(
cur_args_json
,
prev_args_json
)
logger
.
debug
(
"got arguments diff: %s"
,
argument_diff
)
delta
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
current_tool_id
,
function
=
DeltaFunctionCall
(
arguments
=
argument_diff
).
model_dump
(
exclude_none
=
True
))
])
self
.
streamed_args_for_tool
[
self
.
current_tool_id
]
+=
argument_diff
else
:
# try parsing it with regular JSON - if it works we're
# at the end, and we need to send the difference between
# tokens streamed so far and the valid JSON
delta
=
None
# check to see if the name is defined and has been sent. if so,
# stream the name - otherwise keep waiting
# finish by setting old and returning None as base case
self
.
prev_tool_call_arr
=
tool_call_arr
return
delta
except
Exception
:
logger
.
exception
(
"Error trying to handle streaming tool call."
)
logger
.
debug
(
"Skipping chunk as a result of tool streaming extraction "
"error"
)
return
None
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
View file @
2216a4e5
...
...
@@ -63,7 +63,7 @@ class MistralToolParser(ToolParser):
self
.
bot_token
=
"[TOOL_CALLS]"
self
.
bot_token_id
=
self
.
vocab
.
get
(
self
.
bot_token
)
self
.
tool_call_regex
=
re
.
compile
(
r
"\[{.*?}\]"
,
re
.
DOTALL
)
if
not
self
.
bot_token_id
:
if
self
.
bot_token_id
is
None
:
raise
RuntimeError
(
"Mistral Tool Parser could not locate the tool call token in "
"the tokenizer!"
)
...
...
vllm/envs.py
View file @
2216a4e5
...
...
@@ -30,11 +30,13 @@ if TYPE_CHECKING:
VLLM_USAGE_SOURCE
:
str
=
""
VLLM_CONFIGURE_LOGGING
:
int
=
1
VLLM_LOGGING_LEVEL
:
str
=
"INFO"
VLLM_LOGGING_PREFIX
:
str
=
""
VLLM_LOGGING_CONFIG_PATH
:
Optional
[
str
]
=
None
VLLM_TRACE_FUNCTION
:
int
=
0
VLLM_ATTENTION_BACKEND
:
Optional
[
str
]
=
None
VLLM_USE_FLASHINFER_SAMPLER
:
bool
=
False
VLLM_USE_FLASHINFER_REJECTION_SAMPLER
:
bool
=
False
VLLM_FLASHINFER_FORCE_TENSOR_CORES
:
bool
=
False
VLLM_PP_LAYER_PARTITION
:
Optional
[
str
]
=
None
VLLM_CPU_KVCACHE_SPACE
:
int
=
0
VLLM_CPU_OMP_THREADS_BIND
:
str
=
""
...
...
@@ -68,7 +70,9 @@ if TYPE_CHECKING:
VLLM_ALLOW_RUNTIME_LORA_UPDATING
:
bool
=
False
VLLM_SKIP_P2P_CHECK
:
bool
=
False
VLLM_TORCH_COMPILE_LEVEL
:
int
=
0
VLLM_CUSTOM_OPS
:
List
[
str
]
=
[]
VLLM_DISABLED_KERNELS
:
List
[
str
]
=
[]
VLLM_USE_V1
:
bool
=
False
def
get_default_cache_root
():
...
...
@@ -223,7 +227,17 @@ environment_variables: Dict[str, Callable[[], Any]] = {
os
.
environ
.
get
(
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"
,
"1"
)
!=
"0"
),
"VLLM_TORCH_COMPILE_LEVEL"
:
lambda
:
int
(
os
.
environ
.
get
(
"VLLM_TORCH_COMPILE_LEVEL"
,
"0"
)),
# Fine-grained control over which custom ops to enable/disable.
# Use 'all' to enable all, 'none' to disable all.
# Also specify a list of custom op names to enable (prefixed with a '+'),
# or disable (prefixed with a '-').
# Examples:
# - 'all,-op1' to enable all except op1
# - 'none,+op1,+op2' to enable only op1 and op2
# By default, all custom ops are enabled when running without Inductor
# and disabled when running with Inductor (compile_level >= Inductor).
"VLLM_CUSTOM_OPS"
:
lambda
:
os
.
environ
.
get
(
"VLLM_CUSTOM_OPS"
,
""
).
replace
(
" "
,
""
).
split
(
","
),
# local rank of the process in the distributed setting, used to determine
# the GPU device id
"LOCAL_RANK"
:
...
...
@@ -273,6 +287,10 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_LOGGING_LEVEL"
:
lambda
:
os
.
getenv
(
"VLLM_LOGGING_LEVEL"
,
"INFO"
),
# if set, VLLM_LOGGING_PREFIX will be prepended to all log messages
"VLLM_LOGGING_PREFIX"
:
lambda
:
os
.
getenv
(
"VLLM_LOGGING_PREFIX"
,
""
),
# Trace function calls
# If set to 1, vllm will trace function calls
# Useful for debugging
...
...
@@ -293,6 +311,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_USE_FLASHINFER_SAMPLER"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_FLASHINFER_SAMPLER"
,
"0"
))),
# If set, vllm will force flashinfer to use tensor cores;
# otherwise will use heuristic based on model architecture.
"VLLM_FLASHINFER_FORCE_TENSOR_CORES"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_FLASHINFER_FORCE_TENSOR_CORES"
,
"0"
))),
# Pipeline stage partition strategy
"VLLM_PP_LAYER_PARTITION"
:
lambda
:
os
.
getenv
(
"VLLM_PP_LAYER_PARTITION"
,
None
),
...
...
@@ -451,6 +474,10 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_DISABLED_KERNELS"
:
lambda
:
[]
if
"VLLM_DISABLED_KERNELS"
not
in
os
.
environ
else
os
.
environ
[
"VLLM_DISABLED_KERNELS"
].
split
(
","
),
# If set, use the V1 code path.
"VLLM_USE_V1"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_V1"
,
"0"
))),
}
# end-env-vars-definition
...
...
vllm/logger.py
View file @
2216a4e5
...
...
@@ -15,8 +15,10 @@ import vllm.envs as envs
VLLM_CONFIGURE_LOGGING
=
envs
.
VLLM_CONFIGURE_LOGGING
VLLM_LOGGING_CONFIG_PATH
=
envs
.
VLLM_LOGGING_CONFIG_PATH
VLLM_LOGGING_LEVEL
=
envs
.
VLLM_LOGGING_LEVEL
VLLM_LOGGING_PREFIX
=
envs
.
VLLM_LOGGING_PREFIX
_FORMAT
=
"%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s"
_FORMAT
=
(
f
"
{
VLLM_LOGGING_PREFIX
}
%(levelname)s %(asctime)s "
"%(filename)s:%(lineno)d] %(message)s"
)
_DATE_FORMAT
=
"%m-%d %H:%M:%S"
DEFAULT_LOGGING_CONFIG
=
{
...
...
vllm/model_executor/custom_op.py
View file @
2216a4e5
from
functools
import
lru_cache
from
typing
import
Dict
,
Type
import
torch.nn
as
nn
import
vllm.envs
as
envs
from
vllm.compilation.levels
import
CompilationLevel
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.utils
import
is_cpu
,
is_hip
,
is_xpu
from
vllm.utils
import
is_hip
,
is_xpu
,
print_warning_once
logger
=
init_logger
(
__name__
)
class
CustomOp
(
nn
.
Module
):
"""
Base class for custom ops.
Dispatches the forward method to the appropriate backend.
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
):
super
().
__init__
()
self
.
_forward_method
=
self
.
dispatch_forward
()
...
...
@@ -17,7 +27,6 @@ class CustomOp(nn.Module):
def
forward_native
(
self
,
*
args
,
**
kwargs
):
"""PyTorch-native implementation of the forward method.
This method is optional. If implemented, it can be used with compilers
such as torch.compile or PyTorch XLA. Also, it can be used for testing
purposes.
...
...
@@ -56,12 +65,16 @@ class CustomOp(nn.Module):
# NOTE(woosuk): Here we assume that vLLM was built for only one
# specific backend. Currently, we do not support dynamic dispatching.
if
envs
.
VLLM_TORCH_COMPILE_LEVEL
>=
CompilationLevel
.
INDUCTOR
:
enabled
=
self
.
enabled
()
logger
.
debug
(
"custom op %s %s"
,
self
.
__class__
.
name
,
"enabled"
if
enabled
else
"disabled"
)
if
not
enabled
:
return
self
.
forward_native
if
is_hip
():
return
self
.
forward_hip
elif
is_cpu
():
elif
current_platform
.
is_cpu
():
return
self
.
forward_cpu
elif
current_platform
.
is_tpu
():
return
self
.
forward_tpu
...
...
@@ -69,3 +82,50 @@ class CustomOp(nn.Module):
return
self
.
forward_xpu
else
:
return
self
.
forward_cuda
@
classmethod
def
enabled
(
cls
)
->
bool
:
# if no name, then it was not registered
if
not
hasattr
(
cls
,
"name"
):
print_warning_once
(
f
"Custom op
{
cls
.
__name__
}
was not registered, "
f
"which means it won't appear in the op registry. "
f
"It will be enabled/disabled based on the global settings."
)
return
CustomOp
.
default_on
()
enabled
=
f
"+
{
cls
.
name
}
"
in
envs
.
VLLM_CUSTOM_OPS
disabled
=
f
"-
{
cls
.
name
}
"
in
envs
.
VLLM_CUSTOM_OPS
assert
not
(
enabled
and
disabled
),
f
"Cannot enable and disable
{
cls
.
name
}
"
return
(
CustomOp
.
default_on
()
or
enabled
)
and
not
disabled
# On by default if VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.INDUCTOR
# Specifying 'all' or 'none' in VLLM_CUSTOM_OPS takes precedence.
@
staticmethod
@
lru_cache
()
def
default_on
()
->
bool
:
count_none
=
envs
.
VLLM_CUSTOM_OPS
.
count
(
"none"
)
count_all
=
envs
.
VLLM_CUSTOM_OPS
.
count
(
"all"
)
assert
count_none
+
count_all
<=
1
,
"Can only specify 'none' or 'all'"
return
envs
.
VLLM_TORCH_COMPILE_LEVEL
<
CompilationLevel
.
INDUCTOR
and
\
not
count_none
>
0
or
count_all
>
0
# Dictionary of all custom ops (classes, indexed by registered name).
# To check if an op with a name is enabled, call .enabled() on the class.
# Examples:
# - MyOp.enabled()
# - op_registry["my_op"].enabled()
op_registry
:
Dict
[
str
,
Type
[
'CustomOp'
]]
=
{}
# Decorator to register custom ops.
@
classmethod
def
register
(
cls
,
name
:
str
):
def
decorator
(
op_cls
):
assert
name
not
in
cls
.
op_registry
,
f
"Duplicate op name:
{
name
}
"
op_cls
.
name
=
name
cls
.
op_registry
[
name
]
=
op_cls
return
op_cls
return
decorator
vllm/model_executor/layers/activation.py
View file @
2216a4e5
...
...
@@ -11,12 +11,14 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.utils
import
LazyDict
import
vllm.envs
as
envs
@
CustomOp
.
register
(
"fatrelu_and_mul"
)
class
FatreluAndMul
(
CustomOp
):
"""An activation function for FATReLU.
The function computes x -> FATReLU(x[:d]) * x[d:] where
d = x.shape[-1] // 2.
This is used in openbmb/MiniCPM-S-1B-sft.
...
...
@@ -41,6 +43,7 @@ class FatreluAndMul(CustomOp):
return
self
.
forward_native
(
x
)
@
CustomOp
.
register
(
"silu_and_mul"
)
class
SiluAndMul
(
CustomOp
):
"""An activation function for SwiGLU.
...
...
@@ -78,6 +81,7 @@ class SiluAndMul(CustomOp):
return
out
@
CustomOp
.
register
(
"gelu_and_mul"
)
class
GeluAndMul
(
CustomOp
):
"""An activation function for GeGLU.
...
...
@@ -133,6 +137,7 @@ class GeluAndMul(CustomOp):
return
f
'approximate=
{
repr
(
self
.
approximate
)
}
'
@
CustomOp
.
register
(
"gelu_new"
)
class
NewGELU
(
CustomOp
):
def
forward_native
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
@@ -154,6 +159,7 @@ class NewGELU(CustomOp):
return
ops
.
gelu_new
(
x
)
@
CustomOp
.
register
(
"gelu_fast"
)
class
FastGELU
(
CustomOp
):
def
forward_native
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
@@ -174,8 +180,8 @@ class FastGELU(CustomOp):
return
ops
.
gelu_fast
(
x
)
@
CustomOp
.
register
(
"quick_gelu"
)
class
QuickGELU
(
CustomOp
):
# https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
def
forward_native
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""PyTorch-native implementation equivalent to forward()."""
...
...
@@ -199,6 +205,7 @@ class QuickGELU(CustomOp):
# def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
@
CustomOp
.
register
(
"relu2"
)
class
ReLUSquaredActivation
(
CustomOp
):
"""
Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
...
...
@@ -254,15 +261,24 @@ class ScaledActivation(nn.Module):
param_data
.
copy_
(
loaded_weight
)
_ACTIVATION_REGISTRY
=
{
"gelu"
:
nn
.
GELU
(),
"gelu_fast"
:
FastGELU
(),
"gelu_new"
:
NewGELU
(),
"gelu_pytorch_tanh"
:
nn
.
GELU
(
approximate
=
"tanh"
),
"relu"
:
nn
.
ReLU
(),
"relu2"
:
ReLUSquaredActivation
(),
"quick_gelu"
:
QuickGELU
(),
}
_ACTIVATION_REGISTRY
=
LazyDict
({
"gelu"
:
lambda
:
nn
.
GELU
(),
"gelu_fast"
:
lambda
:
FastGELU
(),
"gelu_new"
:
lambda
:
NewGELU
(),
"gelu_pytorch_tanh"
:
lambda
:
nn
.
GELU
(
approximate
=
"tanh"
),
"relu"
:
lambda
:
nn
.
ReLU
(),
"relu2"
:
lambda
:
ReLUSquaredActivation
(),
"silu"
:
lambda
:
nn
.
SiLU
(),
"quick_gelu"
:
lambda
:
QuickGELU
(),
})
def
get_act_fn
(
...
...
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
View file @
2216a4e5
...
...
@@ -116,7 +116,7 @@ def single_marlin_moe(
intermediate_cache
=
torch
.
ops
.
_moe_C
.
marlin_gemm_moe
(
hidden_states
,
w
,
sorted_token_ids
,
topk_weights
,
topk_ids
,
scales
,
w_zeros
,
g_idx
,
sort_indices
,
workspace
,
scalar_type
,
M
,
N
,
K
,
w_zeros
,
g_idx
,
sort_indices
,
workspace
,
scalar_type
.
id
,
M
,
N
,
K
,
is_k_full
,
E
,
topk
,
block_size_m
,
True
,
False
)
return
torch
.
sum
(
intermediate_cache
.
view
(
*
intermediate_cache
.
shape
),
dim
=
1
)
...
...
@@ -272,7 +272,7 @@ def fused_marlin_moe(
g_idx1
,
sort_indices1
,
workspace
,
scalar_type1
,
scalar_type1
.
id
,
M
,
2
*
N
,
K
,
...
...
@@ -297,7 +297,7 @@ def fused_marlin_moe(
g_idx2
,
sort_indices2
,
workspace
,
scalar_type2
,
scalar_type2
.
id
,
M
,
K
,
N
,
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
2216a4e5
...
...
@@ -37,13 +37,13 @@ class FusedMoEMethodBase(QuantizeMethodBase):
raise
NotImplementedError
@
CustomOp
.
register
(
"unquantized_fused_moe"
)
class
UnquantizedFusedMoEMethod
(
FusedMoEMethodBase
,
CustomOp
):
"""MoE method without quantization."""
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
# Fused gate_up_proj (column parallel)
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size
,
...
...
@@ -74,7 +74,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
)
->
torch
.
Tensor
:
return
self
.
forward
(
x
=
x
,
layer
=
layer
,
router_logits
=
router_logits
,
...
...
@@ -97,7 +96,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_experts
)
...
...
@@ -134,7 +132,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe.moe_pallas
import
fused_moe
assert
not
use_grouped_topk
assert
num_expert_group
is
None
...
...
vllm/model_executor/layers/layernorm.py
View file @
2216a4e5
...
...
@@ -8,6 +8,7 @@ from vllm.model_executor.custom_op import CustomOp
import
vllm.envs
as
envs
@
CustomOp
.
register
(
"rms_norm"
)
class
RMSNorm
(
CustomOp
):
"""Root mean square normalization.
...
...
@@ -27,7 +28,6 @@ class RMSNorm(CustomOp):
self
.
variance_epsilon
=
eps
self
.
variance_size_override
=
(
None
if
var_hidden_size
==
hidden_size
else
var_hidden_size
)
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
))
def
forward_native
(
...
...
@@ -139,6 +139,7 @@ class RMSNorm(CustomOp):
return
s
@
CustomOp
.
register
(
"gemma_rms_norm"
)
class
GemmaRMSNorm
(
CustomOp
):
"""RMS normalization for Gemma.
...
...
vllm/model_executor/layers/logits_processor.py
View file @
2216a4e5
...
...
@@ -48,14 +48,15 @@ class LogitsProcessor(nn.Module):
self
,
lm_head
:
VocabParallelEmbedding
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
Optional
[
SamplingMetadata
]
=
None
,
embedding_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Optional
[
torch
.
Tensor
]:
if
self
.
logits_as_input
:
logits
=
hidden_states
else
:
hidden_states
=
_prune_hidden_states
(
hidden_states
,
sampling_metadata
)
if
sampling_metadata
is
not
None
:
hidden_states
=
_prune_hidden_states
(
hidden_states
,
sampling_metadata
)
# Get the logits for the next tokens.
logits
=
self
.
_get_logits
(
hidden_states
,
lm_head
,
embedding_bias
)
...
...
@@ -69,7 +70,8 @@ class LogitsProcessor(nn.Module):
logits
*=
self
.
scale
# Apply logits processors (if any).
logits
=
_apply_logits_processors
(
logits
,
sampling_metadata
)
if
sampling_metadata
is
not
None
:
logits
=
_apply_logits_processors
(
logits
,
sampling_metadata
)
return
logits
...
...
vllm/model_executor/layers/pooler.py
View file @
2216a4e5
...
...
@@ -12,6 +12,7 @@ class PoolingType(IntEnum):
"""Enumeration for different types of pooling methods."""
LAST
=
0
ALL
=
1
CLS
=
2
class
Pooler
(
nn
.
Module
):
...
...
@@ -23,12 +24,13 @@ class Pooler(nn.Module):
3. Returns structured results as `PoolerOutput`.
Attributes:
pooling_type: The type of pooling to use (LAST, A
VERAGE, MAX
).
pooling_type: The type of pooling to use (LAST, A
LL, CLS
).
normalize: Whether to normalize the pooled data.
"""
def
__init__
(
self
,
pooling_type
:
PoolingType
,
normalize
:
bool
):
super
().
__init__
()
self
.
pooling_type
=
pooling_type
self
.
normalize
=
normalize
...
...
@@ -38,10 +40,16 @@ class Pooler(nn.Module):
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
"""Pools specific information from hidden states based on metadata."""
prompt_lens
=
PoolingTensors
.
from_pooling_metadata
(
pooling_metadata
,
hidden_states
.
device
).
prompt_lens
if
self
.
pooling_type
==
PoolingType
.
LAST
:
if
self
.
pooling_type
is
PoolingType
.
CLS
:
first_token_flat_indices
=
torch
.
zeros_like
(
prompt_lens
)
first_token_flat_indices
[
1
:]
+=
torch
.
cumsum
(
prompt_lens
,
dim
=
0
)[:
-
1
]
pooled_data
=
hidden_states
[
first_token_flat_indices
]
elif
self
.
pooling_type
==
PoolingType
.
LAST
:
last_token_flat_indices
=
torch
.
cumsum
(
prompt_lens
,
dim
=
0
)
-
1
pooled_data
=
hidden_states
[
last_token_flat_indices
]
elif
self
.
pooling_type
==
PoolingType
.
ALL
:
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
2216a4e5
...
...
@@ -100,12 +100,21 @@ class CompressedTensorsConfig(QuantizationConfig):
target_scheme_map
[
target
][
"weights"
]
=
QuantizationArgs
.
parse_obj
(
quant_config
.
get
(
"weights"
))
try
:
target_scheme_map
[
target
][
"input_activations"
]
=
QuantizationArgs
.
parse_obj
(
quant_config
.
get
(
"input_activations"
))
except
Exception
:
target_scheme_map
[
target
][
"input_activations"
]
=
None
target_scheme_map
[
target
][
"input_activations"
]
=
None
if
is_activation_quantization_format
(
quant_format
):
input_activations
=
quant_config
.
get
(
"input_activations"
)
# The only case where we have activation quant supported
# but no input_activations provided in the config
# should be w8a16fp8 w8a16fp8 can also run for cases where
# there is an input_quant but it is ignored
if
not
input_activations
:
assert
target_scheme_map
[
target
][
"weights"
].
type
==
QuantizationType
.
FLOAT
else
:
target_scheme_map
[
target
][
"input_activations"
]
=
QuantizationArgs
.
parse_obj
(
quant_config
.
get
(
"input_activations"
))
return
cls
(
target_scheme_map
=
target_scheme_map
,
ignore
=
ignore
,
...
...
@@ -244,8 +253,6 @@ class CompressedTensorsConfig(QuantizationConfig):
group_size
=
weight_quant
.
group_size
,
actorder
=
weight_quant
.
actorder
)
# Detect If Activation Quantization.
# TODO @dsikka: clean-up conditions
if
is_activation_quantization_format
(
self
.
quant_format
):
if
self
.
_is_fp8_w8a8
(
weight_quant
,
input_quant
):
is_fp8_w8a8_supported
=
self
.
_check_scheme_supported
(
...
...
@@ -256,16 +263,19 @@ class CompressedTensorsConfig(QuantizationConfig):
is_static_input_scheme
=
(
input_quant
and
not
input_quant
.
dynamic
))
else
:
# note: input_quant will be present for converted models;
# will be ignored during inference post loading
return
CompressedTensorsW8A16Fp8
(
strategy
=
weight_quant
.
strategy
,
is_static_input_scheme
=
(
input_quant
and
not
input_quant
.
dynamic
))
is_static_input_scheme
=
not
input_quant
.
dynamic
)
# note: input_quant can be None
if
self
.
_is_fp8_w8a16
(
weight_quant
,
input_quant
):
is_static_input_scheme
=
(
input_quant
and
not
input_quant
.
dynamic
)
return
CompressedTensorsW8A16Fp8
(
strategy
=
weight_quant
.
strategy
,
is_static_input_scheme
=
(
input_quant
and
not
input_quant
.
dynamic
))
is_static_input_scheme
=
is_static_input_scheme
)
if
self
.
_is_static_tensor_w8a8
(
weight_quant
,
input_quant
):
return
CompressedTensorsW8A8Int8
(
...
...
vllm/model_executor/layers/rotary_embedding.py
View file @
2216a4e5
...
...
@@ -72,6 +72,7 @@ def _apply_rotary_emb(
return
torch
.
stack
((
o1
,
o2
),
dim
=-
1
).
flatten
(
-
2
)
@
CustomOp
.
register
(
"rotary_embedding"
)
class
RotaryEmbedding
(
CustomOp
):
"""Original rotary positional embedding."""
...
...
@@ -468,7 +469,7 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
self
.
long_factor
=
long_factor
scale
=
self
.
max_position_embeddings
/
\
self
.
original_max_position_embeddings
self
.
original_max_position_embeddings
if
scale
<=
1.0
:
scaling_factor
=
1.0
else
:
...
...
vllm/model_executor/model_loader/neuron.py
View file @
2216a4e5
...
...
@@ -6,7 +6,6 @@ from typing import Dict, List, Optional, Tuple
import
torch
import
torch.nn
as
nn
import
transformers
from
transformers
import
PretrainedConfig
from
vllm.config
import
ModelConfig
,
ParallelConfig
,
SchedulerConfig
...
...
@@ -108,39 +107,11 @@ class NeuronCasualLM(nn.Module):
neuronx_module
=
importlib
.
import_module
(
neuronx_module_path
)
neuronx_model_cls
=
getattr
(
neuronx_module
,
neuronx_model_cls_name
)
split_model_dir
=
f
"
{
model_name_or_path
}
-split"
if
_is_pretrained_neuron_checkpoint
(
model_name_or_path
):
split_model_dir
=
model_name_or_path
elif
not
os
.
path
.
exists
(
f
"
{
model_name_or_path
}
-split"
):
hf_model_cls
=
getattr
(
transformers
,
hf_model_cls_name
)
from
transformers_neuronx.module
import
save_pretrained_split
hf_model
=
hf_model_cls
.
from_pretrained
(
model_name_or_path
,
low_cpu_mem_usage
=
True
)
save_pretrained_split
(
hf_model
,
f
"
{
model_name_or_path
}
-split"
)
self
.
model
=
neuronx_model_cls
.
from_pretrained
(
split_model_dir
,
self
.
model
=
neuronx_model_cls
.
from_pretrained
(
model_name_or_path
,
**
kwargs
)
self
.
model
.
to_neuron
()
def
_is_pretrained_neuron_checkpoint
(
model_name_or_path
:
str
)
->
bool
:
# Checking if the neuron checkpoint is saved in the old format.
if
os
.
path
.
isdir
(
os
.
path
.
join
(
model_name_or_path
,
"pytorch_model.bin"
)):
return
True
# Checking if the neuron checkpoint is saved in the new format.
pretrained_split_files
=
[
"config.json"
,
"generation_config.json"
]
pretrained_split_format
=
".safetensors"
for
file
in
pretrained_split_files
:
file_path
=
os
.
path
.
join
(
model_name_or_path
,
file
)
if
not
os
.
path
.
isfile
(
file_path
):
return
False
for
file
in
os
.
listdir
(
model_name_or_path
):
if
file
.
endswith
(
pretrained_split_format
):
return
True
return
False
def
_get_model_architecture
(
config
:
PretrainedConfig
)
->
str
:
architectures
=
getattr
(
config
,
"architectures"
,
[])
for
arch
in
architectures
:
...
...
vllm/model_executor/models/bert.py
0 → 100644
View file @
2216a4e5
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
import
torch
from
torch
import
nn
from
transformers
import
BertConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
,
AttentionType
from
vllm.attention.backends.xformers
import
XFormersImpl
from
vllm.config
import
CacheConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.pooler
import
Pooler
,
PoolingType
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.sequence
import
IntermediateTensors
,
PoolerOutput
class
BertEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
config
:
BertConfig
):
super
().
__init__
()
self
.
size
=
config
.
hidden_size
self
.
word_embeddings
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
position_embeddings
=
VocabParallelEmbedding
(
config
.
max_position_embeddings
,
config
.
hidden_size
)
self
.
token_type_embeddings
=
VocabParallelEmbedding
(
config
.
type_vocab_size
,
config
.
hidden_size
)
self
.
LayerNorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
self
.
position_ids
=
nn
.
Parameter
(
torch
.
empty
((
1
,
config
.
max_position_embeddings
)),
)
self
.
position_embedding_type
=
config
.
position_embedding_type
if
self
.
position_embedding_type
!=
"absolute"
:
raise
ValueError
(
"Only 'absolute' position_embedding_type"
+
" is supported"
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
position_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
input_shape
=
input_ids
.
size
()
# Input embeddings.
inputs_embeds
=
self
.
word_embeddings
(
input_ids
)
# Position embeddings.
position_embeddings
=
self
.
position_embeddings
(
position_ids
)
# Token type embeddings. (TODO: move off hotpath?)
token_type_embeddings
=
self
.
token_type_embeddings
(
torch
.
zeros
(
input_shape
,
dtype
=
torch
.
long
,
device
=
inputs_embeds
.
device
))
embeddings
=
inputs_embeds
+
token_type_embeddings
+
position_embeddings
embeddings
=
self
.
LayerNorm
(
embeddings
)
return
embeddings
class
BertEncoder
(
nn
.
Module
):
def
__init__
(
self
,
config
:
BertConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
()
self
.
layer
=
nn
.
ModuleList
([
BertLayer
(
config
=
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.layer.
{
layer_idx
}
"
)
for
layer_idx
in
range
(
config
.
num_hidden_layers
)
])
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
for
i
in
range
(
len
(
self
.
layer
)):
layer
=
self
.
layer
[
i
]
hidden_states
=
layer
(
hidden_states
,
kv_caches
[
i
],
attn_metadata
)
return
hidden_states
class
BertLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
BertConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
()
self
.
attention
=
BertAttention
(
hidden_size
=
config
.
hidden_size
,
num_attention_heads
=
config
.
num_attention_heads
,
layer_norm_eps
=
config
.
layer_norm_eps
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attention"
)
self
.
intermediate
=
BertIntermediate
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.intermediate"
)
self
.
output
=
BertOutput
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
layer_norm_eps
=
config
.
layer_norm_eps
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.output"
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
Optional
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
):
attn_output
=
self
.
attention
(
hidden_states
,
kv_cache
,
attn_metadata
)
intermediate_output
=
self
.
intermediate
(
attn_output
)
output
=
self
.
output
(
intermediate_output
,
attn_output
)
return
output
class
BertAttention
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
num_attention_heads
:
int
,
layer_norm_eps
:
float
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
self
=
BertSelfAttention
(
hidden_size
=
hidden_size
,
num_attention_heads
=
num_attention_heads
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.output"
)
self
.
output
=
BertSelfOutput
(
hidden_size
=
hidden_size
,
layer_norm_eps
=
layer_norm_eps
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.output"
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
self_output
=
self
.
self
(
hidden_states
,
kv_cache
,
attn_metadata
)
return
self
.
output
(
self_output
,
hidden_states
)
class
BertSelfAttention
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
num_attention_heads
:
int
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
hidden_size
=
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
num_attention_heads
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
total_num_kv_heads
=
self
.
total_num_heads
self
.
head_dim
=
self
.
hidden_size
//
self
.
total_num_heads
assert
self
.
head_dim
*
self
.
total_num_heads
==
self
.
hidden_size
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_size
)
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
=
self
.
hidden_size
,
head_size
=
self
.
head_dim
,
total_num_heads
=
self
.
total_num_heads
,
total_num_kv_heads
=
self
.
total_num_kv_heads
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
)
self
.
attn
=
Attention
(
num_heads
=
self
.
num_heads
,
head_size
=
self
.
head_dim
,
scale
=
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
)
if
not
isinstance
(
self
.
attn
.
impl
,
XFormersImpl
):
raise
ValueError
(
"Encoder-only models currently require XFORMERS attention "
"backend. Set VLLM_ATTENTION_BACKEND=XFORMERS to use BERT."
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
,
attn_type
=
AttentionType
.
ENCODER_ONLY
)
return
output
class
BertSelfOutput
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
layer_norm_eps
:
float
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
()
self
.
dense
=
RowParallelLinear
(
input_size
=
hidden_size
,
output_size
=
hidden_size
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.dense"
)
self
.
LayerNorm
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
layer_norm_eps
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
input_tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
,
_
=
self
.
dense
(
hidden_states
)
hidden_states
=
self
.
LayerNorm
(
hidden_states
+
input_tensor
)
return
hidden_states
class
BertIntermediate
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
()
self
.
dense
=
ColumnParallelLinear
(
input_size
=
hidden_size
,
output_size
=
intermediate_size
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.dense"
)
self
.
intermediate_act_fn
=
get_act_fn
(
hidden_act
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
,
_
=
self
.
dense
(
hidden_states
)
hidden_states
=
self
.
intermediate_act_fn
(
hidden_states
)
return
hidden_states
class
BertOutput
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
intermediate_size
:
int
,
layer_norm_eps
:
float
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
()
self
.
dense
=
RowParallelLinear
(
input_size
=
intermediate_size
,
output_size
=
hidden_size
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.dense"
)
self
.
LayerNorm
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
layer_norm_eps
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
input_tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
,
_
=
self
.
dense
(
hidden_states
)
hidden_states
=
self
.
LayerNorm
(
hidden_states
+
input_tensor
)
return
hidden_states
class
BertModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
BertConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
()
self
.
embeddings
=
BertEmbedding
(
config
)
self
.
encoder
=
BertEncoder
(
config
,
cache_config
,
quant_config
,
prefix
=
f
"
{
prefix
}
.encoder"
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
embeddings
(
input_ids
=
input_ids
,
position_ids
=
position_ids
)
return
self
.
encoder
(
hidden_states
,
kv_caches
,
attn_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"query"
,
"q"
),
(
"qkv_proj"
,
"key"
,
"k"
),
(
"qkv_proj"
,
"value"
,
"v"
),
]
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
if
"pooler"
in
name
:
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
class
BertEmbeddingModel
(
nn
.
Module
):
"""A model that uses Bert to provide embedding functionalities.
This class encapsulates the BertModel and provides an interface for
embedding operations and customized pooling functions.
Attributes:
model: An instance of BertModel used for forward operations.
_pooler: An instance of Pooler used for pooling operations.
"""
def
__init__
(
self
,
config
:
BertConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
model
=
BertModel
(
config
,
cache_config
,
quant_config
)
self
.
_pooler
=
Pooler
(
pooling_type
=
PoolingType
.
CLS
,
normalize
=
True
)
def
forward
(
self
,
input_ids
:
Optional
[
torch
.
Tensor
],
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
return
self
.
model
(
input_ids
=
input_ids
,
position_ids
=
positions
,
kv_caches
=
kv_caches
,
inputs_embeds
=
inputs_embeds
,
intermediate_tensors
=
intermediate_tensors
,
attn_metadata
=
attn_metadata
)
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Optional
[
PoolerOutput
]:
return
self
.
_pooler
(
hidden_states
,
pooling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
self
.
model
.
load_weights
(
weights
)
vllm/model_executor/models/eagle.py
View file @
2216a4e5
...
...
@@ -44,7 +44,7 @@ class EAGLE(nn.Module):
self
.
model
=
model_cls
(
self
.
config
.
model
,
*
args
,
**
kwargs
)
self
.
fc
=
nn
.
Linear
(
config
.
model
.
hidden_size
*
2
,
config
.
model
.
hidden_size
,
bias
=
getattr
(
self
.
config
,
"bias"
,
False
))
bias
=
getattr
(
self
.
config
,
"
eagle_fc_
bias"
,
False
))
self
.
orig_vocab_size
=
config
.
vocab_size
self
.
truncated_vocab_size
=
config
.
truncated_vocab_size
...
...
vllm/model_executor/models/gemma2.py
View file @
2216a4e5
...
...
@@ -241,13 +241,7 @@ class Gemma2DecoderLayer(nn.Module):
return
hidden_states
,
residual
@
support_torch_compile
(
dynamic_arg_dims
=
{
"input_ids"
:
0
,
"positions"
:
0
,
"inputs_embeds"
:
0
,
"intermediate_tensors"
:
0
,
})
@
support_torch_compile
class
Gemma2Model
(
nn
.
Module
):
def
__init__
(
...
...
vllm/model_executor/models/intern_vit.py
View file @
2216a4e5
...
...
@@ -97,6 +97,37 @@ class InternVisionEmbeddings(nn.Module):
return
embeddings
class
InternVisionPatchModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
):
super
().
__init__
()
self
.
config
=
config
self
.
embeddings
=
InternVisionEmbeddings
(
config
)
def
get_input_embeddings
(
self
):
return
self
.
embeddings
def
forward
(
self
,
pixel_values
:
Optional
[
torch
.
Tensor
]
=
None
,
pixel_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
FloatTensor
:
if
pixel_values
is
None
and
pixel_embeds
is
None
:
raise
ValueError
(
'You have to specify pixel_values or pixel_embeds'
)
if
pixel_embeds
is
not
None
:
hidden_states
=
pixel_embeds
elif
pixel_values
is
not
None
:
if
pixel_values
.
ndim
==
4
:
hidden_states
=
self
.
embeddings
(
pixel_values
)
else
:
raise
ValueError
(
f
'wrong pixel_values size:
{
pixel_values
.
shape
}
'
)
return
hidden_states
class
InternParallelAttention
(
nn
.
Module
):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
...
...
Prev
1
…
5
6
7
8
9
10
11
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment