Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
40542023
Commit
40542023
authored
Feb 24, 2024
by
zhuwenwen
Browse files
merge v0.3.2
parents
5e5b497d
8fbd84bf
Changes
44
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
948 additions
and
41 deletions
+948
-41
vllm/core/policy.py
vllm/core/policy.py
+1
-1
vllm/core/scheduler.py
vllm/core/scheduler.py
+4
-0
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+14
-8
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+11
-3
vllm/entrypoints/api_server.py
vllm/entrypoints/api_server.py
+6
-0
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+23
-1
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+4
-0
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+9
-4
vllm/entrypoints/openai/serving_completion.py
vllm/entrypoints/openai/serving_completion.py
+11
-4
vllm/entrypoints/openai/serving_engine.py
vllm/entrypoints/openai/serving_engine.py
+40
-1
vllm/logger.py
vllm/logger.py
+7
-3
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+22
-7
vllm/model_executor/models/__init__.py
vllm/model_executor/models/__init__.py
+2
-0
vllm/model_executor/models/gemma.py
vllm/model_executor/models/gemma.py
+333
-0
vllm/model_executor/models/olmo.py
vllm/model_executor/models/olmo.py
+378
-0
vllm/model_executor/sampling_metadata.py
vllm/model_executor/sampling_metadata.py
+3
-0
vllm/outputs.py
vllm/outputs.py
+9
-1
vllm/sampling_params.py
vllm/sampling_params.py
+8
-1
vllm/sequence.py
vllm/sequence.py
+54
-4
vllm/transformers_utils/config.py
vllm/transformers_utils/config.py
+9
-3
No files found.
vllm/core/policy.py
View file @
40542023
...
@@ -33,7 +33,7 @@ class FCFS(Policy):
...
@@ -33,7 +33,7 @@ class FCFS(Policy):
now
:
float
,
now
:
float
,
seq_group
:
SequenceGroup
,
seq_group
:
SequenceGroup
,
)
->
float
:
)
->
float
:
return
now
-
seq_group
.
arrival_time
return
now
-
seq_group
.
metrics
.
arrival_time
class
PolicyFactory
:
class
PolicyFactory
:
...
...
vllm/core/scheduler.py
View file @
40542023
...
@@ -365,10 +365,13 @@ class Scheduler:
...
@@ -365,10 +365,13 @@ class Scheduler:
# This function call changes the internal states of the scheduler
# This function call changes the internal states of the scheduler
# such as self.running, self.swapped, and self.waiting.
# such as self.running, self.swapped, and self.waiting.
scheduler_outputs
=
self
.
_schedule
()
scheduler_outputs
=
self
.
_schedule
()
now
=
time
.
time
()
# Create input data structures.
# Create input data structures.
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
for
seq_group
in
scheduler_outputs
.
scheduled_seq_groups
:
for
seq_group
in
scheduler_outputs
.
scheduled_seq_groups
:
seq_group
.
maybe_set_first_scheduled_time
(
now
)
seq_data
:
Dict
[
int
,
SequenceData
]
=
{}
seq_data
:
Dict
[
int
,
SequenceData
]
=
{}
block_tables
:
Dict
[
int
,
List
[
int
]]
=
{}
block_tables
:
Dict
[
int
,
List
[
int
]]
=
{}
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
):
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
):
...
@@ -384,6 +387,7 @@ class Scheduler:
...
@@ -384,6 +387,7 @@ class Scheduler:
block_tables
=
block_tables
,
block_tables
=
block_tables
,
lora_request
=
seq_group
.
lora_request
,
lora_request
=
seq_group
.
lora_request
,
prefix
=
seq_group
.
prefix
,
prefix
=
seq_group
.
prefix
,
state
=
seq_group
.
state
,
)
)
seq_group_metadata_list
.
append
(
seq_group_metadata
)
seq_group_metadata_list
.
append
(
seq_group_metadata
)
return
seq_group_metadata_list
,
scheduler_outputs
return
seq_group_metadata_list
,
scheduler_outputs
...
...
vllm/engine/arg_utils.py
View file @
40542023
...
@@ -32,6 +32,7 @@ class EngineArgs:
...
@@ -32,6 +32,7 @@ class EngineArgs:
max_paddings
:
int
=
256
max_paddings
:
int
=
256
disable_log_stats
:
bool
=
False
disable_log_stats
:
bool
=
False
revision
:
Optional
[
str
]
=
None
revision
:
Optional
[
str
]
=
None
code_revision
:
Optional
[
str
]
=
None
tokenizer_revision
:
Optional
[
str
]
=
None
tokenizer_revision
:
Optional
[
str
]
=
None
quantization
:
Optional
[
str
]
=
None
quantization
:
Optional
[
str
]
=
None
enforce_eager
:
bool
=
False
enforce_eager
:
bool
=
False
...
@@ -75,6 +76,13 @@ class EngineArgs:
...
@@ -75,6 +76,13 @@ class EngineArgs:
help
=
'the specific model version to use. It can be a branch '
help
=
'the specific model version to use. It can be a branch '
'name, a tag name, or a commit id. If unspecified, will use '
'name, a tag name, or a commit id. If unspecified, will use '
'the default version.'
)
'the default version.'
)
parser
.
add_argument
(
'--code-revision'
,
type
=
str
,
default
=
None
,
help
=
'the specific revision to use for the model code on '
'Hugging Face Hub. It can be a branch name, a tag name, or a '
'commit id. If unspecified, will use the default version.'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--tokenizer-revision'
,
'--tokenizer-revision'
,
type
=
str
,
type
=
str
,
...
@@ -165,7 +173,6 @@ class EngineArgs:
...
@@ -165,7 +173,6 @@ class EngineArgs:
default
=
EngineArgs
.
block_size
,
default
=
EngineArgs
.
block_size
,
choices
=
[
8
,
16
,
32
],
choices
=
[
8
,
16
,
32
],
help
=
'token block size'
)
help
=
'token block size'
)
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser
.
add_argument
(
'--seed'
,
parser
.
add_argument
(
'--seed'
,
type
=
int
,
type
=
int
,
default
=
EngineArgs
.
seed
,
default
=
EngineArgs
.
seed
,
...
@@ -279,13 +286,12 @@ class EngineArgs:
...
@@ -279,13 +286,12 @@ class EngineArgs:
)
->
Tuple
[
ModelConfig
,
CacheConfig
,
ParallelConfig
,
SchedulerConfig
,
)
->
Tuple
[
ModelConfig
,
CacheConfig
,
ParallelConfig
,
SchedulerConfig
,
DeviceConfig
,
Optional
[
LoRAConfig
]]:
DeviceConfig
,
Optional
[
LoRAConfig
]]:
device_config
=
DeviceConfig
(
self
.
device
)
device_config
=
DeviceConfig
(
self
.
device
)
model_config
=
ModelConfig
(
self
.
model
,
self
.
tokenizer
,
model_config
=
ModelConfig
(
self
.
tokenizer_mode
,
self
.
trust_remote_code
,
self
.
model
,
self
.
tokenizer
,
self
.
tokenizer_mode
,
self
.
download_dir
,
self
.
load_format
,
self
.
trust_remote_code
,
self
.
download_dir
,
self
.
load_format
,
self
.
dtype
,
self
.
seed
,
self
.
revision
,
self
.
dtype
,
self
.
seed
,
self
.
revision
,
self
.
code_revision
,
self
.
tokenizer_revision
,
self
.
max_model_len
,
self
.
tokenizer_revision
,
self
.
max_model_len
,
self
.
quantization
,
self
.
quantization
,
self
.
enforce_eager
,
self
.
enforce_eager
,
self
.
max_context_len_to_capture
)
self
.
max_context_len_to_capture
)
cache_config
=
CacheConfig
(
self
.
block_size
,
cache_config
=
CacheConfig
(
self
.
block_size
,
self
.
gpu_memory_utilization
,
self
.
gpu_memory_utilization
,
self
.
swap_space
,
self
.
kv_cache_dtype
,
self
.
swap_space
,
self
.
kv_cache_dtype
,
...
...
vllm/engine/llm_engine.py
View file @
40542023
...
@@ -464,6 +464,9 @@ class LLMEngine:
...
@@ -464,6 +464,9 @@ class LLMEngine:
prompt_token_ids
[:
prefix_pos
],
lora_request
.
lora_int_id
prompt_token_ids
[:
prefix_pos
],
lora_request
.
lora_int_id
if
lora_request
else
0
)
if
prefix_pos
is
not
None
else
None
if
lora_request
else
0
)
if
prefix_pos
is
not
None
else
None
# Defensive copy of SamplingParams, which are used by the sampler
sampling_params
=
copy
.
deepcopy
(
sampling_params
)
# Create the sequence group.
# Create the sequence group.
seq_group
=
SequenceGroup
(
request_id
,
[
seq
],
sampling_params
,
seq_group
=
SequenceGroup
(
request_id
,
[
seq
],
sampling_params
,
arrival_time
,
lora_request
,
prefix
)
arrival_time
,
lora_request
,
prefix
)
...
@@ -725,6 +728,7 @@ class LLMEngine:
...
@@ -725,6 +728,7 @@ class LLMEngine:
def
_process_model_outputs
(
def
_process_model_outputs
(
self
,
output
:
SamplerOutput
,
self
,
output
:
SamplerOutput
,
scheduler_outputs
:
SchedulerOutputs
)
->
List
[
RequestOutput
]:
scheduler_outputs
:
SchedulerOutputs
)
->
List
[
RequestOutput
]:
now
=
time
.
time
()
# Update the scheduled sequence groups with the model outputs.
# Update the scheduled sequence groups with the model outputs.
scheduled_seq_groups
=
scheduler_outputs
.
scheduled_seq_groups
scheduled_seq_groups
=
scheduler_outputs
.
scheduled_seq_groups
for
seq_group
,
outputs
in
zip
(
scheduled_seq_groups
,
output
):
for
seq_group
,
outputs
in
zip
(
scheduled_seq_groups
,
output
):
...
@@ -736,6 +740,7 @@ class LLMEngine:
...
@@ -736,6 +740,7 @@ class LLMEngine:
# Create the outputs.
# Create the outputs.
request_outputs
:
List
[
RequestOutput
]
=
[]
request_outputs
:
List
[
RequestOutput
]
=
[]
for
seq_group
in
scheduled_seq_groups
:
for
seq_group
in
scheduled_seq_groups
:
seq_group
.
maybe_set_first_token_time
(
now
)
request_output
=
RequestOutput
.
from_seq_group
(
seq_group
)
request_output
=
RequestOutput
.
from_seq_group
(
seq_group
)
request_outputs
.
append
(
request_output
)
request_outputs
.
append
(
request_output
)
for
seq_group
in
scheduler_outputs
.
ignored_seq_groups
:
for
seq_group
in
scheduler_outputs
.
ignored_seq_groups
:
...
@@ -864,18 +869,21 @@ class LLMEngine:
...
@@ -864,18 +869,21 @@ class LLMEngine:
# Number of Tokens.
# Number of Tokens.
if
prompt_run
:
if
prompt_run
:
num_prompt_tokens
=
scheduler_outputs
.
num_batched_tokens
num_prompt_tokens
=
sum
(
len
(
seq_group
.
prompt_token_ids
)
for
seq_group
in
scheduler_outputs
.
scheduled_seq_groups
)
else
:
else
:
num_generation_tokens
=
scheduler_outputs
.
num_batched_tokens
num_generation_tokens
=
scheduler_outputs
.
num_batched_tokens
# Latency Timings.
# Latency Timings.
time_last_iters
=
[]
time_last_iters
=
[]
for
seq_group
in
scheduler_outputs
.
scheduled_seq_groups
:
for
seq_group
in
scheduler_outputs
.
scheduled_seq_groups
:
# Time since last token. (n.b. updates seq_group.last_token_time)
# Time since last token. (n.b. updates seq_group.
metrics.
last_token_time)
time_last_iters
.
append
(
seq_group
.
get_last_latency
(
now
))
time_last_iters
.
append
(
seq_group
.
get_last_latency
(
now
))
# Time since arrival for all finished requests.
# Time since arrival for all finished requests.
if
seq_group
.
is_finished
():
if
seq_group
.
is_finished
():
time_e2e_requests
.
append
(
now
-
seq_group
.
arrival_time
)
time_e2e_requests
.
append
(
now
-
seq_group
.
metrics
.
arrival_time
)
time_to_first_tokens
=
time_last_iters
if
prompt_run
else
[]
time_to_first_tokens
=
time_last_iters
if
prompt_run
else
[]
time_per_output_tokens
=
[]
if
prompt_run
else
time_last_iters
time_per_output_tokens
=
[]
if
prompt_run
else
time_last_iters
...
...
vllm/entrypoints/api_server.py
View file @
40542023
"""
NOTE: This API server is used only for demonstrating usage of AsyncEngine and simple performance benchmarks.
It is not intended for production use. For production use, we recommend using our OpenAI compatible server.
We are also not going to accept PRs modifying this file, please change `vllm/entrypoints/openai/api_server.py` instead.
"""
import
argparse
import
argparse
import
json
import
json
from
typing
import
AsyncGenerator
from
typing
import
AsyncGenerator
...
...
vllm/entrypoints/openai/api_server.py
View file @
40542023
...
@@ -23,6 +23,7 @@ from vllm.entrypoints.openai.protocol import CompletionRequest, ChatCompletionRe
...
@@ -23,6 +23,7 @@ from vllm.entrypoints.openai.protocol import CompletionRequest, ChatCompletionRe
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.entrypoints.openai.serving_chat
import
OpenAIServingChat
from
vllm.entrypoints.openai.serving_chat
import
OpenAIServingChat
from
vllm.entrypoints.openai.serving_completion
import
OpenAIServingCompletion
from
vllm.entrypoints.openai.serving_completion
import
OpenAIServingCompletion
from
vllm.entrypoints.openai.serving_engine
import
LoRA
TIMEOUT_KEEP_ALIVE
=
5
# seconds
TIMEOUT_KEEP_ALIVE
=
5
# seconds
...
@@ -48,6 +49,16 @@ async def lifespan(app: fastapi.FastAPI):
...
@@ -48,6 +49,16 @@ async def lifespan(app: fastapi.FastAPI):
app
=
fastapi
.
FastAPI
(
lifespan
=
lifespan
)
app
=
fastapi
.
FastAPI
(
lifespan
=
lifespan
)
class
LoRAParserAction
(
argparse
.
Action
):
def
__call__
(
self
,
parser
,
namespace
,
values
,
option_string
=
None
):
lora_list
=
[]
for
item
in
values
:
name
,
path
=
item
.
split
(
'='
)
lora_list
.
append
(
LoRA
(
name
,
path
))
setattr
(
namespace
,
self
.
dest
,
lora_list
)
def
parse_args
():
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
parser
=
argparse
.
ArgumentParser
(
description
=
"vLLM OpenAI-Compatible RESTful API server."
)
description
=
"vLLM OpenAI-Compatible RESTful API server."
)
...
@@ -81,6 +92,15 @@ def parse_args():
...
@@ -81,6 +92,15 @@ def parse_args():
help
=
"The model name used in the API. If not "
help
=
"The model name used in the API. If not "
"specified, the model name will be the same as "
"specified, the model name will be the same as "
"the huggingface name."
)
"the huggingface name."
)
parser
.
add_argument
(
"--lora-modules"
,
type
=
str
,
default
=
None
,
nargs
=
'+'
,
action
=
LoRAParserAction
,
help
=
"LoRA module configurations in the format name=path. Multiple modules can be specified."
)
parser
.
add_argument
(
"--chat-template"
,
parser
.
add_argument
(
"--chat-template"
,
type
=
str
,
type
=
str
,
default
=
None
,
default
=
None
,
...
@@ -217,8 +237,10 @@ if __name__ == "__main__":
...
@@ -217,8 +237,10 @@ if __name__ == "__main__":
engine
=
AsyncLLMEngine
.
from_engine_args
(
engine_args
)
engine
=
AsyncLLMEngine
.
from_engine_args
(
engine_args
)
openai_serving_chat
=
OpenAIServingChat
(
engine
,
served_model
,
openai_serving_chat
=
OpenAIServingChat
(
engine
,
served_model
,
args
.
response_role
,
args
.
response_role
,
args
.
lora_modules
,
args
.
chat_template
)
args
.
chat_template
)
openai_serving_completion
=
OpenAIServingCompletion
(
engine
,
served_model
)
openai_serving_completion
=
OpenAIServingCompletion
(
engine
,
served_model
,
args
.
lora_modules
)
# Register labels for metrics
# Register labels for metrics
add_global_metrics_labels
(
model_name
=
engine_args
.
model
)
add_global_metrics_labels
(
model_name
=
engine_args
.
model
)
...
...
vllm/entrypoints/openai/protocol.py
View file @
40542023
...
@@ -60,6 +60,7 @@ class ChatCompletionRequest(BaseModel):
...
@@ -60,6 +60,7 @@ class ChatCompletionRequest(BaseModel):
top_p
:
Optional
[
float
]
=
1.0
top_p
:
Optional
[
float
]
=
1.0
n
:
Optional
[
int
]
=
1
n
:
Optional
[
int
]
=
1
max_tokens
:
Optional
[
int
]
=
None
max_tokens
:
Optional
[
int
]
=
None
seed
:
Optional
[
int
]
=
None
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
Field
(
default_factory
=
list
)
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
Field
(
default_factory
=
list
)
stream
:
Optional
[
bool
]
=
False
stream
:
Optional
[
bool
]
=
False
presence_penalty
:
Optional
[
float
]
=
0.0
presence_penalty
:
Optional
[
float
]
=
0.0
...
@@ -90,6 +91,7 @@ class ChatCompletionRequest(BaseModel):
...
@@ -90,6 +91,7 @@ class ChatCompletionRequest(BaseModel):
temperature
=
self
.
temperature
,
temperature
=
self
.
temperature
,
top_p
=
self
.
top_p
,
top_p
=
self
.
top_p
,
min_p
=
self
.
min_p
,
min_p
=
self
.
min_p
,
seed
=
self
.
seed
,
stop
=
self
.
stop
,
stop
=
self
.
stop
,
stop_token_ids
=
self
.
stop_token_ids
,
stop_token_ids
=
self
.
stop_token_ids
,
max_tokens
=
self
.
max_tokens
,
max_tokens
=
self
.
max_tokens
,
...
@@ -117,6 +119,7 @@ class CompletionRequest(BaseModel):
...
@@ -117,6 +119,7 @@ class CompletionRequest(BaseModel):
logprobs
:
Optional
[
int
]
=
None
logprobs
:
Optional
[
int
]
=
None
echo
:
Optional
[
bool
]
=
False
echo
:
Optional
[
bool
]
=
False
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
Field
(
default_factory
=
list
)
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
Field
(
default_factory
=
list
)
seed
:
Optional
[
int
]
=
None
presence_penalty
:
Optional
[
float
]
=
0.0
presence_penalty
:
Optional
[
float
]
=
0.0
frequency_penalty
:
Optional
[
float
]
=
0.0
frequency_penalty
:
Optional
[
float
]
=
0.0
best_of
:
Optional
[
int
]
=
None
best_of
:
Optional
[
int
]
=
None
...
@@ -147,6 +150,7 @@ class CompletionRequest(BaseModel):
...
@@ -147,6 +150,7 @@ class CompletionRequest(BaseModel):
top_p
=
self
.
top_p
,
top_p
=
self
.
top_p
,
top_k
=
self
.
top_k
,
top_k
=
self
.
top_k
,
min_p
=
self
.
min_p
,
min_p
=
self
.
min_p
,
seed
=
self
.
seed
,
stop
=
self
.
stop
,
stop
=
self
.
stop
,
stop_token_ids
=
self
.
stop_token_ids
,
stop_token_ids
=
self
.
stop_token_ids
,
ignore_eos
=
self
.
ignore_eos
,
ignore_eos
=
self
.
ignore_eos
,
...
...
vllm/entrypoints/openai/serving_chat.py
View file @
40542023
import
time
import
time
import
codecs
import
codecs
from
fastapi
import
Request
from
fastapi
import
Request
from
typing
import
AsyncGenerator
,
AsyncIterator
,
Union
from
typing
import
AsyncGenerator
,
AsyncIterator
,
Optional
,
List
,
Union
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
random_uuid
from
vllm.utils
import
random_uuid
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
...
@@ -11,7 +11,7 @@ from vllm.entrypoints.openai.protocol import (
...
@@ -11,7 +11,7 @@ from vllm.entrypoints.openai.protocol import (
ChatCompletionStreamResponse
,
ChatMessage
,
DeltaMessage
,
ErrorResponse
,
ChatCompletionStreamResponse
,
ChatMessage
,
DeltaMessage
,
ErrorResponse
,
UsageInfo
)
UsageInfo
)
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
from
vllm.entrypoints.openai.serving_engine
import
OpenAIServing
from
vllm.entrypoints.openai.serving_engine
import
OpenAIServing
,
LoRA
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -22,8 +22,11 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -22,8 +22,11 @@ class OpenAIServingChat(OpenAIServing):
engine
:
AsyncLLMEngine
,
engine
:
AsyncLLMEngine
,
served_model
:
str
,
served_model
:
str
,
response_role
:
str
,
response_role
:
str
,
lora_modules
:
Optional
[
List
[
LoRA
]]
=
None
,
chat_template
=
None
):
chat_template
=
None
):
super
().
__init__
(
engine
=
engine
,
served_model
=
served_model
)
super
().
__init__
(
engine
=
engine
,
served_model
=
served_model
,
lora_modules
=
lora_modules
)
self
.
response_role
=
response_role
self
.
response_role
=
response_role
self
.
_load_chat_template
(
chat_template
)
self
.
_load_chat_template
(
chat_template
)
...
@@ -64,11 +67,13 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -64,11 +67,13 @@ class OpenAIServingChat(OpenAIServing):
token_ids
=
self
.
_validate_prompt_and_tokenize
(
request
,
token_ids
=
self
.
_validate_prompt_and_tokenize
(
request
,
prompt
=
prompt
)
prompt
=
prompt
)
sampling_params
=
request
.
to_sampling_params
()
sampling_params
=
request
.
to_sampling_params
()
lora_request
=
self
.
_maybe_get_lora
(
request
)
except
ValueError
as
e
:
except
ValueError
as
e
:
return
self
.
create_error_response
(
str
(
e
))
return
self
.
create_error_response
(
str
(
e
))
result_generator
=
self
.
engine
.
generate
(
prompt
,
sampling_params
,
result_generator
=
self
.
engine
.
generate
(
prompt
,
sampling_params
,
request_id
,
token_ids
)
request_id
,
token_ids
,
lora_request
)
# Streaming response
# Streaming response
if
request
.
stream
:
if
request
.
stream
:
return
self
.
chat_completion_stream_generator
(
return
self
.
chat_completion_stream_generator
(
...
...
vllm/entrypoints/openai/serving_completion.py
View file @
40542023
...
@@ -15,7 +15,7 @@ from .protocol import (
...
@@ -15,7 +15,7 @@ from .protocol import (
UsageInfo
,
UsageInfo
,
)
)
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
from
vllm.entrypoints.openai.serving_engine
import
OpenAIServing
from
vllm.entrypoints.openai.serving_engine
import
OpenAIServing
,
LoRA
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -249,8 +249,13 @@ def merge_async_iterators(*iterators):
...
@@ -249,8 +249,13 @@ def merge_async_iterators(*iterators):
class
OpenAIServingCompletion
(
OpenAIServing
):
class
OpenAIServingCompletion
(
OpenAIServing
):
def
__init__
(
self
,
engine
:
AsyncLLMEngine
,
served_model
:
str
):
def
__init__
(
self
,
super
().
__init__
(
engine
=
engine
,
served_model
=
served_model
)
engine
:
AsyncLLMEngine
,
served_model
:
str
,
lora_modules
:
Optional
[
List
[
LoRA
]]
=
None
):
super
().
__init__
(
engine
=
engine
,
served_model
=
served_model
,
lora_modules
=
lora_modules
)
async
def
create_completion
(
self
,
request
:
CompletionRequest
,
async
def
create_completion
(
self
,
request
:
CompletionRequest
,
raw_request
:
Request
):
raw_request
:
Request
):
...
@@ -284,6 +289,7 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -284,6 +289,7 @@ class OpenAIServingCompletion(OpenAIServing):
generators
=
[]
generators
=
[]
try
:
try
:
sampling_params
=
request
.
to_sampling_params
()
sampling_params
=
request
.
to_sampling_params
()
lora_request
=
self
.
_maybe_get_lora
(
request
)
prompt_is_tokens
,
prompts
=
parse_prompt_format
(
request
.
prompt
)
prompt_is_tokens
,
prompts
=
parse_prompt_format
(
request
.
prompt
)
for
i
,
prompt
in
enumerate
(
prompts
):
for
i
,
prompt
in
enumerate
(
prompts
):
...
@@ -298,7 +304,8 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -298,7 +304,8 @@ class OpenAIServingCompletion(OpenAIServing):
self
.
engine
.
generate
(
None
,
self
.
engine
.
generate
(
None
,
sampling_params
,
sampling_params
,
f
"
{
request_id
}
-
{
i
}
"
,
f
"
{
request_id
}
-
{
i
}
"
,
prompt_token_ids
=
input_ids
))
prompt_token_ids
=
input_ids
,
lora_request
=
lora_request
))
except
ValueError
as
e
:
except
ValueError
as
e
:
return
self
.
create_error_response
(
str
(
e
))
return
self
.
create_error_response
(
str
(
e
))
...
...
vllm/entrypoints/openai/serving_engine.py
View file @
40542023
import
asyncio
import
asyncio
from
dataclasses
import
dataclass
from
http
import
HTTPStatus
from
http
import
HTTPStatus
from
typing
import
Dict
,
List
,
Optional
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Union
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -9,15 +10,35 @@ from vllm.entrypoints.openai.protocol import (CompletionRequest,
...
@@ -9,15 +10,35 @@ from vllm.entrypoints.openai.protocol import (CompletionRequest,
ErrorResponse
,
LogProbs
,
ErrorResponse
,
LogProbs
,
ModelCard
,
ModelList
,
ModelCard
,
ModelList
,
ModelPermission
)
ModelPermission
)
from
vllm.lora.request
import
LoRARequest
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
@
dataclass
class
LoRA
:
name
:
str
local_path
:
str
class
OpenAIServing
:
class
OpenAIServing
:
def
__init__
(
self
,
engine
:
AsyncLLMEngine
,
served_model
:
str
):
def
__init__
(
self
,
engine
:
AsyncLLMEngine
,
served_model
:
str
,
lora_modules
=
Optional
[
List
[
LoRA
]]):
self
.
engine
=
engine
self
.
engine
=
engine
self
.
served_model
=
served_model
self
.
served_model
=
served_model
if
lora_modules
is
None
:
self
.
lora_requests
=
[]
else
:
self
.
lora_requests
=
[
LoRARequest
(
lora_name
=
lora
.
name
,
lora_int_id
=
i
,
lora_local_path
=
lora
.
local_path
,
)
for
i
,
lora
in
enumerate
(
lora_modules
,
start
=
1
)
]
self
.
max_model_len
=
0
self
.
max_model_len
=
0
self
.
tokenizer
=
None
self
.
tokenizer
=
None
...
@@ -50,6 +71,13 @@ class OpenAIServing:
...
@@ -50,6 +71,13 @@ class OpenAIServing:
root
=
self
.
served_model
,
root
=
self
.
served_model
,
permission
=
[
ModelPermission
()])
permission
=
[
ModelPermission
()])
]
]
lora_cards
=
[
ModelCard
(
id
=
lora
.
lora_name
,
root
=
self
.
served_model
,
permission
=
[
ModelPermission
()])
for
lora
in
self
.
lora_requests
]
model_cards
.
extend
(
lora_cards
)
return
ModelList
(
data
=
model_cards
)
return
ModelList
(
data
=
model_cards
)
def
_create_logprobs
(
def
_create_logprobs
(
...
@@ -99,11 +127,22 @@ class OpenAIServing:
...
@@ -99,11 +127,22 @@ class OpenAIServing:
async
def
_check_model
(
self
,
request
)
->
Optional
[
ErrorResponse
]:
async
def
_check_model
(
self
,
request
)
->
Optional
[
ErrorResponse
]:
if
request
.
model
==
self
.
served_model
:
if
request
.
model
==
self
.
served_model
:
return
return
if
request
.
model
in
[
lora
.
lora_name
for
lora
in
self
.
lora_requests
]:
return
return
self
.
create_error_response
(
return
self
.
create_error_response
(
message
=
f
"The model `
{
request
.
model
}
` does not exist."
,
message
=
f
"The model `
{
request
.
model
}
` does not exist."
,
err_type
=
"NotFoundError"
,
err_type
=
"NotFoundError"
,
status_code
=
HTTPStatus
.
NOT_FOUND
)
status_code
=
HTTPStatus
.
NOT_FOUND
)
def
_maybe_get_lora
(
self
,
request
)
->
Optional
[
LoRARequest
]:
if
request
.
model
==
self
.
served_model
:
return
for
lora
in
self
.
lora_requests
:
if
request
.
model
==
lora
.
lora_name
:
return
lora
# if _check_model has been called earlier, this will be unreachable
raise
ValueError
(
"The model `{request.model}` does not exist."
)
def
_validate_prompt_and_tokenize
(
def
_validate_prompt_and_tokenize
(
self
,
self
,
request
:
Union
[
ChatCompletionRequest
,
CompletionRequest
],
request
:
Union
[
ChatCompletionRequest
,
CompletionRequest
],
...
...
vllm/logger.py
View file @
40542023
...
@@ -5,6 +5,8 @@ import logging
...
@@ -5,6 +5,8 @@ import logging
import
sys
import
sys
import
os
import
os
VLLM_CONFIGURE_LOGGING
=
int
(
os
.
getenv
(
"VLLM_CONFIGURE_LOGGING"
,
"1"
))
_FORMAT
=
"%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s"
_FORMAT
=
"%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s"
_DATE_FORMAT
=
"%m-%d %H:%M:%S"
_DATE_FORMAT
=
"%m-%d %H:%M:%S"
...
@@ -45,13 +47,15 @@ def _setup_logger():
...
@@ -45,13 +47,15 @@ def _setup_logger():
# The logger is initialized when the module is imported.
# The logger is initialized when the module is imported.
# This is thread-safe as the module is only imported once,
# This is thread-safe as the module is only imported once,
# guaranteed by the Python GIL.
# guaranteed by the Python GIL.
_setup_logger
()
if
VLLM_CONFIGURE_LOGGING
:
_setup_logger
()
def
init_logger
(
name
:
str
):
def
init_logger
(
name
:
str
):
# Use the same settings as above for root logger
# Use the same settings as above for root logger
logger
=
logging
.
getLogger
(
name
)
logger
=
logging
.
getLogger
(
name
)
logger
.
setLevel
(
os
.
getenv
(
"LOG_LEVEL"
,
"DEBUG"
))
logger
.
setLevel
(
os
.
getenv
(
"LOG_LEVEL"
,
"DEBUG"
))
logger
.
addHandler
(
_default_handler
)
if
VLLM_CONFIGURE_LOGGING
:
logger
.
propagate
=
False
logger
.
addHandler
(
_default_handler
)
logger
.
propagate
=
False
return
logger
return
logger
vllm/model_executor/layers/sampler.py
View file @
40542023
...
@@ -342,7 +342,9 @@ def _beam_search_sample(
...
@@ -342,7 +342,9 @@ def _beam_search_sample(
def
_multinomial
(
def
_multinomial
(
probs
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
num_samples
:
int
,
num_samples
:
int
,
):
seq_groups
:
Optional
[
List
[
Tuple
[
List
[
int
],
SamplingParams
]]]
=
None
,
generators
:
Optional
[
List
[
torch
.
Generator
]]
=
None
,
)
->
torch
.
Tensor
:
if
num_samples
>
1
:
if
num_samples
>
1
:
# This is equivalent to torch.repeat_interleaved (which also
# This is equivalent to torch.repeat_interleaved (which also
# forces a GPU<->CPU sync).
# forces a GPU<->CPU sync).
...
@@ -352,7 +354,15 @@ def _multinomial(
...
@@ -352,7 +354,15 @@ def _multinomial(
probs
=
probs
[:,
None
,
:].
expand
(
probs
.
shape
[
0
],
num_samples
,
probs
=
probs
[:,
None
,
:].
expand
(
probs
.
shape
[
0
],
num_samples
,
probs
.
shape
[
1
]).
contiguous
().
view
(
probs
.
shape
[
1
]).
contiguous
().
view
(
-
1
,
probs
.
shape
[
1
])
-
1
,
probs
.
shape
[
1
])
q
=
torch
.
empty_like
(
probs
).
exponential_
(
1
)
q
=
torch
.
empty_like
(
probs
)
if
seq_groups
is
None
:
q
.
exponential_
()
else
:
sample_idx
=
0
for
(
seq_ids
,
_
),
generator
in
zip
(
seq_groups
,
generators
):
next_sample_idx
=
sample_idx
+
len
(
seq_ids
)
*
num_samples
q
[
sample_idx
:
next_sample_idx
].
exponential_
(
generator
=
generator
)
sample_idx
=
next_sample_idx
return
probs
.
div_
(
q
).
argmax
(
dim
=
1
).
view
(
-
1
,
num_samples
)
return
probs
.
div_
(
q
).
argmax
(
dim
=
1
).
view
(
-
1
,
num_samples
)
...
@@ -370,6 +380,7 @@ def _sample(
...
@@ -370,6 +380,7 @@ def _sample(
sample_results_dict
:
Dict
[
int
,
Tuple
[
List
[
int
],
List
[
int
]]]
=
{}
sample_results_dict
:
Dict
[
int
,
Tuple
[
List
[
int
],
List
[
int
]]]
=
{}
sample_metadata
=
{}
sample_metadata
=
{}
multinomial_samples
=
{}
# Counterintiutively, having two loops here is actually faster.
# Counterintiutively, having two loops here is actually faster.
# The first loop can run without waiting on GPU<->CPU sync.
# The first loop can run without waiting on GPU<->CPU sync.
...
@@ -385,14 +396,18 @@ def _sample(
...
@@ -385,14 +396,18 @@ def _sample(
is_prompts
,
sample_indices
)
is_prompts
,
sample_indices
)
if
sampling_type
==
SamplingType
.
GREEDY
:
if
sampling_type
==
SamplingType
.
GREEDY
:
greedy_samples
=
torch
.
argmax
(
logprobs
[
sample_indices
],
dim
=-
1
)
greedy_samples
=
torch
.
argmax
(
logprobs
[
sample_indices
],
dim
=-
1
)
elif
sampling_type
==
SamplingType
.
RANDOM
:
elif
sampling_type
in
(
SamplingType
.
RANDOM
,
SamplingType
.
RANDOM_SEED
)
:
max_best_of
=
1
max_best_of
=
1
for
seq_group
,
is_prompt
in
zip
(
seq_groups
,
is_prompts
):
for
seq_group
,
is_prompt
in
zip
(
seq_groups
,
is_prompts
):
if
is_prompt
:
if
is_prompt
:
_
,
sampling_params
=
seq_group
_
,
sampling_params
=
seq_group
max_best_of
=
max
(
max_best_of
,
sampling_params
.
best_of
)
max_best_of
=
max
(
max_best_of
,
sampling_params
.
best_of
)
multinomial_samples
=
_multinomial
(
probs
[
sample_indices
],
seeded_args
=
{}
if
sampling_type
==
SamplingType
.
RANDOM
else
{
max_best_of
)
"seq_groups"
:
seq_groups
,
"generators"
:
sampling_metadata
.
generators
,
}
multinomial_samples
[
sampling_type
]
=
_multinomial
(
probs
[
sample_indices
],
max_best_of
,
**
seeded_args
)
elif
sampling_type
==
SamplingType
.
BEAM
:
elif
sampling_type
==
SamplingType
.
BEAM
:
beam_search_logprobs
=
logprobs
[
sample_indices
]
beam_search_logprobs
=
logprobs
[
sample_indices
]
else
:
else
:
...
@@ -407,9 +422,9 @@ def _sample(
...
@@ -407,9 +422,9 @@ def _sample(
sampling_type
]
sampling_type
]
if
sampling_type
==
SamplingType
.
GREEDY
:
if
sampling_type
==
SamplingType
.
GREEDY
:
sample_results
=
_greedy_sample
(
seq_groups
,
greedy_samples
)
sample_results
=
_greedy_sample
(
seq_groups
,
greedy_samples
)
elif
sampling_type
==
SamplingType
.
RANDOM
:
elif
sampling_type
in
(
SamplingType
.
RANDOM
,
SamplingType
.
RANDOM_SEED
)
:
sample_results
=
_random_sample
(
seq_groups
,
is_prompts
,
sample_results
=
_random_sample
(
seq_groups
,
is_prompts
,
multinomial_samples
)
multinomial_samples
[
sampling_type
]
)
elif
sampling_type
==
SamplingType
.
BEAM
:
elif
sampling_type
==
SamplingType
.
BEAM
:
sample_results
=
_beam_search_sample
(
seq_groups
,
is_prompts
,
sample_results
=
_beam_search_sample
(
seq_groups
,
is_prompts
,
sampling_metadata
.
seq_data
,
sampling_metadata
.
seq_data
,
...
...
vllm/model_executor/models/__init__.py
View file @
40542023
...
@@ -20,6 +20,7 @@ _MODELS = {
...
@@ -20,6 +20,7 @@ _MODELS = {
"DeciLMForCausalLM"
:
(
"decilm"
,
"DeciLMForCausalLM"
),
"DeciLMForCausalLM"
:
(
"decilm"
,
"DeciLMForCausalLM"
),
"DeepseekForCausalLM"
:
(
"deepseek"
,
"DeepseekForCausalLM"
),
"DeepseekForCausalLM"
:
(
"deepseek"
,
"DeepseekForCausalLM"
),
"FalconForCausalLM"
:
(
"falcon"
,
"FalconForCausalLM"
),
"FalconForCausalLM"
:
(
"falcon"
,
"FalconForCausalLM"
),
"GemmaForCausalLM"
:
(
"gemma"
,
"GemmaForCausalLM"
),
"GPT2LMHeadModel"
:
(
"gpt2"
,
"GPT2LMHeadModel"
),
"GPT2LMHeadModel"
:
(
"gpt2"
,
"GPT2LMHeadModel"
),
"GPTBigCodeForCausalLM"
:
(
"gpt_bigcode"
,
"GPTBigCodeForCausalLM"
),
"GPTBigCodeForCausalLM"
:
(
"gpt_bigcode"
,
"GPTBigCodeForCausalLM"
),
"GPTJForCausalLM"
:
(
"gpt_j"
,
"GPTJForCausalLM"
),
"GPTJForCausalLM"
:
(
"gpt_j"
,
"GPTJForCausalLM"
),
...
@@ -35,6 +36,7 @@ _MODELS = {
...
@@ -35,6 +36,7 @@ _MODELS = {
# transformers's mpt class has lower case
# transformers's mpt class has lower case
"MptForCausalLM"
:
(
"mpt"
,
"MPTForCausalLM"
),
"MptForCausalLM"
:
(
"mpt"
,
"MPTForCausalLM"
),
"MPTForCausalLM"
:
(
"mpt"
,
"MPTForCausalLM"
),
"MPTForCausalLM"
:
(
"mpt"
,
"MPTForCausalLM"
),
"OLMoForCausalLM"
:
(
"olmo"
,
"OLMoForCausalLM"
),
"OPTForCausalLM"
:
(
"opt"
,
"OPTForCausalLM"
),
"OPTForCausalLM"
:
(
"opt"
,
"OPTForCausalLM"
),
"PhiForCausalLM"
:
(
"phi"
,
"PhiForCausalLM"
),
"PhiForCausalLM"
:
(
"phi"
,
"PhiForCausalLM"
),
"QWenLMHeadModel"
:
(
"qwen"
,
"QWenLMHeadModel"
),
"QWenLMHeadModel"
:
(
"qwen"
,
"QWenLMHeadModel"
),
...
...
vllm/model_executor/models/gemma.py
0 → 100644
View file @
40542023
# coding=utf-8
# Copyright 2023 The vLLM team.
# Copyright (c) Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Gemma model compatible with HuggingFace weights."""
from
typing
import
List
,
Optional
,
Tuple
import
torch
from
torch
import
nn
from
transformers
import
GemmaConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
GemmaRMSNorm
(
nn
.
Module
):
def
__init__
(
self
,
dim
:
int
,
eps
:
float
=
1e-6
):
super
().
__init__
()
self
.
eps
=
eps
self
.
weight
=
nn
.
Parameter
(
torch
.
zeros
(
dim
))
def
_norm
(
self
,
x
):
return
x
*
torch
.
rsqrt
(
x
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
+
self
.
eps
)
def
forward
(
self
,
x
):
output
=
self
.
_norm
(
x
.
float
()).
type_as
(
x
)
return
output
*
(
1
+
self
.
weight
)
class
GemmaMLP
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
intermediate_size
:
int
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
gate_proj
=
ColumnParallelLinear
(
hidden_size
,
intermediate_size
,
bias
=
False
,
linear_method
=
linear_method
)
self
.
up_proj
=
ColumnParallelLinear
(
hidden_size
,
intermediate_size
,
bias
=
False
,
linear_method
=
linear_method
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
linear_method
=
linear_method
)
self
.
act_fn
=
nn
.
GELU
()
def
forward
(
self
,
x
):
gate
,
_
=
self
.
gate_proj
(
x
)
gate
=
self
.
act_fn
(
gate
)
up
,
_
=
self
.
up_proj
(
x
)
fuse
=
gate
*
up
outputs
,
_
=
self
.
down_proj
(
fuse
)
return
outputs
class
GemmaAttention
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
head_dim
:
int
,
max_position_embeddings
:
int
=
8192
,
rope_theta
:
float
=
10000
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
num_heads
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
total_num_kv_heads
=
num_kv_heads
if
self
.
total_num_kv_heads
>=
tp_size
:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert
self
.
total_num_kv_heads
%
tp_size
==
0
else
:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert
tp_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_size
)
self
.
head_dim
=
head_dim
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
rope_theta
=
rope_theta
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
,
self
.
head_dim
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
False
,
linear_method
=
linear_method
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
False
,
linear_method
=
linear_method
,
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
max_position_embeddings
,
base
=
self
.
rope_theta
,
is_neox_style
=
True
,
)
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
class
GemmaDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
GemmaConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
self_attn
=
GemmaAttention
(
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
num_kv_heads
=
config
.
num_key_value_heads
,
head_dim
=
config
.
head_dim
,
max_position_embeddings
=
config
.
max_position_embeddings
,
rope_theta
=
config
.
rope_theta
,
linear_method
=
linear_method
,
)
self
.
mlp
=
GemmaMLP
(
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
linear_method
=
linear_method
,
)
self
.
input_layernorm
=
GemmaRMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
GemmaRMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input_metadata
:
InputMetadata
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
input_metadata
=
input_metadata
,
)
hidden_states
=
residual
+
hidden_states
# Fully Connected
residual
=
hidden_states
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
return
hidden_states
class
GemmaModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
GemmaConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
)
self
.
layers
=
nn
.
ModuleList
([
GemmaDecoderLayer
(
config
,
linear_method
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
self
.
norm
=
GemmaRMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
# Normalize the embedding by sqrt(hidden_size)
hidden_states
=
hidden_states
*
(
self
.
config
.
hidden_size
**
0.5
)
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
hidden_states
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
],
input_metadata
,
)
hidden_states
=
self
.
norm
(
hidden_states
)
return
hidden_states
class
GemmaForCausalLM
(
nn
.
Module
):
def
__init__
(
self
,
config
:
GemmaConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
model
=
GemmaModel
(
config
,
linear_method
)
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
@
torch
.
no_grad
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
input_metadata
)
return
hidden_states
def
sample
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
self
.
model
.
embed_tokens
.
weight
,
hidden_states
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
"auto"
,
revision
:
Optional
[
str
]
=
None
):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
=
set
()
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
load_format
,
revision
):
for
(
param_name
,
shard_name
,
shard_id
)
in
stacked_params_mapping
:
if
shard_name
not
in
name
:
continue
name
=
name
.
replace
(
shard_name
,
param_name
)
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra layer for lora models.
if
"lm_head"
in
name
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
unloaded_params
=
params_dict
.
keys
()
-
loaded_params
if
unloaded_params
:
raise
RuntimeError
(
f
"Some weights are not initialized from checkpoints:
{
unloaded_params
}
"
)
vllm/model_executor/models/olmo.py
0 → 100644
View file @
40542023
# coding=utf-8
# Adapted from
# https://github.com/allenai/OLMo/blob/v0.2.4/olmo/model.py and
# https://github.com/allenai/OLMo/blob/v0.2.4/hf_olmo/modeling_olmo.py
# Copyright 2023 The vLLM team.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
#
# BSD 3-Clause License
#
# Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Inference-only OLMo model compatible with HuggingFace weights."""
from
typing
import
List
,
Optional
,
Tuple
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
QKVParallelLinear
,
RowParallelLinear
,
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_world_size
,
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
,
)
from
vllm.sequence
import
SamplerOutput
from
vllm.transformers_utils.configs.olmo
import
OLMoConfig
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
SwiGLU
(
nn
.
Module
):
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
,
gate
=
x
.
chunk
(
2
,
dim
=-
1
)
return
F
.
silu
(
gate
)
*
x
@
property
def
output_multiplier
(
self
)
->
float
:
return
0.5
class
OlmoAttention
(
nn
.
Module
):
"""
This is the attention block where the output is computed as ``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
(plus another skip connection).
"""
def
__init__
(
self
,
config
:
OLMoConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
hidden_size
=
config
.
d_model
assert
config
.
d_model
%
config
.
n_heads
==
0
tensor_model_parallel_world_size
=
get_tensor_model_parallel_world_size
(
)
self
.
total_num_heads
=
self
.
config
.
n_heads
assert
self
.
total_num_heads
%
tensor_model_parallel_world_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tensor_model_parallel_world_size
self
.
head_dim
=
self
.
hidden_size
//
self
.
total_num_heads
# Layer norms.
self
.
attn_norm
=
nn
.
LayerNorm
(
config
.
d_model
,
elementwise_affine
=
False
,
bias
=
False
)
# Attention input projection. Projects x -> (q, k, v)
self
.
att_proj
=
QKVParallelLinear
(
config
.
d_model
,
self
.
head_dim
,
self
.
total_num_heads
,
bias
=
config
.
include_bias
,
linear_method
=
linear_method
,
)
# Rotary embeddings.
if
self
.
config
.
rope
:
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
8192
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
)
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_dim
,
scale
=
self
.
scaling
)
# Attention output projection.
self
.
attn_out
=
RowParallelLinear
(
config
.
d_model
,
config
.
d_model
,
bias
=
config
.
include_bias
,
linear_method
=
linear_method
,
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
attn_norm
(
hidden_states
)
qkv
,
_
=
self
.
att_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
if
self
.
config
.
rope
:
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
)
output
,
_
=
self
.
attn_out
(
attn_output
)
return
output
class
OlmoMLP
(
nn
.
Module
):
"""
This is the MLP block where the output is computed as ``MLP(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
(plus another skip connection).
"""
def
__init__
(
self
,
config
:
OLMoConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
hidden_size
=
(
config
.
mlp_hidden_size
if
config
.
mlp_hidden_size
is
not
None
else
config
.
mlp_ratio
*
config
.
d_model
)
# Layer norms.
self
.
ff_norm
=
nn
.
LayerNorm
(
config
.
d_model
,
elementwise_affine
=
False
,
bias
=
False
)
# Feed-forward input projection.
self
.
ff_proj
=
ColumnParallelLinear
(
config
.
d_model
,
self
.
hidden_size
,
bias
=
config
.
include_bias
,
linear_method
=
linear_method
,
)
# Activation function.
# self.act = SiluAndMul()
# self.act.output_multiplier = 0.5
self
.
act
=
SwiGLU
()
assert
(
self
.
act
.
output_multiplier
*
self
.
hidden_size
)
%
1
==
0
# Feed-forward output projection.
self
.
ff_out
=
RowParallelLinear
(
int
(
self
.
act
.
output_multiplier
*
self
.
hidden_size
),
config
.
d_model
,
bias
=
config
.
include_bias
,
linear_method
=
linear_method
,
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
# Add feed-forward projection.
# shape: (batch_size, seq_len, d_model)
og_x
=
x
x
=
self
.
ff_norm
(
x
)
x
,
_
=
self
.
ff_proj
(
x
)
x
=
self
.
act
(
x
)
x
,
_
=
self
.
ff_out
(
x
)
x
=
og_x
+
x
return
x
class
OlmoBlock
(
nn
.
Module
):
"""
This is a typical transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))``
(plus another skip connection).
"""
def
__init__
(
self
,
config
:
OLMoConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
):
super
().
__init__
()
# Attention block.
self
.
attn
=
OlmoAttention
(
config
,
linear_method
)
# MLP block.
self
.
mlp
=
OlmoMLP
(
config
,
linear_method
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input_metadata
:
InputMetadata
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]]:
# Attention block.
og_x
=
hidden_states
x
=
self
.
attn
(
positions
,
hidden_states
,
kv_cache
,
input_metadata
)
x
=
x
+
og_x
# MLP block.
hidden_states
=
self
.
mlp
(
x
)
return
hidden_states
class
OlmoModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
OLMoConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
):
super
().
__init__
()
self
.
config
=
config
self
.
transformer
=
nn
.
ModuleDict
(
dict
(
wte
=
VocabParallelEmbedding
(
config
.
embedding_size
or
config
.
vocab_size
,
config
.
d_model
,
),
ln_f
=
nn
.
LayerNorm
(
config
.
d_model
,
elementwise_affine
=
False
,
bias
=
False
),
))
blocks
=
[
OlmoBlock
(
config
,
linear_method
)
for
i
in
range
(
config
.
n_layers
)
]
if
self
.
config
.
block_group_size
>
1
:
raise
NotImplementedError
(
"Block group size > 1 not supported yet"
)
else
:
self
.
transformer
.
update
({
"blocks"
:
nn
.
ModuleList
(
blocks
)})
if
not
config
.
weight_tying
:
self
.
transformer
.
update
({
"ff_out"
:
ColumnParallelLinear
(
config
.
d_model
,
config
.
embedding_size
or
config
.
vocab_size
,
bias
=
config
.
include_bias
,
linear_method
=
linear_method
,
)
})
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
"""
:param input_ids: A tensor of shape `(batch_size, seq_len)`.
"""
# Get embeddings of input.
# shape: (batch_size, seq_len, d_model)
x
=
self
.
transformer
.
wte
(
input_ids
)
# type: ignore
# Apply blocks one-by-one.
for
block_idx
,
block
in
enumerate
(
self
.
transformer
.
blocks
):
# shape: (batch_size, seq_len, d_model)
x
=
block
(
positions
,
x
,
kv_caches
[
block_idx
],
input_metadata
,
)
# Apply final layer norm.
# shape: (batch_size, seq_len or 1, d_model)
x
=
self
.
transformer
.
ln_f
(
x
)
# type: ignore
return
x
class
OLMoForCausalLM
(
nn
.
Module
):
"""
Extremely barebones HF model wrapper.
"""
def
__init__
(
self
,
config
:
OLMoConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
):
super
().
__init__
()
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
model
=
OlmoModel
(
config
,
linear_method
)
self
.
lm_head_weight
=
(
self
.
model
.
transformer
.
wte
.
weight
if
config
.
weight_tying
else
self
.
model
.
transformer
.
ff_out
.
weight
)
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
=
input_ids
,
positions
=
positions
,
kv_caches
=
kv_caches
,
input_metadata
=
input_metadata
,
)
return
hidden_states
def
sample
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
"auto"
,
revision
:
Optional
[
str
]
=
None
,
):
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
load_format
,
revision
):
# attention
if
".att"
in
name
:
name
=
name
.
replace
(
".att"
,
".attn.att"
)
# mlp
if
".ff"
in
name
and
"transformer.ff_out"
not
in
name
:
name
=
name
.
replace
(
".ff"
,
".mlp.ff"
)
# there is no bias in olmo
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
vllm/model_executor/sampling_metadata.py
View file @
40542023
...
@@ -19,6 +19,7 @@ class SamplingMetadata:
...
@@ -19,6 +19,7 @@ class SamplingMetadata:
prompt_lens: Lengths of prompts.
prompt_lens: Lengths of prompts.
selected_token_indices: Token indices selected for sampling.
selected_token_indices: Token indices selected for sampling.
categorized_sample_indices: SamplingType -> token indices to sample.
categorized_sample_indices: SamplingType -> token indices to sample.
generators: List of torch.Generators to use for seeded sampling
perform_sampling: Whether to perform sampling. This option is used to
perform_sampling: Whether to perform sampling. This option is used to
make the sampling only happens in the driver worker, and disable
make the sampling only happens in the driver worker, and disable
sampling in other worker processes.
sampling in other worker processes.
...
@@ -31,6 +32,7 @@ class SamplingMetadata:
...
@@ -31,6 +32,7 @@ class SamplingMetadata:
prompt_lens
:
Optional
[
List
[
int
]],
prompt_lens
:
Optional
[
List
[
int
]],
selected_token_indices
:
torch
.
Tensor
,
selected_token_indices
:
torch
.
Tensor
,
categorized_sample_indices
:
Optional
[
Dict
[
SamplingType
,
torch
.
Tensor
]],
categorized_sample_indices
:
Optional
[
Dict
[
SamplingType
,
torch
.
Tensor
]],
generators
:
Optional
[
List
[
torch
.
Generator
]]
=
None
,
perform_sampling
:
bool
=
True
,
perform_sampling
:
bool
=
True
,
)
->
None
:
)
->
None
:
self
.
seq_groups
=
seq_groups
self
.
seq_groups
=
seq_groups
...
@@ -38,6 +40,7 @@ class SamplingMetadata:
...
@@ -38,6 +40,7 @@ class SamplingMetadata:
self
.
prompt_lens
=
prompt_lens
self
.
prompt_lens
=
prompt_lens
self
.
selected_token_indices
=
selected_token_indices
self
.
selected_token_indices
=
selected_token_indices
self
.
categorized_sample_indices
=
categorized_sample_indices
self
.
categorized_sample_indices
=
categorized_sample_indices
self
.
generators
=
generators
self
.
perform_sampling
=
perform_sampling
self
.
perform_sampling
=
perform_sampling
self
.
num_prompts
=
len
(
prompt_lens
)
if
prompt_lens
is
not
None
else
0
self
.
num_prompts
=
len
(
prompt_lens
)
if
prompt_lens
is
not
None
else
0
...
...
vllm/outputs.py
View file @
40542023
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
import
time
from
vllm.sequence
import
(
PromptLogprobs
,
SampleLogprobs
,
SequenceGroup
,
from
vllm.sequence
import
(
PromptLogprobs
,
SampleLogprobs
,
SequenceGroup
,
SequenceStatus
)
SequenceStatus
,
RequestMetrics
)
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
...
@@ -60,6 +61,7 @@ class RequestOutput:
...
@@ -60,6 +61,7 @@ class RequestOutput:
prompt_logprobs: The log probabilities to return per prompt token.
prompt_logprobs: The log probabilities to return per prompt token.
outputs: The output sequences of the request.
outputs: The output sequences of the request.
finished: Whether the whole request is finished.
finished: Whether the whole request is finished.
metrics: Metrics associated with the request.
lora_request: The LoRA request that was used to generate the output.
lora_request: The LoRA request that was used to generate the output.
"""
"""
...
@@ -71,6 +73,7 @@ class RequestOutput:
...
@@ -71,6 +73,7 @@ class RequestOutput:
prompt_logprobs
:
Optional
[
PromptLogprobs
],
prompt_logprobs
:
Optional
[
PromptLogprobs
],
outputs
:
List
[
CompletionOutput
],
outputs
:
List
[
CompletionOutput
],
finished
:
bool
,
finished
:
bool
,
metrics
:
Optional
[
RequestMetrics
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
)
->
None
:
)
->
None
:
self
.
request_id
=
request_id
self
.
request_id
=
request_id
...
@@ -79,6 +82,7 @@ class RequestOutput:
...
@@ -79,6 +82,7 @@ class RequestOutput:
self
.
prompt_logprobs
=
prompt_logprobs
self
.
prompt_logprobs
=
prompt_logprobs
self
.
outputs
=
outputs
self
.
outputs
=
outputs
self
.
finished
=
finished
self
.
finished
=
finished
self
.
metrics
=
metrics
self
.
lora_request
=
lora_request
self
.
lora_request
=
lora_request
@
classmethod
@
classmethod
...
@@ -115,12 +119,15 @@ class RequestOutput:
...
@@ -115,12 +119,15 @@ class RequestOutput:
prompt_token_ids
=
seq_group
.
prompt_token_ids
prompt_token_ids
=
seq_group
.
prompt_token_ids
prompt_logprobs
=
seq_group
.
prompt_logprobs
prompt_logprobs
=
seq_group
.
prompt_logprobs
finished
=
seq_group
.
is_finished
()
finished
=
seq_group
.
is_finished
()
finished_time
=
time
.
time
()
if
finished
else
None
seq_group
.
set_finished_time
(
finished_time
)
return
cls
(
seq_group
.
request_id
,
return
cls
(
seq_group
.
request_id
,
prompt
,
prompt
,
prompt_token_ids
,
prompt_token_ids
,
prompt_logprobs
,
prompt_logprobs
,
outputs
,
outputs
,
finished
,
finished
,
seq_group
.
metrics
,
lora_request
=
seq_group
.
lora_request
)
lora_request
=
seq_group
.
lora_request
)
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
...
@@ -130,4 +137,5 @@ class RequestOutput:
...
@@ -130,4 +137,5 @@ class RequestOutput:
f
"prompt_logprobs=
{
self
.
prompt_logprobs
}
, "
f
"prompt_logprobs=
{
self
.
prompt_logprobs
}
, "
f
"outputs=
{
self
.
outputs
}
, "
f
"outputs=
{
self
.
outputs
}
, "
f
"finished=
{
self
.
finished
}
, "
f
"finished=
{
self
.
finished
}
, "
f
"metrics=
{
self
.
metrics
}
, "
f
"lora_request=
{
self
.
lora_request
}
)"
)
f
"lora_request=
{
self
.
lora_request
}
)"
)
vllm/sampling_params.py
View file @
40542023
...
@@ -11,7 +11,8 @@ _SAMPLING_EPS = 1e-5
...
@@ -11,7 +11,8 @@ _SAMPLING_EPS = 1e-5
class
SamplingType
(
IntEnum
):
class
SamplingType
(
IntEnum
):
GREEDY
=
0
GREEDY
=
0
RANDOM
=
1
RANDOM
=
1
BEAM
=
2
RANDOM_SEED
=
2
BEAM
=
3
LogitsProcessor
=
Callable
[[
List
[
int
],
torch
.
Tensor
],
torch
.
Tensor
]
LogitsProcessor
=
Callable
[[
List
[
int
],
torch
.
Tensor
],
torch
.
Tensor
]
...
@@ -56,6 +57,7 @@ class SamplingParams:
...
@@ -56,6 +57,7 @@ class SamplingParams:
min_p: Float that represents the minimum probability for a token to be
min_p: Float that represents the minimum probability for a token to be
considered, relative to the probability of the most likely token.
considered, relative to the probability of the most likely token.
Must be in [0, 1]. Set to 0 to disable this.
Must be in [0, 1]. Set to 0 to disable this.
seed: Random seed to use for the generation.
use_beam_search: Whether to use beam search instead of sampling.
use_beam_search: Whether to use beam search instead of sampling.
length_penalty: Float that penalizes sequences based on their length.
length_penalty: Float that penalizes sequences based on their length.
Used in beam search.
Used in beam search.
...
@@ -101,6 +103,7 @@ class SamplingParams:
...
@@ -101,6 +103,7 @@ class SamplingParams:
top_p
:
float
=
1.0
,
top_p
:
float
=
1.0
,
top_k
:
int
=
-
1
,
top_k
:
int
=
-
1
,
min_p
:
float
=
0.0
,
min_p
:
float
=
0.0
,
seed
:
Optional
[
int
]
=
None
,
use_beam_search
:
bool
=
False
,
use_beam_search
:
bool
=
False
,
length_penalty
:
float
=
1.0
,
length_penalty
:
float
=
1.0
,
early_stopping
:
Union
[
bool
,
str
]
=
False
,
early_stopping
:
Union
[
bool
,
str
]
=
False
,
...
@@ -124,6 +127,7 @@ class SamplingParams:
...
@@ -124,6 +127,7 @@ class SamplingParams:
self
.
top_p
=
top_p
self
.
top_p
=
top_p
self
.
top_k
=
top_k
self
.
top_k
=
top_k
self
.
min_p
=
min_p
self
.
min_p
=
min_p
self
.
seed
=
seed
self
.
use_beam_search
=
use_beam_search
self
.
use_beam_search
=
use_beam_search
self
.
length_penalty
=
length_penalty
self
.
length_penalty
=
length_penalty
self
.
early_stopping
=
early_stopping
self
.
early_stopping
=
early_stopping
...
@@ -229,6 +233,8 @@ class SamplingParams:
...
@@ -229,6 +233,8 @@ class SamplingParams:
return
SamplingType
.
BEAM
return
SamplingType
.
BEAM
if
self
.
temperature
<
_SAMPLING_EPS
:
if
self
.
temperature
<
_SAMPLING_EPS
:
return
SamplingType
.
GREEDY
return
SamplingType
.
GREEDY
if
self
.
seed
is
not
None
:
return
SamplingType
.
RANDOM_SEED
return
SamplingType
.
RANDOM
return
SamplingType
.
RANDOM
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
...
@@ -242,6 +248,7 @@ class SamplingParams:
...
@@ -242,6 +248,7 @@ class SamplingParams:
f
"top_p=
{
self
.
top_p
}
, "
f
"top_p=
{
self
.
top_p
}
, "
f
"top_k=
{
self
.
top_k
}
, "
f
"top_k=
{
self
.
top_k
}
, "
f
"min_p=
{
self
.
min_p
}
, "
f
"min_p=
{
self
.
min_p
}
, "
f
"seed=
{
self
.
seed
}
, "
f
"use_beam_search=
{
self
.
use_beam_search
}
, "
f
"use_beam_search=
{
self
.
use_beam_search
}
, "
f
"length_penalty=
{
self
.
length_penalty
}
, "
f
"length_penalty=
{
self
.
length_penalty
}
, "
f
"early_stopping=
{
self
.
early_stopping
}
, "
f
"early_stopping=
{
self
.
early_stopping
}
, "
...
...
vllm/sequence.py
View file @
40542023
"""Sequence and its related classes."""
"""Sequence and its related classes."""
import
copy
import
copy
import
enum
import
enum
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Optional
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Union
from
vllm.block
import
LogicalTokenBlock
from
vllm.block
import
LogicalTokenBlock
...
@@ -49,6 +50,25 @@ class SequenceStatus(enum.Enum):
...
@@ -49,6 +50,25 @@ class SequenceStatus(enum.Enum):
return
finish_reason
return
finish_reason
@
dataclass
class
RequestMetrics
:
"""Metrics associated with a request.
Args:
arrival_time: The time when the request arrived.
first_scheduled_time: The time when the request was first scheduled.
first_token_time: The time when the first token was generated.
time_in_queue: The time the request spent in the queue.
finished_time: The time when the request was finished.
"""
arrival_time
:
float
last_token_time
:
float
first_scheduled_time
:
Optional
[
float
]
first_token_time
:
Optional
[
float
]
time_in_queue
:
Optional
[
float
]
finished_time
:
Optional
[
float
]
=
None
class
SequenceData
:
class
SequenceData
:
"""Data associated with a sequence.
"""Data associated with a sequence.
...
@@ -228,6 +248,14 @@ class Sequence:
...
@@ -228,6 +248,14 @@ class Sequence:
f
"num_blocks=
{
len
(
self
.
logical_token_blocks
)
}
)"
)
f
"num_blocks=
{
len
(
self
.
logical_token_blocks
)
}
)"
)
@
dataclass
class
SequenceGroupState
:
"""Mutable state tied to a specific sequence group"""
# torch.Generator used in seeded sampling
generator
:
Optional
=
None
class
SequenceGroup
:
class
SequenceGroup
:
"""A group of sequences that are generated from the same prompt.
"""A group of sequences that are generated from the same prompt.
...
@@ -252,11 +280,15 @@ class SequenceGroup:
...
@@ -252,11 +280,15 @@ class SequenceGroup:
self
.
request_id
=
request_id
self
.
request_id
=
request_id
self
.
seqs_dict
=
{
seq
.
seq_id
:
seq
for
seq
in
seqs
}
self
.
seqs_dict
=
{
seq
.
seq_id
:
seq
for
seq
in
seqs
}
self
.
sampling_params
=
sampling_params
self
.
sampling_params
=
sampling_params
self
.
arrival_time
=
arrival_time
self
.
metrics
=
RequestMetrics
(
arrival_time
=
arrival_time
,
self
.
last_token_time
=
arrival_time
last_token_time
=
arrival_time
,
first_scheduled_time
=
None
,
first_token_time
=
None
,
time_in_queue
=
None
)
self
.
lora_request
=
lora_request
self
.
lora_request
=
lora_request
self
.
prefix
:
Optional
[
Prefix
]
=
prefix
self
.
prefix
:
Optional
[
Prefix
]
=
prefix
self
.
prompt_logprobs
:
Optional
[
PromptLogprobs
]
=
None
self
.
prompt_logprobs
:
Optional
[
PromptLogprobs
]
=
None
self
.
state
=
SequenceGroupState
()
@
property
@
property
def
prompt
(
self
)
->
str
:
def
prompt
(
self
)
->
str
:
...
@@ -276,10 +308,25 @@ class SequenceGroup:
...
@@ -276,10 +308,25 @@ class SequenceGroup:
def
get_last_latency
(
self
,
now
:
float
)
->
float
:
def
get_last_latency
(
self
,
now
:
float
)
->
float
:
"""Gets last token latency for Request level timings."""
"""Gets last token latency for Request level timings."""
latency
=
now
-
self
.
last_token_time
latency
=
now
-
self
.
metrics
.
last_token_time
self
.
last_token_time
=
now
self
.
metrics
.
last_token_time
=
now
return
latency
return
latency
def
maybe_set_first_token_time
(
self
,
time
:
float
)
->
None
:
"""Sets the first token time for Request level timings."""
if
self
.
metrics
.
first_token_time
is
None
:
self
.
metrics
.
first_token_time
=
time
def
maybe_set_first_scheduled_time
(
self
,
time
:
float
)
->
None
:
"""Sets the first scheduled time and time in queue for Request level timings."""
if
self
.
metrics
.
first_scheduled_time
is
None
:
self
.
metrics
.
first_scheduled_time
=
time
self
.
metrics
.
time_in_queue
=
time
-
self
.
metrics
.
arrival_time
def
set_finished_time
(
self
,
time
:
Optional
[
float
])
->
None
:
"""Sets the finished time for Request level timings."""
self
.
metrics
.
finished_time
=
time
def
get_max_num_running_seqs
(
self
)
->
int
:
def
get_max_num_running_seqs
(
self
)
->
int
:
"""The maximum number of sequences running in parallel in the remaining
"""The maximum number of sequences running in parallel in the remaining
lifetime of the request."""
lifetime of the request."""
...
@@ -359,6 +406,7 @@ class SequenceGroupMetadata:
...
@@ -359,6 +406,7 @@ class SequenceGroupMetadata:
sampling_params: The sampling parameters used to generate the outputs.
sampling_params: The sampling parameters used to generate the outputs.
block_tables: The block tables. (Seq id -> list of physical block
block_tables: The block tables. (Seq id -> list of physical block
numbers)
numbers)
state: Internal state tied to this sequence group.
lora_request: LoRA request.
lora_request: LoRA request.
prefix: The prefix of the prompt of the sequence group.
prefix: The prefix of the prompt of the sequence group.
"""
"""
...
@@ -372,6 +420,7 @@ class SequenceGroupMetadata:
...
@@ -372,6 +420,7 @@ class SequenceGroupMetadata:
block_tables
:
Dict
[
int
,
List
[
int
]],
block_tables
:
Dict
[
int
,
List
[
int
]],
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
prefix
:
Optional
[
Prefix
]
=
None
,
prefix
:
Optional
[
Prefix
]
=
None
,
state
:
Optional
[
SequenceGroupState
]
=
None
,
)
->
None
:
)
->
None
:
self
.
request_id
=
request_id
self
.
request_id
=
request_id
self
.
is_prompt
=
is_prompt
self
.
is_prompt
=
is_prompt
...
@@ -380,6 +429,7 @@ class SequenceGroupMetadata:
...
@@ -380,6 +429,7 @@ class SequenceGroupMetadata:
self
.
block_tables
=
block_tables
self
.
block_tables
=
block_tables
self
.
lora_request
=
lora_request
self
.
lora_request
=
lora_request
self
.
prefix
=
prefix
self
.
prefix
=
prefix
self
.
state
=
SequenceGroupState
()
if
state
is
None
else
state
@
property
@
property
def
lora_int_id
(
self
)
->
int
:
def
lora_int_id
(
self
)
->
int
:
...
...
vllm/transformers_utils/config.py
View file @
40542023
...
@@ -16,10 +16,14 @@ _CONFIG_REGISTRY = {
...
@@ -16,10 +16,14 @@ _CONFIG_REGISTRY = {
def
get_config
(
model
:
str
,
def
get_config
(
model
:
str
,
trust_remote_code
:
bool
,
trust_remote_code
:
bool
,
revision
:
Optional
[
str
]
=
None
)
->
PretrainedConfig
:
revision
:
Optional
[
str
]
=
None
,
code_revision
:
Optional
[
str
]
=
None
)
->
PretrainedConfig
:
try
:
try
:
config
=
AutoConfig
.
from_pretrained
(
config
=
AutoConfig
.
from_pretrained
(
model
,
trust_remote_code
=
trust_remote_code
,
revision
=
revision
)
model
,
trust_remote_code
=
trust_remote_code
,
revision
=
revision
,
code_revision
=
code_revision
)
except
ValueError
as
e
:
except
ValueError
as
e
:
if
(
not
trust_remote_code
and
if
(
not
trust_remote_code
and
"requires you to execute the configuration file"
in
str
(
e
)):
"requires you to execute the configuration file"
in
str
(
e
)):
...
@@ -33,5 +37,7 @@ def get_config(model: str,
...
@@ -33,5 +37,7 @@ def get_config(model: str,
raise
e
raise
e
if
config
.
model_type
in
_CONFIG_REGISTRY
:
if
config
.
model_type
in
_CONFIG_REGISTRY
:
config_class
=
_CONFIG_REGISTRY
[
config
.
model_type
]
config_class
=
_CONFIG_REGISTRY
[
config
.
model_type
]
config
=
config_class
.
from_pretrained
(
model
,
revision
=
revision
)
config
=
config_class
.
from_pretrained
(
model
,
revision
=
revision
,
code_revision
=
code_revision
)
return
config
return
config
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment