Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
d6fa1be3
Unverified
Commit
d6fa1be3
authored
Jul 03, 2023
by
Zhuohan Li
Committed by
GitHub
Jul 03, 2023
Browse files
[Quality] Add code formatter and linter (#326)
parent
0ffded81
Changes
47
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
569 additions
and
374 deletions
+569
-374
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+18
-15
vllm/engine/ray_utils.py
vllm/engine/ray_utils.py
+7
-6
vllm/entrypoints/api_server.py
vllm/entrypoints/api_server.py
+8
-9
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+1
-2
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+67
-46
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+4
-2
vllm/logger.py
vllm/logger.py
+3
-3
vllm/model_executor/__init__.py
vllm/model_executor/__init__.py
+0
-1
vllm/model_executor/input_metadata.py
vllm/model_executor/input_metadata.py
+13
-2
vllm/model_executor/layers/activation.py
vllm/model_executor/layers/activation.py
+9
-10
vllm/model_executor/layers/attention.py
vllm/model_executor/layers/attention.py
+81
-32
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+45
-37
vllm/model_executor/model_loader.py
vllm/model_executor/model_loader.py
+6
-7
vllm/model_executor/models/__init__.py
vllm/model_executor/models/__init__.py
+0
-2
vllm/model_executor/models/gpt2.py
vllm/model_executor/models/gpt2.py
+53
-33
vllm/model_executor/models/gpt_bigcode.py
vllm/model_executor/models/gpt_bigcode.py
+66
-38
vllm/model_executor/models/gpt_neox.py
vllm/model_executor/models/gpt_neox.py
+52
-33
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+57
-39
vllm/model_executor/models/opt.py
vllm/model_executor/models/opt.py
+69
-46
vllm/model_executor/weight_utils.py
vllm/model_executor/weight_utils.py
+10
-11
No files found.
vllm/engine/llm_engine.py
View file @
d6fa1be3
...
@@ -67,8 +67,7 @@ class LLMEngine:
...
@@ -67,8 +67,7 @@ class LLMEngine:
f
"download_dir=
{
model_config
.
download_dir
!
r
}
, "
f
"download_dir=
{
model_config
.
download_dir
!
r
}
, "
f
"use_np_weights=
{
model_config
.
use_np_weights
}
, "
f
"use_np_weights=
{
model_config
.
use_np_weights
}
, "
f
"tensor_parallel_size=
{
parallel_config
.
tensor_parallel_size
}
, "
f
"tensor_parallel_size=
{
parallel_config
.
tensor_parallel_size
}
, "
f
"seed=
{
model_config
.
seed
}
)"
f
"seed=
{
model_config
.
seed
}
)"
)
)
# TODO(woosuk): Print more configs in debug mode.
# TODO(woosuk): Print more configs in debug mode.
self
.
model_config
=
model_config
self
.
model_config
=
model_config
...
@@ -78,8 +77,8 @@ class LLMEngine:
...
@@ -78,8 +77,8 @@ class LLMEngine:
self
.
log_stats
=
log_stats
self
.
log_stats
=
log_stats
self
.
_verify_args
()
self
.
_verify_args
()
self
.
tokenizer
=
get_tokenizer
(
model_config
.
tokenizer
,
self
.
tokenizer
=
get_tokenizer
(
model_config
.
tokenizer_mode
)
model_config
.
tokenizer
,
tokenizer_mode
=
model_config
.
tokenizer_mode
)
self
.
seq_counter
=
Counter
()
self
.
seq_counter
=
Counter
()
# Create the parallel GPU workers.
# Create the parallel GPU workers.
...
@@ -129,8 +128,8 @@ class LLMEngine:
...
@@ -129,8 +128,8 @@ class LLMEngine:
num_gpu_blocks
=
min
(
b
[
0
]
for
b
in
num_blocks
)
num_gpu_blocks
=
min
(
b
[
0
]
for
b
in
num_blocks
)
num_cpu_blocks
=
min
(
b
[
1
]
for
b
in
num_blocks
)
num_cpu_blocks
=
min
(
b
[
1
]
for
b
in
num_blocks
)
# FIXME(woosuk): Change to debug log.
# FIXME(woosuk): Change to debug log.
logger
.
info
(
f
'
# GPU blocks:
{
num_gpu_blocks
}
,
'
logger
.
info
(
f
"
# GPU blocks:
{
num_gpu_blocks
}
,
"
f
'
# CPU blocks:
{
num_cpu_blocks
}
'
)
f
"
# CPU blocks:
{
num_cpu_blocks
}
"
)
if
num_gpu_blocks
<=
0
:
if
num_gpu_blocks
<=
0
:
raise
ValueError
(
"No available memory for the cache blocks. "
raise
ValueError
(
"No available memory for the cache blocks. "
...
@@ -152,7 +151,9 @@ class LLMEngine:
...
@@ -152,7 +151,9 @@ class LLMEngine:
# Initialize the cluster.
# Initialize the cluster.
distributed_init_method
,
devices
=
initialize_cluster
(
parallel_config
)
distributed_init_method
,
devices
=
initialize_cluster
(
parallel_config
)
# Create the LLM engine.
# Create the LLM engine.
engine
=
cls
(
*
engine_configs
,
distributed_init_method
,
devices
,
engine
=
cls
(
*
engine_configs
,
distributed_init_method
,
devices
,
log_stats
=
not
engine_args
.
disable_log_stats
)
log_stats
=
not
engine_args
.
disable_log_stats
)
return
engine
return
engine
...
@@ -226,8 +227,10 @@ class LLMEngine:
...
@@ -226,8 +227,10 @@ class LLMEngine:
and updates the scheduler with the model outputs. Finally, it decodes
and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results.
the sequences and returns the newly generated results.
"""
"""
seq_group_metadata_list
,
scheduler_outputs
,
ignored_seq_groups
=
self
.
scheduler
.
schedule
()
(
seq_group_metadata_list
,
scheduler_outputs
,
if
(
not
seq_group_metadata_list
)
and
scheduler_outputs
.
is_empty
()
and
(
not
ignored_seq_groups
):
ignored_seq_groups
)
=
self
.
scheduler
.
schedule
()
if
((
not
seq_group_metadata_list
)
and
scheduler_outputs
.
is_empty
()
and
(
not
ignored_seq_groups
)):
# Nothing to do.
# Nothing to do.
return
[]
return
[]
...
@@ -281,8 +284,8 @@ class LLMEngine:
...
@@ -281,8 +284,8 @@ class LLMEngine:
# Truncate the output text so that the stop string is
# Truncate the output text so that the stop string is
# not included in the output.
# not included in the output.
seq
.
output_text
=
seq
.
output_text
[:
-
len
(
stop_str
)]
seq
.
output_text
=
seq
.
output_text
[:
-
len
(
stop_str
)]
self
.
scheduler
.
free_seq
(
seq
,
self
.
scheduler
.
free_seq
(
SequenceStatus
.
FINISHED_STOPPED
)
seq
,
SequenceStatus
.
FINISHED_STOPPED
)
stopped
=
True
stopped
=
True
break
break
if
stopped
:
if
stopped
:
...
@@ -290,7 +293,7 @@ class LLMEngine:
...
@@ -290,7 +293,7 @@ class LLMEngine:
# Check if the sequence has reached max_seq_len.
# Check if the sequence has reached max_seq_len.
if
(
seq
.
get_len
()
>=
if
(
seq
.
get_len
()
>=
self
.
scheduler
.
scheduler_config
.
max_seq_len
):
self
.
scheduler
.
scheduler_config
.
max_seq_len
):
self
.
scheduler
.
free_seq
(
self
.
scheduler
.
free_seq
(
seq
,
SequenceStatus
.
FINISHED_LENGTH_CAPPED
)
seq
,
SequenceStatus
.
FINISHED_LENGTH_CAPPED
)
continue
continue
...
@@ -302,15 +305,15 @@ class LLMEngine:
...
@@ -302,15 +305,15 @@ class LLMEngine:
# Check if the sequence has generated the EOS token.
# Check if the sequence has generated the EOS token.
if
not
sampling_params
.
ignore_eos
:
if
not
sampling_params
.
ignore_eos
:
if
seq
.
get_last_token_id
()
==
self
.
tokenizer
.
eos_token_id
:
if
seq
.
get_last_token_id
()
==
self
.
tokenizer
.
eos_token_id
:
self
.
scheduler
.
free_seq
(
seq
,
self
.
scheduler
.
free_seq
(
SequenceStatus
.
FINISHED_STOPPED
)
seq
,
SequenceStatus
.
FINISHED_STOPPED
)
continue
continue
def
_run_workers
(
def
_run_workers
(
self
,
self
,
method
:
str
,
method
:
str
,
get_all_outputs
:
bool
=
False
,
*
args
,
*
args
,
get_all_outputs
:
bool
=
False
,
**
kwargs
,
**
kwargs
,
)
->
Any
:
)
->
Any
:
"""Runs the given method on all workers."""
"""Runs the given method on all workers."""
...
...
vllm/engine/ray_utils.py
View file @
d6fa1be3
...
@@ -8,7 +8,8 @@ except ImportError:
...
@@ -8,7 +8,8 @@ except ImportError:
from
vllm.config
import
ParallelConfig
from
vllm.config
import
ParallelConfig
DeviceID
=
Tuple
[
int
,
Optional
[
str
],
int
]
# rank, node resource (node IP), device id
# rank, node resource (node IP), device id
DeviceID
=
Tuple
[
int
,
Optional
[
str
],
int
]
def
initialize_cluster
(
def
initialize_cluster
(
...
@@ -53,15 +54,15 @@ def initialize_cluster(
...
@@ -53,15 +54,15 @@ def initialize_cluster(
valid_node_resources
=
[]
valid_node_resources
=
[]
num_devices_per_node
=
None
num_devices_per_node
=
None
for
node
in
ray
.
nodes
():
for
node
in
ray
.
nodes
():
if
(
not
node
[
'
Alive
'
])
or
node
[
'
Resources
'
][
'
GPU
'
]
<=
0
:
if
(
not
node
[
"
Alive
"
])
or
node
[
"
Resources
"
][
"
GPU
"
]
<=
0
:
continue
continue
if
num_devices_per_node
is
None
:
if
num_devices_per_node
is
None
:
num_devices_per_node
=
node
[
'
Resources
'
][
'
GPU
'
]
num_devices_per_node
=
node
[
"
Resources
"
][
"
GPU
"
]
else
:
else
:
assert
num_devices_per_node
==
node
[
'
Resources
'
][
'
GPU
'
],
(
assert
num_devices_per_node
==
node
[
"
Resources
"
][
"
GPU
"
],
(
"The number of GPUs per node is not uniform."
)
"The number of GPUs per node is not uniform."
)
for
key
in
node
[
'
Resources
'
]:
for
key
in
node
[
"
Resources
"
]:
if
key
.
startswith
(
'
node:
'
):
if
key
.
startswith
(
"
node:
"
):
valid_node_resources
.
append
(
key
)
valid_node_resources
.
append
(
key
)
# Verify the parallel config.
# Verify the parallel config.
...
...
vllm/entrypoints/api_server.py
View file @
d6fa1be3
...
@@ -11,8 +11,8 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
...
@@ -11,8 +11,8 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils
import
random_uuid
from
vllm.utils
import
random_uuid
TIMEOUT_KEEP_ALIVE
=
5
# seconds.
TIMEOUT_KEEP_ALIVE
=
5
# seconds.
TIMEOUT_TO_PREVENT_DEADLOCK
=
1
# seconds
TIMEOUT_TO_PREVENT_DEADLOCK
=
1
# seconds
.
app
=
FastAPI
()
app
=
FastAPI
()
...
@@ -37,8 +37,7 @@ async def generate(request: Request) -> Response:
...
@@ -37,8 +37,7 @@ async def generate(request: Request) -> Response:
async
for
request_output
in
results_generator
:
async
for
request_output
in
results_generator
:
prompt
=
request_output
.
prompt
prompt
=
request_output
.
prompt
text_outputs
=
[
text_outputs
=
[
prompt
+
output
.
text
prompt
+
output
.
text
for
output
in
request_output
.
outputs
for
output
in
request_output
.
outputs
]
]
ret
=
{
"text"
:
text_outputs
}
ret
=
{
"text"
:
text_outputs
}
yield
(
json
.
dumps
(
ret
)
+
"
\0
"
).
encode
(
"utf-8"
)
yield
(
json
.
dumps
(
ret
)
+
"
\0
"
).
encode
(
"utf-8"
)
...
@@ -63,10 +62,7 @@ async def generate(request: Request) -> Response:
...
@@ -63,10 +62,7 @@ async def generate(request: Request) -> Response:
assert
final_output
is
not
None
assert
final_output
is
not
None
prompt
=
final_output
.
prompt
prompt
=
final_output
.
prompt
text_outputs
=
[
text_outputs
=
[
prompt
+
output
.
text
for
output
in
final_output
.
outputs
]
prompt
+
output
.
text
for
output
in
final_output
.
outputs
]
ret
=
{
"text"
:
text_outputs
}
ret
=
{
"text"
:
text_outputs
}
return
Response
(
content
=
json
.
dumps
(
ret
))
return
Response
(
content
=
json
.
dumps
(
ret
))
...
@@ -81,5 +77,8 @@ if __name__ == "__main__":
...
@@ -81,5 +77,8 @@ if __name__ == "__main__":
engine_args
=
AsyncEngineArgs
.
from_cli_args
(
args
)
engine_args
=
AsyncEngineArgs
.
from_cli_args
(
args
)
engine
=
AsyncLLMEngine
.
from_engine_args
(
engine_args
)
engine
=
AsyncLLMEngine
.
from_engine_args
(
engine_args
)
uvicorn
.
run
(
app
,
host
=
args
.
host
,
port
=
args
.
port
,
log_level
=
"debug"
,
uvicorn
.
run
(
app
,
host
=
args
.
host
,
port
=
args
.
port
,
log_level
=
"debug"
,
timeout_keep_alive
=
TIMEOUT_KEEP_ALIVE
)
timeout_keep_alive
=
TIMEOUT_KEEP_ALIVE
)
vllm/entrypoints/llm.py
View file @
d6fa1be3
...
@@ -63,8 +63,7 @@ class LLM:
...
@@ -63,8 +63,7 @@ class LLM:
self
.
request_counter
=
Counter
()
self
.
request_counter
=
Counter
()
def
get_tokenizer
(
def
get_tokenizer
(
self
,
self
)
->
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
]:
)
->
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
]:
return
self
.
llm_engine
.
tokenizer
return
self
.
llm_engine
.
tokenizer
def
set_tokenizer
(
def
set_tokenizer
(
...
...
vllm/entrypoints/openai/api_server.py
View file @
d6fa1be3
# Adapted from https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/serve/openai_api_server.py
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/serve/openai_api_server.py
import
argparse
import
argparse
from
http
import
HTTPStatus
from
http
import
HTTPStatus
...
@@ -29,7 +30,7 @@ from vllm.sampling_params import SamplingParams
...
@@ -29,7 +30,7 @@ from vllm.sampling_params import SamplingParams
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
from
vllm.utils
import
random_uuid
from
vllm.utils
import
random_uuid
TIMEOUT_KEEP_ALIVE
=
5
# seconds
TIMEOUT_KEEP_ALIVE
=
5
# seconds
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
served_model
=
None
served_model
=
None
...
@@ -38,14 +39,13 @@ app = fastapi.FastAPI()
...
@@ -38,14 +39,13 @@ app = fastapi.FastAPI()
def
create_error_response
(
status_code
:
HTTPStatus
,
def
create_error_response
(
status_code
:
HTTPStatus
,
message
:
str
)
->
JSONResponse
:
message
:
str
)
->
JSONResponse
:
return
JSONResponse
(
return
JSONResponse
(
ErrorResponse
(
message
=
message
,
ErrorResponse
(
message
=
message
,
type
=
"invalid_request_error"
).
dict
(),
type
=
"invalid_request_error"
).
dict
(),
status_code
=
status_code
.
value
status_code
=
status_code
.
value
)
)
@
app
.
exception_handler
(
RequestValidationError
)
@
app
.
exception_handler
(
RequestValidationError
)
async
def
validation_exception_handler
(
request
,
exc
):
async
def
validation_exception_handler
(
request
,
exc
):
# pylint: disable=unused-argument
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
str
(
exc
))
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
str
(
exc
))
...
@@ -126,8 +126,11 @@ async def check_length(request, prompt, engine):
...
@@ -126,8 +126,11 @@ async def check_length(request, prompt, engine):
@
app
.
get
(
"/v1/models"
)
@
app
.
get
(
"/v1/models"
)
async
def
show_available_models
():
async
def
show_available_models
():
"""Show available models. Right now we only have one model."""
"""Show available models. Right now we only have one model."""
model_cards
=
[
ModelCard
(
id
=
served_model
,
root
=
served_model
,
model_cards
=
[
permission
=
[
ModelPermission
()])]
ModelCard
(
id
=
served_model
,
root
=
served_model
,
permission
=
[
ModelPermission
()])
]
return
ModelList
(
data
=
model_cards
)
return
ModelList
(
data
=
model_cards
)
...
@@ -144,12 +147,14 @@ def create_logprobs(token_ids: List[int],
...
@@ -144,12 +147,14 @@ def create_logprobs(token_ids: List[int],
if
len
(
logprobs
.
text_offset
)
==
0
:
if
len
(
logprobs
.
text_offset
)
==
0
:
logprobs
.
text_offset
.
append
(
initial_text_offset
)
logprobs
.
text_offset
.
append
(
initial_text_offset
)
else
:
else
:
logprobs
.
text_offset
.
append
(
logprobs
.
text_offset
[
-
1
]
+
last_token_len
)
logprobs
.
text_offset
.
append
(
logprobs
.
text_offset
[
-
1
]
+
last_token_len
)
last_token_len
=
len
(
token
)
last_token_len
=
len
(
token
)
logprobs
.
top_logprobs
.
append
(
logprobs
.
top_logprobs
.
append
({
{
tokenizer
.
convert_ids_to_tokens
(
i
):
p
tokenizer
.
convert_ids_to_tokens
(
i
):
p
for
i
,
p
in
id_logprob
.
items
()})
for
i
,
p
in
id_logprob
.
items
()
})
return
logprobs
return
logprobs
...
@@ -348,7 +353,7 @@ async def create_completion(raw_request: Request):
...
@@ -348,7 +353,7 @@ async def create_completion(raw_request: Request):
if
request
.
suffix
is
not
None
:
if
request
.
suffix
is
not
None
:
# The language models we currently support do not support suffix.
# The language models we currently support do not support suffix.
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
"suffix is not currently supported"
)
"suffix is not currently supported"
)
if
request
.
logit_bias
is
not
None
:
if
request
.
logit_bias
is
not
None
:
# TODO: support logit_bias in vLLM engine.
# TODO: support logit_bias in vLLM engine.
...
@@ -387,22 +392,23 @@ async def create_completion(raw_request: Request):
...
@@ -387,22 +392,23 @@ async def create_completion(raw_request: Request):
except
ValueError
as
e
:
except
ValueError
as
e
:
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
str
(
e
))
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
str
(
e
))
result_generator
=
engine
.
generate
(
prompt
,
sampling_params
,
result_generator
=
engine
.
generate
(
prompt
,
sampling_params
,
request_id
)
request_id
)
# Similar to the OpenAI API, when n != best_of, we do not stream the
# Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use beam search.
# results. In addition, we do not stream the results when use beam search.
stream
=
(
request
.
stream
and
stream
=
(
request
.
stream
(
request
.
best_of
is
None
or
request
.
n
==
request
.
best_of
)
and
and
(
request
.
best_of
is
None
or
request
.
n
==
request
.
best_of
)
not
request
.
use_beam_search
)
and
not
request
.
use_beam_search
)
async
def
abort_request
()
->
None
:
async
def
abort_request
()
->
None
:
await
engine
.
abort
(
request_id
)
await
engine
.
abort
(
request_id
)
def
create_stream_response_json
(
index
:
int
,
def
create_stream_response_json
(
text
:
str
,
index
:
int
,
logprobs
:
Optional
[
LogProbs
]
=
None
,
text
:
str
,
finish_reason
:
Optional
[
str
]
=
None
)
->
str
:
logprobs
:
Optional
[
LogProbs
]
=
None
,
finish_reason
:
Optional
[
str
]
=
None
,
)
->
str
:
choice_data
=
CompletionResponseStreamChoice
(
choice_data
=
CompletionResponseStreamChoice
(
index
=
index
,
index
=
index
,
text
=
text
,
text
=
text
,
...
@@ -443,7 +449,8 @@ async def create_completion(raw_request: Request):
...
@@ -443,7 +449,8 @@ async def create_completion(raw_request: Request):
)
)
yield
f
"data:
{
response_json
}
\n\n
"
yield
f
"data:
{
response_json
}
\n\n
"
if
output
.
finish_reason
is
not
None
:
if
output
.
finish_reason
is
not
None
:
logprobs
=
LogProbs
()
if
request
.
logprobs
is
not
None
else
None
logprobs
=
(
LogProbs
()
if
request
.
logprobs
is
not
None
else
None
)
response_json
=
create_stream_response_json
(
response_json
=
create_stream_response_json
(
index
=
i
,
index
=
i
,
text
=
""
,
text
=
""
,
...
@@ -487,8 +494,8 @@ async def create_completion(raw_request: Request):
...
@@ -487,8 +494,8 @@ async def create_completion(raw_request: Request):
choices
.
append
(
choice_data
)
choices
.
append
(
choice_data
)
num_prompt_tokens
=
len
(
final_res
.
prompt_token_ids
)
num_prompt_tokens
=
len
(
final_res
.
prompt_token_ids
)
num_generated_tokens
=
sum
(
len
(
output
.
token_ids
)
num_generated_tokens
=
sum
(
for
output
in
final_res
.
outputs
)
len
(
output
.
token_ids
)
for
output
in
final_res
.
outputs
)
usage
=
UsageInfo
(
usage
=
UsageInfo
(
prompt_tokens
=
num_prompt_tokens
,
prompt_tokens
=
num_prompt_tokens
,
completion_tokens
=
num_generated_tokens
,
completion_tokens
=
num_generated_tokens
,
...
@@ -506,9 +513,11 @@ async def create_completion(raw_request: Request):
...
@@ -506,9 +513,11 @@ async def create_completion(raw_request: Request):
# When user requests streaming but we don't stream, we still need to
# When user requests streaming but we don't stream, we still need to
# return a streaming response with a single event.
# return a streaming response with a single event.
response_json
=
response
.
json
(
ensure_ascii
=
False
)
response_json
=
response
.
json
(
ensure_ascii
=
False
)
async
def
fake_stream_generator
()
->
AsyncGenerator
[
str
,
None
]:
async
def
fake_stream_generator
()
->
AsyncGenerator
[
str
,
None
]:
yield
f
"data:
{
response_json
}
\n\n
"
yield
f
"data:
{
response_json
}
\n\n
"
yield
"data: [DONE]
\n\n
"
yield
"data: [DONE]
\n\n
"
return
StreamingResponse
(
fake_stream_generator
(),
return
StreamingResponse
(
fake_stream_generator
(),
media_type
=
"text/event-stream"
)
media_type
=
"text/event-stream"
)
...
@@ -517,26 +526,34 @@ async def create_completion(raw_request: Request):
...
@@ -517,26 +526,34 @@ async def create_completion(raw_request: Request):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
parser
=
argparse
.
ArgumentParser
(
description
=
"vLLM OpenAI-Compatible RESTful API server."
description
=
"vLLM OpenAI-Compatible RESTful API server."
)
)
parser
.
add_argument
(
"--host"
,
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"localhost"
,
help
=
"host name"
)
type
=
str
,
default
=
"localhost"
,
help
=
"host name"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
8000
,
help
=
"port number"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
8000
,
help
=
"port number"
)
parser
.
add_argument
(
"--allow-credentials"
,
action
=
"store_true"
,
help
=
"allow credentials"
)
parser
.
add_argument
(
"--allowed-origins"
,
type
=
json
.
loads
,
default
=
[
"*"
],
help
=
"allowed origins"
)
parser
.
add_argument
(
"--allowed-methods"
,
type
=
json
.
loads
,
default
=
[
"*"
],
help
=
"allowed methods"
)
parser
.
add_argument
(
"--allowed-headers"
,
type
=
json
.
loads
,
default
=
[
"*"
],
help
=
"allowed headers"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--allow-credentials"
,
action
=
"store_true"
,
help
=
"allow credentials"
"--served-model-name"
,
)
type
=
str
,
parser
.
add_argument
(
default
=
None
,
"--allowed-origins"
,
type
=
json
.
loads
,
default
=
[
"*"
],
help
=
"allowed origins"
help
=
"The model name used in the API. If not specified, "
)
"the model name will be the same as the "
parser
.
add_argument
(
"huggingface name."
)
"--allowed-methods"
,
type
=
json
.
loads
,
default
=
[
"*"
],
help
=
"allowed methods"
)
parser
.
add_argument
(
"--allowed-headers"
,
type
=
json
.
loads
,
default
=
[
"*"
],
help
=
"allowed headers"
)
parser
.
add_argument
(
"--served-model-name"
,
type
=
str
,
default
=
None
,
help
=
"The model name used in the API. If not specified, "
"the model name will be the same as the "
"huggingface name."
)
parser
=
AsyncEngineArgs
.
add_cli_args
(
parser
)
parser
=
AsyncEngineArgs
.
add_cli_args
(
parser
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
@@ -556,7 +573,11 @@ if __name__ == "__main__":
...
@@ -556,7 +573,11 @@ if __name__ == "__main__":
engine
=
AsyncLLMEngine
.
from_engine_args
(
engine_args
)
engine
=
AsyncLLMEngine
.
from_engine_args
(
engine_args
)
# A separate tokenizer to map token IDs to strings.
# A separate tokenizer to map token IDs to strings.
tokenizer
=
get_tokenizer
(
engine_args
.
tokenizer
,
engine_args
.
tokenizer_mode
)
tokenizer
=
get_tokenizer
(
engine_args
.
tokenizer
,
tokenizer_mode
=
engine_args
.
tokenizer_mode
)
uvicorn
.
run
(
app
,
host
=
args
.
host
,
port
=
args
.
port
,
log_level
=
"info"
,
uvicorn
.
run
(
app
,
host
=
args
.
host
,
port
=
args
.
port
,
log_level
=
"info"
,
timeout_keep_alive
=
TIMEOUT_KEEP_ALIVE
)
timeout_keep_alive
=
TIMEOUT_KEEP_ALIVE
)
vllm/entrypoints/openai/protocol.py
View file @
d6fa1be3
# Adapted from https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
import
time
import
time
from
typing
import
Dict
,
List
,
Literal
,
Optional
,
Union
from
typing
import
Dict
,
List
,
Literal
,
Optional
,
Union
...
@@ -98,7 +99,8 @@ class LogProbs(BaseModel):
...
@@ -98,7 +99,8 @@ class LogProbs(BaseModel):
text_offset
:
List
[
int
]
=
Field
(
default_factory
=
list
)
text_offset
:
List
[
int
]
=
Field
(
default_factory
=
list
)
token_logprobs
:
List
[
Optional
[
float
]]
=
Field
(
default_factory
=
list
)
token_logprobs
:
List
[
Optional
[
float
]]
=
Field
(
default_factory
=
list
)
tokens
:
List
[
str
]
=
Field
(
default_factory
=
list
)
tokens
:
List
[
str
]
=
Field
(
default_factory
=
list
)
top_logprobs
:
List
[
Optional
[
Dict
[
str
,
float
]]]
=
Field
(
default_factory
=
list
)
top_logprobs
:
List
[
Optional
[
Dict
[
str
,
float
]]]
=
Field
(
default_factory
=
list
)
class
CompletionResponseChoice
(
BaseModel
):
class
CompletionResponseChoice
(
BaseModel
):
...
...
vllm/logger.py
View file @
d6fa1be3
# Adapted from https://github.com/skypilot-org/skypilot/blob/86dc0f6283a335e4aa37b3c10716f90999f48ab6/sky/sky_logging.py
# Adapted from
# https://github.com/skypilot-org/skypilot/blob/86dc0f6283a335e4aa37b3c10716f90999f48ab6/sky/sky_logging.py
"""Logging configuration for vLLM."""
import
logging
import
logging
import
sys
import
sys
_FORMAT
=
"%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s"
_FORMAT
=
"%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s"
_DATE_FORMAT
=
"%m-%d %H:%M:%S"
_DATE_FORMAT
=
"%m-%d %H:%M:%S"
...
...
vllm/model_executor/__init__.py
View file @
d6fa1be3
...
@@ -2,7 +2,6 @@ from vllm.model_executor.input_metadata import InputMetadata
...
@@ -2,7 +2,6 @@ from vllm.model_executor.input_metadata import InputMetadata
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.model_executor.utils
import
set_random_seed
__all__
=
[
__all__
=
[
"InputMetadata"
,
"InputMetadata"
,
"get_model"
,
"get_model"
,
...
...
vllm/model_executor/input_metadata.py
View file @
d6fa1be3
...
@@ -8,11 +8,22 @@ from vllm.sequence import SequenceData
...
@@ -8,11 +8,22 @@ from vllm.sequence import SequenceData
class
InputMetadata
:
class
InputMetadata
:
"""Metadata for input sequences. Used for PagedAttention.
Args:
seq_groups: List of (seq_ids, sampling_params).
seq_data: Seq_id -> SequenceData.
prompt_lens: Lengths of prompts.
slot_mapping: The address to write the new KV to of each token.
context_lens: the length of attention context for each generation token.
max_context_len: The maximum context length.
block_tables: The block tables. (Seq id -> list of physical block)
"""
def
__init__
(
def
__init__
(
self
,
self
,
seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]],
# List of (seq_ids, sampling_params).
seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]],
seq_data
:
Dict
[
int
,
SequenceData
],
# Seq_id -> SequenceData.
seq_data
:
Dict
[
int
,
SequenceData
],
prompt_lens
:
List
[
int
],
prompt_lens
:
List
[
int
],
slot_mapping
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
context_lens
:
torch
.
Tensor
,
context_lens
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/activation.py
View file @
d6fa1be3
...
@@ -6,9 +6,10 @@ from vllm import activation_ops
...
@@ -6,9 +6,10 @@ from vllm import activation_ops
_ACTIVATION_REGISTRY
=
{
_ACTIVATION_REGISTRY
=
{
"gelu"
:
nn
.
GELU
(),
"gelu"
:
nn
.
GELU
(),
"gelu_new"
:
nn
.
GELU
(
approximate
=
"tanh"
),
# NOTE: This may introduce small rounding errors.
# NOTE: The following GELU functions may introduce small rounding errors.
"gelu_fast"
:
nn
.
GELU
(
approximate
=
"tanh"
),
# NOTE: This may introduce small rounding errors.
"gelu_new"
:
nn
.
GELU
(
approximate
=
"tanh"
),
"gelu_pytorch_tanh"
:
nn
.
GELU
(
approximate
=
"tanh"
),
# NOTE: This may introduce small rounding errors.
"gelu_fast"
:
nn
.
GELU
(
approximate
=
"tanh"
),
"gelu_pytorch_tanh"
:
nn
.
GELU
(
approximate
=
"tanh"
),
"relu"
:
nn
.
ReLU
(),
"relu"
:
nn
.
ReLU
(),
}
}
...
@@ -25,15 +26,13 @@ class SiluAndMul(nn.Module):
...
@@ -25,15 +26,13 @@ class SiluAndMul(nn.Module):
"""An activation function for SwiGLU.
"""An activation function for SwiGLU.
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[1] // 2.
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[1] // 2.
"""
def
__init__
(
self
):
Shapes:
super
().
__init__
()
x: (num_tokens, 2 * d)
return: (num_tokens, d)
"""
def
forward
(
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
self
,
x
:
torch
.
Tensor
,
# (num_tokens, 2 * d)
)
->
torch
.
Tensor
:
# (num_tokens, d)
num_tokens
=
x
.
shape
[
0
]
num_tokens
=
x
.
shape
[
0
]
d
=
x
.
shape
[
1
]
//
2
d
=
x
.
shape
[
1
]
//
2
out
=
torch
.
empty
(
num_tokens
,
d
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
out
=
torch
.
empty
(
num_tokens
,
d
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
...
...
vllm/model_executor/layers/attention.py
View file @
d6fa1be3
...
@@ -14,6 +14,7 @@ _SUPPORTED_HEAD_SIZES = [64, 80, 96, 128]
...
@@ -14,6 +14,7 @@ _SUPPORTED_HEAD_SIZES = [64, 80, 96, 128]
class
PagedAttention
(
nn
.
Module
):
class
PagedAttention
(
nn
.
Module
):
# pylint: disable=line-too-long
"""GPT-style multi-head PagedAttention.
"""GPT-style multi-head PagedAttention.
This class takes flattened 1D query, key, and value tensors as input. The
This class takes flattened 1D query, key, and value tensors as input. The
...
@@ -54,12 +55,20 @@ class PagedAttention(nn.Module):
...
@@ -54,12 +55,20 @@ class PagedAttention(nn.Module):
def
multi_query_kv_attention
(
def
multi_query_kv_attention
(
self
,
self
,
output
:
torch
.
Tensor
,
# [num_prompt_tokens, num_heads, head_size]
output
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
# [num_prompt_tokens, num_heads, head_size]
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
# [num_prompt_tokens, num_heads, head_size]
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
# [num_prompt_tokens, num_heads, head_size]
value
:
torch
.
Tensor
,
attn_bias
:
xops
.
AttentionBias
,
attn_bias
:
xops
.
AttentionBias
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Normal attention for the prompt tokens.
Args:
output: shape = [num_prompt_tokens, num_heads, head_size]
query: shape = [num_prompt_tokens, num_heads, head_size]
key: shape = [num_prompt_tokens, num_heads, head_size]
value: shape = [num_prompt_tokens, num_heads, head_size]
"""
# TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
# TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
out
=
xops
.
memory_efficient_attention_forward
(
out
=
xops
.
memory_efficient_attention_forward
(
query
.
unsqueeze
(
0
),
query
.
unsqueeze
(
0
),
...
@@ -76,12 +85,22 @@ class PagedAttention(nn.Module):
...
@@ -76,12 +85,22 @@ class PagedAttention(nn.Module):
def
single_query_cached_kv_attention
(
def
single_query_cached_kv_attention
(
self
,
self
,
output
:
torch
.
Tensor
,
# [num_generation_tokens, num_heads, head_size]
output
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
# [num_generation_tokens, num_heads, head_size]
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
# [num_blocks, num_heads, head_size/x, block_size, x]
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
# [num_blocks, num_heads, head_size, block_size]
value_cache
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
)
->
None
:
)
->
None
:
"""PagedAttention for the generation tokens.
Args:
output: shape = [num_generation_tokens, num_heads, head_size]
query: shape = [num_generation_tokens, num_heads, head_size]
key_cache: shape = [num_blocks, num_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_heads, head_size, block_size]
input_metadata: metadata for paged attention.
"""
block_size
=
value_cache
.
shape
[
3
]
block_size
=
value_cache
.
shape
[
3
]
attention_ops
.
single_query_cached_kv_attention
(
attention_ops
.
single_query_cached_kv_attention
(
output
,
output
,
...
@@ -97,16 +116,32 @@ class PagedAttention(nn.Module):
...
@@ -97,16 +116,32 @@ class PagedAttention(nn.Module):
def
forward
(
def
forward
(
self
,
self
,
query
:
torch
.
Tensor
,
# [num_tokens, num_heads * head_size]
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
# [num_tokens, num_heads * head_size]
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
# [num_tokens, num_heads * head_size]
value
:
torch
.
Tensor
,
key_cache
:
Optional
[
torch
.
Tensor
],
# [num_blocks, num_heads, head_size/x, block_size, x]
key_cache
:
Optional
[
torch
.
Tensor
],
value_cache
:
Optional
[
torch
.
Tensor
],
# [num_blocks, num_heads, head_size, block_size]
value_cache
:
Optional
[
torch
.
Tensor
],
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
cache_event
:
Optional
[
torch
.
cuda
.
Event
],
cache_event
:
Optional
[
torch
.
cuda
.
Event
],
)
->
torch
.
Tensor
:
# [num_tokens, num_heads * head_size]
)
->
torch
.
Tensor
:
# NOTE: The query, key, and value tensors must be sliced from a qkv
"""PagedAttention forward pass.
# tensor of shape [num_tokens, 3 * num_heads * head_size].
NOTE: The query, key, and value tensors must be sliced from a qkv
tensor of shape [num_tokens, 3 * num_heads * head_size].
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_heads * head_size]
value: shape = [num_tokens, num_heads * head_size]
key_cache: shape = [num_blocks, num_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_heads, head_size, block_size]
input_metadata: metadata for paged attention.
cache_event: event to wait for the cache operations to finish.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
# Reshape the query, key, and value tensors.
# Reshape the query, key, and value tensors.
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
...
@@ -136,7 +171,7 @@ class PagedAttention(nn.Module):
...
@@ -136,7 +171,7 @@ class PagedAttention(nn.Module):
# and value vectors will not be cached.
# and value vectors will not be cached.
num_valid_tokens
=
input_metadata
.
num_valid_tokens
num_valid_tokens
=
input_metadata
.
num_valid_tokens
if
(
num_valid_tokens
>
0
and
key_cache
is
not
None
if
(
num_valid_tokens
>
0
and
key_cache
is
not
None
and
value_cache
is
not
None
):
and
value_cache
is
not
None
):
# The stride is 3 because the key and value are sliced from qkv.
# The stride is 3 because the key and value are sliced from qkv.
cache_ops
.
reshape_and_cache
(
cache_ops
.
reshape_and_cache
(
key
[:
num_valid_tokens
],
key
[:
num_valid_tokens
],
...
@@ -149,15 +184,12 @@ class PagedAttention(nn.Module):
...
@@ -149,15 +184,12 @@ class PagedAttention(nn.Module):
if
input_metadata
.
num_generation_tokens
>
0
:
if
input_metadata
.
num_generation_tokens
>
0
:
assert
key_cache
is
not
None
and
value_cache
is
not
None
,
(
assert
key_cache
is
not
None
and
value_cache
is
not
None
,
(
"key_cache and value_cache must be provided when "
"key_cache and value_cache must be provided when "
"generating tokens."
"generating tokens."
)
)
# Compute the attention op for generation tokens.
# Compute the attention op for generation tokens.
self
.
single_query_cached_kv_attention
(
self
.
single_query_cached_kv_attention
(
output
[
num_prompt_tokens
:
num_valid_tokens
],
output
[
num_prompt_tokens
:
num_valid_tokens
],
query
[
num_prompt_tokens
:
num_valid_tokens
],
query
[
num_prompt_tokens
:
num_valid_tokens
],
key_cache
,
key_cache
,
value_cache
,
input_metadata
)
value_cache
,
input_metadata
)
# Reshape the output tensor.
# Reshape the output tensor.
# NOTE(woosuk): The output tensor may include paddings.
# NOTE(woosuk): The output tensor may include paddings.
...
@@ -179,9 +211,9 @@ class PagedAttentionWithRoPE(PagedAttention):
...
@@ -179,9 +211,9 @@ class PagedAttentionWithRoPE(PagedAttention):
super
().
__init__
(
num_heads
,
head_size
,
scale
)
super
().
__init__
(
num_heads
,
head_size
,
scale
)
# Create the cos and sin cache.
# Create the cos and sin cache.
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
rotary_dim
,
2
)
/
rotary_dim
))
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
rotary_dim
,
2
)
/
rotary_dim
))
t
=
torch
.
arange
(
max_position
).
float
()
t
=
torch
.
arange
(
max_position
).
float
()
freqs
=
torch
.
einsum
(
'
i,j -> ij
'
,
t
,
inv_freq
.
float
())
freqs
=
torch
.
einsum
(
"
i,j -> ij
"
,
t
,
inv_freq
.
float
())
cos
=
freqs
.
cos
()
cos
=
freqs
.
cos
()
sin
=
freqs
.
sin
()
sin
=
freqs
.
sin
()
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
...
@@ -195,15 +227,32 @@ class PagedAttentionWithRoPE(PagedAttention):
...
@@ -195,15 +227,32 @@ class PagedAttentionWithRoPE(PagedAttention):
def
forward
(
def
forward
(
self
,
self
,
positions
:
torch
.
Tensor
,
# [num_tokens]
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
# [num_tokens, num_heads * head_size]
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
# [num_tokens, num_heads * head_size]
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
# [num_tokens, num_heads * head_size]
value
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
# [num_blocks, num_heads, head_size/x, block_size, x]
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
# [num_blocks, num_heads, head_size, block_size]
value_cache
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
cache_event
:
Optional
[
torch
.
cuda
.
Event
],
cache_event
:
Optional
[
torch
.
cuda
.
Event
],
)
->
torch
.
Tensor
:
# [num_tokens, num_heads * head_size]
)
->
torch
.
Tensor
:
""" PagedAttention forward pass with rotary embedding.
Args:
positions: shape = [num_tokens]
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_heads * head_size]
value: shape = [num_tokens, num_heads * head_size]
key_cache: shape = [num_blocks, num_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_heads, head_size, block_size]
input_metadata: metadata for paged attention.
cache_event: event to wait for the cache operations to finish.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
# Apply rotary embedding to the query and key before passing them
# Apply rotary embedding to the query and key before passing them
# to the attention op.
# to the attention op.
pos_encoding_ops
.
rotary_embedding_neox
(
pos_encoding_ops
.
rotary_embedding_neox
(
...
...
vllm/model_executor/layers/sampler.py
View file @
d6fa1be3
...
@@ -13,6 +13,7 @@ from vllm.sequence import SequenceOutputs
...
@@ -13,6 +13,7 @@ from vllm.sequence import SequenceOutputs
_SAMPLING_EPS
=
1e-5
_SAMPLING_EPS
=
1e-5
class
Sampler
(
nn
.
Module
):
class
Sampler
(
nn
.
Module
):
"""Samples the next tokens from the model's outputs.
"""Samples the next tokens from the model's outputs.
...
@@ -50,19 +51,20 @@ class Sampler(nn.Module):
...
@@ -50,19 +51,20 @@ class Sampler(nn.Module):
# Apply presence and frequency penalties.
# Apply presence and frequency penalties.
output_tokens
=
_get_output_tokens
(
input_metadata
)
output_tokens
=
_get_output_tokens
(
input_metadata
)
assert
len
(
output_tokens
)
==
logits
.
shape
[
0
]
assert
len
(
output_tokens
)
==
logits
.
shape
[
0
]
presence_penalties
,
frequency_penalties
=
_get_penalties
(
input_metadata
)
presence_penalties
,
frequency_penalties
=
_get_penalties
(
input_metadata
)
assert
len
(
presence_penalties
)
==
logits
.
shape
[
0
]
assert
len
(
presence_penalties
)
==
logits
.
shape
[
0
]
assert
len
(
frequency_penalties
)
==
logits
.
shape
[
0
]
assert
len
(
frequency_penalties
)
==
logits
.
shape
[
0
]
logits
=
_apply_penalties
(
logits
=
_apply_penalties
(
logits
,
output_tokens
,
presence_penalties
,
logits
,
output_tokens
,
presence_penalties
,
frequency_penalties
,
frequency_penalties
,
self
.
vocab_size
)
self
.
vocab_size
)
# Apply temperature scaling.
# Apply temperature scaling.
temperatures
=
_get_temperatures
(
input_metadata
)
temperatures
=
_get_temperatures
(
input_metadata
)
assert
len
(
temperatures
)
==
logits
.
shape
[
0
]
assert
len
(
temperatures
)
==
logits
.
shape
[
0
]
if
any
(
t
!=
1.0
for
t
in
temperatures
):
if
any
(
t
!=
1.0
for
t
in
temperatures
):
t
=
torch
.
tensor
(
t
=
torch
.
tensor
(
temperatures
,
temperatures
,
dtype
=
logits
.
dtype
,
device
=
logits
.
device
)
dtype
=
logits
.
dtype
,
device
=
logits
.
device
)
# Use in-place division to avoid creating a new tensor.
# Use in-place division to avoid creating a new tensor.
logits
.
div_
(
t
.
unsqueeze
(
dim
=
1
))
logits
.
div_
(
t
.
unsqueeze
(
dim
=
1
))
...
@@ -75,7 +77,9 @@ class Sampler(nn.Module):
...
@@ -75,7 +77,9 @@ class Sampler(nn.Module):
# Apply top-p and top-k truncation.
# Apply top-p and top-k truncation.
top_ps
,
top_ks
=
_get_top_p_top_k
(
input_metadata
,
self
.
vocab_size
)
top_ps
,
top_ks
=
_get_top_p_top_k
(
input_metadata
,
self
.
vocab_size
)
assert
len
(
top_ps
)
==
len
(
top_ks
)
==
probs
.
shape
[
0
]
assert
len
(
top_ps
)
==
len
(
top_ks
)
==
probs
.
shape
[
0
]
if
any
(
p
<
1.0
-
_SAMPLING_EPS
for
p
in
top_ps
)
or
any
(
k
!=
self
.
vocab_size
for
k
in
top_ks
):
do_top_p
=
any
(
p
<
1.0
-
_SAMPLING_EPS
for
p
in
top_ps
)
do_top_k
=
any
(
k
!=
self
.
vocab_size
for
k
in
top_ks
)
if
do_top_p
or
do_top_k
:
probs
=
_apply_top_p_top_k
(
probs
,
top_ps
,
top_ks
)
probs
=
_apply_top_p_top_k
(
probs
,
top_ps
,
top_ks
)
# Sample the next tokens.
# Sample the next tokens.
...
@@ -97,8 +101,7 @@ def _prune_hidden_states(
...
@@ -97,8 +101,7 @@ def _prune_hidden_states(
def
_get_penalties
(
def
_get_penalties
(
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
)
->
Tuple
[
List
[
float
],
List
[
float
]]:
)
->
Tuple
[
List
[
float
],
List
[
float
]]:
# Collect the presence and frequency penalties.
# Collect the presence and frequency penalties.
presence_penalties
:
List
[
float
]
=
[]
presence_penalties
:
List
[
float
]
=
[]
frequency_penalties
:
List
[
float
]
=
[]
frequency_penalties
:
List
[
float
]
=
[]
...
@@ -117,9 +120,7 @@ def _get_penalties(
...
@@ -117,9 +120,7 @@ def _get_penalties(
return
presence_penalties
,
frequency_penalties
return
presence_penalties
,
frequency_penalties
def
_get_output_tokens
(
def
_get_output_tokens
(
input_metadata
:
InputMetadata
)
->
List
[
List
[
int
]]:
input_metadata
:
InputMetadata
,
)
->
List
[
List
[
int
]]:
output_tokens
:
List
[
List
[
int
]]
=
[]
output_tokens
:
List
[
List
[
int
]]
=
[]
for
i
,
seq_group
in
enumerate
(
input_metadata
.
seq_groups
):
for
i
,
seq_group
in
enumerate
(
input_metadata
.
seq_groups
):
seq_ids
,
_
=
seq_group
seq_ids
,
_
=
seq_group
...
@@ -169,11 +170,13 @@ def _apply_penalties(
...
@@ -169,11 +170,13 @@ def _apply_penalties(
device
=
logits
.
device
)
device
=
logits
.
device
)
frequency_penalties
=
[
frequency_penalties
[
i
]
for
i
in
indices
]
frequency_penalties
=
[
frequency_penalties
[
i
]
for
i
in
indices
]
frequency_penalties
=
torch
.
tensor
(
frequency_penalties
=
torch
.
tensor
(
frequency_penalties
,
frequency_penalties
,
dtype
=
logits
.
dtype
,
device
=
logits
.
device
)
dtype
=
logits
.
dtype
,
device
=
logits
.
device
)
presence_penalties
=
[
presence_penalties
[
i
]
for
i
in
indices
]
presence_penalties
=
[
presence_penalties
[
i
]
for
i
in
indices
]
presence_penalties
=
torch
.
tensor
(
presence_penalties
=
torch
.
tensor
(
presence_penalties
,
presence_penalties
,
dtype
=
logits
.
dtype
,
device
=
logits
.
device
)
dtype
=
logits
.
dtype
,
device
=
logits
.
device
)
# We follow the definition in OpenAI API.
# We follow the definition in OpenAI API.
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
...
@@ -183,9 +186,7 @@ def _apply_penalties(
...
@@ -183,9 +186,7 @@ def _apply_penalties(
return
logits
return
logits
def
_get_temperatures
(
def
_get_temperatures
(
input_metadata
:
InputMetadata
)
->
List
[
float
]:
input_metadata
:
InputMetadata
,
)
->
List
[
float
]:
# Collect the temperatures for the logits.
# Collect the temperatures for the logits.
temperatures
:
List
[
float
]
=
[]
temperatures
:
List
[
float
]
=
[]
for
i
,
seq_group
in
enumerate
(
input_metadata
.
seq_groups
):
for
i
,
seq_group
in
enumerate
(
input_metadata
.
seq_groups
):
...
@@ -252,8 +253,9 @@ def _apply_top_p_top_k(
...
@@ -252,8 +253,9 @@ def _apply_top_p_top_k(
probs_sort
[
top_k_mask
]
=
0.0
probs_sort
[
top_k_mask
]
=
0.0
# Re-sort the probabilities.
# Re-sort the probabilities.
probs
=
torch
.
gather
(
probs
=
torch
.
gather
(
probs_sort
,
probs_sort
,
dim
=-
1
,
index
=
torch
.
argsort
(
probs_idx
,
dim
=-
1
))
dim
=-
1
,
index
=
torch
.
argsort
(
probs_idx
,
dim
=-
1
))
return
probs
return
probs
...
@@ -296,8 +298,9 @@ def _sample_from_prompt(
...
@@ -296,8 +298,9 @@ def _sample_from_prompt(
# Random sampling.
# Random sampling.
# Sample `best_of` tokens for the prompt.
# Sample `best_of` tokens for the prompt.
num_seqs
=
sampling_params
.
best_of
num_seqs
=
sampling_params
.
best_of
next_token_ids
=
torch
.
multinomial
(
next_token_ids
=
torch
.
multinomial
(
prob
,
prob
,
num_samples
=
num_seqs
,
replacement
=
True
)
num_samples
=
num_seqs
,
replacement
=
True
)
next_token_ids
=
next_token_ids
.
tolist
()
next_token_ids
=
next_token_ids
.
tolist
()
return
next_token_ids
return
next_token_ids
...
@@ -315,8 +318,9 @@ def _sample_from_generation_tokens(
...
@@ -315,8 +318,9 @@ def _sample_from_generation_tokens(
if
sampling_params
.
use_beam_search
:
if
sampling_params
.
use_beam_search
:
# Beam search.
# Beam search.
# Add cumulative logprobs for the sequences in the group.
# Add cumulative logprobs for the sequences in the group.
seq_logprobs
=
torch
.
tensor
(
seq_logprobs
=
torch
.
tensor
(
seq_logprobs
,
seq_logprobs
,
dtype
=
torch
.
float
,
device
=
logprobs
.
device
)
dtype
=
torch
.
float
,
device
=
logprobs
.
device
)
logprobs
=
logprobs
+
seq_logprobs
.
unsqueeze
(
dim
=
1
)
logprobs
=
logprobs
+
seq_logprobs
.
unsqueeze
(
dim
=
1
)
vocab_size
=
logprobs
.
size
(
-
1
)
vocab_size
=
logprobs
.
size
(
-
1
)
...
@@ -353,8 +357,9 @@ def _sample_from_generation_tokens(
...
@@ -353,8 +357,9 @@ def _sample_from_generation_tokens(
else
:
else
:
# Random sampling.
# Random sampling.
# Sample 1 token for each sequence in the group.
# Sample 1 token for each sequence in the group.
next_token_ids
=
torch
.
multinomial
(
next_token_ids
=
torch
.
multinomial
(
probs
,
probs
,
num_samples
=
1
,
replacement
=
True
)
num_samples
=
1
,
replacement
=
True
)
next_token_ids
=
next_token_ids
.
squeeze
(
dim
=-
1
).
tolist
()
next_token_ids
=
next_token_ids
.
squeeze
(
dim
=-
1
).
tolist
()
parent_seq_ids
=
seq_ids
parent_seq_ids
=
seq_ids
return
parent_seq_ids
,
next_token_ids
return
parent_seq_ids
,
next_token_ids
...
@@ -381,15 +386,16 @@ def _sample(
...
@@ -381,15 +386,16 @@ def _sample(
# Sample the next tokens.
# Sample the next tokens.
next_token_ids
=
_sample_from_prompt
(
prob
,
sampling_params
)
next_token_ids
=
_sample_from_prompt
(
prob
,
sampling_params
)
# Get top-k log probabilities for the next tokens.
# Get top-k log probabilities for the next tokens.
next_logprobs
=
_get_topk_logprobs
(
next_logprobs
=
_get_topk_logprobs
(
logprob
,
logprob
,
sampling_params
.
logprobs
)
sampling_params
.
logprobs
)
# Build the output.
# Build the output.
for
seq_id
,
next_token_id
in
zip
(
seq_ids
,
next_token_ids
):
for
seq_id
,
next_token_id
in
zip
(
seq_ids
,
next_token_ids
):
output_logprobs
=
next_logprobs
.
copy
()
output_logprobs
=
next_logprobs
.
copy
()
output_logprobs
[
next_token_id
]
=
logprob
[
next_token_id
].
item
()
output_logprobs
[
next_token_id
]
=
logprob
[
next_token_id
].
item
()
seq_outputs
[
seq_id
]
=
SequenceOutputs
(
seq_outputs
[
seq_id
]
=
SequenceOutputs
(
seq_id
,
seq_id
,
seq_id
,
seq_id
,
next_token_id
,
output_logprobs
)
next_token_id
,
output_logprobs
)
else
:
else
:
# Generate the next tokens for generation tokens.
# Generate the next tokens for generation tokens.
prob
=
probs
[
idx
:
idx
+
len
(
seq_ids
)]
prob
=
probs
[
idx
:
idx
+
len
(
seq_ids
)]
...
@@ -399,22 +405,24 @@ def _sample(
...
@@ -399,22 +405,24 @@ def _sample(
# Sample the next tokens.
# Sample the next tokens.
seq_logprobs
=
[
seq_logprobs
=
[
input_metadata
.
seq_data
[
seq_id
].
cumulative_logprob
input_metadata
.
seq_data
[
seq_id
].
cumulative_logprob
for
seq_id
in
seq_ids
]
for
seq_id
in
seq_ids
]
parent_seq_ids
,
next_token_ids
=
_sample_from_generation_tokens
(
parent_seq_ids
,
next_token_ids
=
_sample_from_generation_tokens
(
seq_ids
,
prob
,
logprob
,
seq_logprobs
,
sampling_params
)
seq_ids
,
prob
,
logprob
,
seq_logprobs
,
sampling_params
)
# Get top-k log probabilities for the next tokens.
# Get top-k log probabilities for the next tokens.
next_logprobs
:
Dict
[
int
,
Dict
[
int
,
float
]]
=
{}
next_logprobs
:
Dict
[
int
,
Dict
[
int
,
float
]]
=
{}
for
i
,
seq_id
in
enumerate
(
seq_ids
):
for
j
,
seq_id
in
enumerate
(
seq_ids
):
next_logprobs
[
seq_id
]
=
_get_topk_logprobs
(
next_logprobs
[
seq_id
]
=
_get_topk_logprobs
(
logprob
[
i
],
sampling_params
.
logprobs
)
logprob
[
j
],
sampling_params
.
logprobs
)
# Build the output.
# Build the output.
for
seq_id
,
parent_seq_id
,
next_token_id
in
zip
(
for
seq_id
,
parent_seq_id
,
next_token_id
in
zip
(
seq_ids
,
parent_seq_ids
,
next_token_ids
):
seq_ids
,
parent_seq_ids
,
next_token_ids
):
i
=
seq_ids
.
index
(
parent_seq_id
)
j
=
seq_ids
.
index
(
parent_seq_id
)
output_logprobs
=
next_logprobs
[
parent_seq_id
].
copy
()
output_logprobs
=
next_logprobs
[
parent_seq_id
].
copy
()
output_logprobs
[
next_token_id
]
=
logprob
[
i
,
next_token_id
].
item
()
output_logprobs
[
next_token_id
]
=
logprob
[
j
,
next_token_id
].
item
()
seq_outputs
[
seq_id
]
=
SequenceOutputs
(
seq_outputs
[
seq_id
]
=
SequenceOutputs
(
seq_id
,
seq_id
,
parent_seq_id
,
parent_seq_id
,
...
...
vllm/model_executor/model_loader.py
View file @
d6fa1be3
...
@@ -6,8 +6,9 @@ import torch.nn as nn
...
@@ -6,8 +6,9 @@ import torch.nn as nn
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.model_executor.models
import
(
GPT2LMHeadModel
,
GPTBigCodeForCausalLM
,
GPTNeoXForCausalLM
,
from
vllm.model_executor.models
import
(
GPT2LMHeadModel
,
GPTBigCodeForCausalLM
,
LlamaForCausalLM
,
OPTForCausalLM
)
GPTNeoXForCausalLM
,
LlamaForCausalLM
,
OPTForCausalLM
)
from
vllm.model_executor.weight_utils
import
initialize_dummy_weights
from
vllm.model_executor.weight_utils
import
initialize_dummy_weights
# TODO(woosuk): Lazy-load the model classes.
# TODO(woosuk): Lazy-load the model classes.
...
@@ -28,8 +29,7 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
...
@@ -28,8 +29,7 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
return
_MODEL_REGISTRY
[
arch
]
return
_MODEL_REGISTRY
[
arch
]
raise
ValueError
(
raise
ValueError
(
f
"Model architectures
{
architectures
}
are not supported for now. "
f
"Model architectures
{
architectures
}
are not supported for now. "
f
"Supported architectures:
{
list
(
_MODEL_REGISTRY
.
keys
())
}
"
f
"Supported architectures:
{
list
(
_MODEL_REGISTRY
.
keys
())
}
"
)
)
def
get_model
(
model_config
:
ModelConfig
)
->
nn
.
Module
:
def
get_model
(
model_config
:
ModelConfig
)
->
nn
.
Module
:
...
@@ -46,8 +46,7 @@ def get_model(model_config: ModelConfig) -> nn.Module:
...
@@ -46,8 +46,7 @@ def get_model(model_config: ModelConfig) -> nn.Module:
initialize_dummy_weights
(
model
)
initialize_dummy_weights
(
model
)
else
:
else
:
# Load the weights from the cached or downloaded files.
# Load the weights from the cached or downloaded files.
model
.
load_weights
(
model
.
load_weights
(
model_config
.
model
,
model_config
.
download_dir
,
model_config
.
model
,
model_config
.
download_dir
,
model_config
.
use_np_weights
)
model_config
.
use_np_weights
)
model
=
model
.
cuda
()
model
=
model
.
cuda
()
return
model
.
eval
()
return
model
.
eval
()
vllm/model_executor/models/__init__.py
View file @
d6fa1be3
...
@@ -4,8 +4,6 @@ from vllm.model_executor.models.gpt_neox import GPTNeoXForCausalLM
...
@@ -4,8 +4,6 @@ from vllm.model_executor.models.gpt_neox import GPTNeoXForCausalLM
from
vllm.model_executor.models.llama
import
LlamaForCausalLM
from
vllm.model_executor.models.llama
import
LlamaForCausalLM
from
vllm.model_executor.models.opt
import
OPTForCausalLM
from
vllm.model_executor.models.opt
import
OPTForCausalLM
__all__
=
[
__all__
=
[
"GPT2LMHeadModel"
,
"GPT2LMHeadModel"
,
"GPTBigCodeForCausalLM"
,
"GPTBigCodeForCausalLM"
,
...
...
vllm/model_executor/models/gpt2.py
View file @
d6fa1be3
# coding=utf-8
# coding=utf-8
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
# Copyright 2023 The vLLM team.
# Copyright 2023 The vLLM team.
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
...
@@ -47,19 +48,25 @@ class GPT2Attention(nn.Module):
...
@@ -47,19 +48,25 @@ class GPT2Attention(nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
total_num_heads
=
config
.
num_attention_heads
total_num_heads
=
config
.
num_attention_heads
tensor_model_parallel_world_size
=
get_tensor_model_parallel_world_size
()
tensor_model_parallel_world_size
=
(
get_tensor_model_parallel_world_size
())
assert
total_num_heads
%
tensor_model_parallel_world_size
==
0
assert
total_num_heads
%
tensor_model_parallel_world_size
==
0
self
.
num_heads
=
total_num_heads
//
tensor_model_parallel_world_size
self
.
num_heads
=
total_num_heads
//
tensor_model_parallel_world_size
self
.
head_dim
=
self
.
hidden_size
//
total_num_heads
self
.
head_dim
=
self
.
hidden_size
//
total_num_heads
self
.
scale
=
self
.
head_dim
**
-
0.5
self
.
scale
=
self
.
head_dim
**-
0.5
self
.
c_attn
=
ColumnParallelLinear
(
self
.
hidden_size
,
3
*
self
.
hidden_size
,
self
.
c_attn
=
ColumnParallelLinear
(
self
.
hidden_size
,
bias
=
True
,
gather_output
=
False
,
3
*
self
.
hidden_size
,
bias
=
True
,
gather_output
=
False
,
perform_initialization
=
False
)
perform_initialization
=
False
)
self
.
c_proj
=
RowParallelLinear
(
self
.
hidden_size
,
self
.
hidden_size
,
self
.
c_proj
=
RowParallelLinear
(
self
.
hidden_size
,
bias
=
True
,
input_is_parallel
=
True
,
self
.
hidden_size
,
bias
=
True
,
input_is_parallel
=
True
,
perform_initialization
=
False
)
perform_initialization
=
False
)
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_dim
,
scale
=
self
.
scale
)
scale
=
self
.
scale
)
def
forward
(
def
forward
(
...
@@ -72,8 +79,8 @@ class GPT2Attention(nn.Module):
...
@@ -72,8 +79,8 @@ class GPT2Attention(nn.Module):
qkv
,
_
=
self
.
c_attn
(
hidden_states
)
qkv
,
_
=
self
.
c_attn
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
key_cache
,
value_cache
=
kv_cache
key_cache
,
value_cache
=
kv_cache
attn_output
=
self
.
attn
(
attn_output
=
self
.
attn
(
q
,
k
,
v
,
key_cache
,
value_cache
,
q
,
k
,
v
,
key_cache
,
value_cache
,
input_metadata
,
cache_event
)
input_metadata
,
cache_event
)
attn_output
,
_
=
self
.
c_proj
(
attn_output
)
attn_output
,
_
=
self
.
c_proj
(
attn_output
)
return
attn_output
return
attn_output
...
@@ -87,11 +94,15 @@ class GPT2MLP(nn.Module):
...
@@ -87,11 +94,15 @@ class GPT2MLP(nn.Module):
):
):
super
().
__init__
()
super
().
__init__
()
hidden_size
=
config
.
hidden_size
hidden_size
=
config
.
hidden_size
self
.
c_fc
=
ColumnParallelLinear
(
hidden_size
,
intermediate_size
,
self
.
c_fc
=
ColumnParallelLinear
(
hidden_size
,
bias
=
True
,
gather_output
=
False
,
intermediate_size
,
bias
=
True
,
gather_output
=
False
,
perform_initialization
=
False
)
perform_initialization
=
False
)
self
.
c_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
self
.
c_proj
=
RowParallelLinear
(
intermediate_size
,
bias
=
True
,
input_is_parallel
=
True
,
hidden_size
,
bias
=
True
,
input_is_parallel
=
True
,
perform_initialization
=
False
)
perform_initialization
=
False
)
self
.
act
=
get_act_fn
(
config
.
activation_function
)
self
.
act
=
get_act_fn
(
config
.
activation_function
)
...
@@ -107,7 +118,8 @@ class GPT2Block(nn.Module):
...
@@ -107,7 +118,8 @@ class GPT2Block(nn.Module):
def
__init__
(
self
,
config
:
GPT2Config
):
def
__init__
(
self
,
config
:
GPT2Config
):
super
().
__init__
()
super
().
__init__
()
hidden_size
=
config
.
hidden_size
hidden_size
=
config
.
hidden_size
inner_dim
=
config
.
n_inner
if
config
.
n_inner
is
not
None
else
4
*
hidden_size
inner_dim
=
(
config
.
n_inner
if
config
.
n_inner
is
not
None
else
4
*
hidden_size
)
self
.
ln_1
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
ln_1
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
attn
=
GPT2Attention
(
config
)
self
.
attn
=
GPT2Attention
(
config
)
...
@@ -145,9 +157,9 @@ class GPT2Model(nn.Module):
...
@@ -145,9 +157,9 @@ class GPT2Model(nn.Module):
def
__init__
(
self
,
config
:
GPT2Config
):
def
__init__
(
self
,
config
:
GPT2Config
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
assert
config
.
add_cross_attention
==
False
assert
not
config
.
add_cross_attention
assert
config
.
scale_attn_by_inverse_layer_idx
==
False
assert
not
config
.
scale_attn_by_inverse_layer_idx
assert
config
.
reorder_and_upcast_attn
==
False
assert
not
config
.
reorder_and_upcast_attn
self
.
embed_dim
=
config
.
hidden_size
self
.
embed_dim
=
config
.
hidden_size
# Optimization: While the vocab size of GPT-2 is 50257, we extend it
# Optimization: While the vocab size of GPT-2 is 50257, we extend it
...
@@ -180,8 +192,8 @@ class GPT2Model(nn.Module):
...
@@ -180,8 +192,8 @@ class GPT2Model(nn.Module):
else
:
else
:
cache_event
=
cache_events
[
i
]
cache_event
=
cache_events
[
i
]
layer
=
self
.
h
[
i
]
layer
=
self
.
h
[
i
]
hidden_states
=
layer
(
hidden_states
=
layer
(
hidden_states
,
kv_caches
[
i
],
input_metadata
,
hidden_states
,
kv_caches
[
i
],
input_metadata
,
cache_event
)
cache_event
)
hidden_states
=
self
.
ln_f
(
hidden_states
)
hidden_states
=
self
.
ln_f
(
hidden_states
)
return
hidden_states
return
hidden_states
...
@@ -206,24 +218,26 @@ class GPT2LMHeadModel(nn.Module):
...
@@ -206,24 +218,26 @@ class GPT2LMHeadModel(nn.Module):
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
Dict
[
int
,
SequenceOutputs
]:
)
->
Dict
[
int
,
SequenceOutputs
]:
hidden_states
=
self
.
transformer
(
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
input_metadata
,
cache_events
)
next_tokens
=
self
.
sampler
(
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
self
.
lm_head_weight
,
hidden_states
,
input_metadata
)
input_metadata
)
return
next_tokens
return
next_tokens
_column_parallel_weights
=
[
"wte.weight"
,
"c_fc.weight"
,
"c_fc.bias"
]
_column_parallel_weights
=
[
"wte.weight"
,
"c_fc.weight"
,
"c_fc.bias"
]
_row_parallel_weights
=
[
"c_proj.weight"
]
_row_parallel_weights
=
[
"c_proj.weight"
]
def
load_weights
(
self
,
model_name_or_path
:
str
,
def
load_weights
(
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
cache_dir
:
Optional
[
str
]
=
None
,
use_np_cache
:
bool
=
False
):
use_np_cache
:
bool
=
False
):
tensor_model_parallel_world_size
=
get_tensor_model_parallel_world_size
()
tensor_model_parallel_world_size
=
(
get_tensor_model_parallel_world_size
())
tensor_model_parallel_rank
=
get_tensor_model_parallel_rank
()
tensor_model_parallel_rank
=
get_tensor_model_parallel_rank
()
state_dict
=
self
.
state_dict
()
state_dict
=
self
.
state_dict
()
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
use_np_cache
):
model_name_or_path
,
cache_dir
,
use_np_cache
):
if
"lm_head.weight"
in
name
:
if
"lm_head.weight"
in
name
:
# GPT-2 ties the weights of the embedding layer and the final
# GPT-2 ties the weights of the embedding layer and the final
# linear layer.
# linear layer.
...
@@ -248,16 +262,20 @@ class GPT2LMHeadModel(nn.Module):
...
@@ -248,16 +262,20 @@ class GPT2LMHeadModel(nn.Module):
if
name
==
"transformer.wte.weight"
:
if
name
==
"transformer.wte.weight"
:
# Consider padding in the vocab size.
# Consider padding in the vocab size.
padded_vocab_size
=
param
.
shape
[
0
]
*
tensor_model_parallel_world_size
padded_vocab_size
=
(
param
.
shape
[
0
]
*
tensor_model_parallel_world_size
)
num_extra_rows
=
padded_vocab_size
-
self
.
config
.
vocab_size
num_extra_rows
=
padded_vocab_size
-
self
.
config
.
vocab_size
extra_rows
=
torch
.
empty
(
num_extra_rows
,
loaded_weight
.
shape
[
1
])
extra_rows
=
torch
.
empty
(
num_extra_rows
,
loaded_weight
.
shape
[
1
])
extra_rows
=
extra_rows
.
to
(
loaded_weight
)
extra_rows
=
extra_rows
.
to
(
loaded_weight
)
loaded_weight
=
torch
.
cat
([
loaded_weight
,
extra_rows
],
dim
=
0
)
loaded_weight
=
torch
.
cat
([
loaded_weight
,
extra_rows
],
dim
=
0
)
# For the fused QKV linear layer, manually shard the weights.
# For the fused QKV linear layer, manually shard the weights.
if
"c_attn"
in
name
:
if
"c_attn"
in
name
:
# GPT-2's fused QKV has the shape of [3 * num_heads * head_size, hidden_size].
# GPT-2's fused QKV has the shape of
# When tensor parallelism is used, we shard the weights along the head dimension.
# [3 * num_heads * head_size, hidden_size].
# When tensor parallelism is used, we shard the weights along
# the head dimension.
total_num_heads
=
self
.
config
.
num_attention_heads
total_num_heads
=
self
.
config
.
num_attention_heads
hidden_size
=
self
.
config
.
hidden_size
hidden_size
=
self
.
config
.
hidden_size
head_size
=
hidden_size
//
total_num_heads
head_size
=
hidden_size
//
total_num_heads
...
@@ -266,11 +284,13 @@ class GPT2LMHeadModel(nn.Module):
...
@@ -266,11 +284,13 @@ class GPT2LMHeadModel(nn.Module):
head_end
=
(
tensor_model_parallel_rank
+
1
)
*
num_heads
head_end
=
(
tensor_model_parallel_rank
+
1
)
*
num_heads
if
name
.
endswith
(
".weight"
):
if
name
.
endswith
(
".weight"
):
loaded_weight
=
loaded_weight
.
view
(
3
,
total_num_heads
,
head_size
,
hidden_size
)
loaded_weight
=
loaded_weight
.
view
(
3
,
total_num_heads
,
head_size
,
hidden_size
)
loaded_weight
=
loaded_weight
[:,
head_start
:
head_end
,
:,
:]
loaded_weight
=
loaded_weight
[:,
head_start
:
head_end
,
:,
:]
loaded_weight
=
loaded_weight
.
reshape
(
-
1
,
hidden_size
)
loaded_weight
=
loaded_weight
.
reshape
(
-
1
,
hidden_size
)
elif
name
.
endswith
(
".bias"
):
elif
name
.
endswith
(
".bias"
):
loaded_weight
=
loaded_weight
.
view
(
3
,
total_num_heads
,
head_size
)
loaded_weight
=
loaded_weight
.
view
(
3
,
total_num_heads
,
head_size
)
loaded_weight
=
loaded_weight
[:,
head_start
:
head_end
,
:]
loaded_weight
=
loaded_weight
[:,
head_start
:
head_end
,
:]
loaded_weight
=
loaded_weight
.
reshape
(
-
1
)
loaded_weight
=
loaded_weight
.
reshape
(
-
1
)
else
:
else
:
...
...
vllm/model_executor/models/gpt_bigcode.py
View file @
d6fa1be3
# coding=utf-8
# coding=utf-8
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
# Copyright 2023 The vLLM team.
# Copyright 2023 The vLLM team.
# Copyright 2023 CTranslate2, and Michael Feil
# Copyright 2023 CTranslate2, and Michael Feil
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
...
@@ -49,19 +50,25 @@ class GPTBigCodeAttention(nn.Module):
...
@@ -49,19 +50,25 @@ class GPTBigCodeAttention(nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
total_num_heads
=
config
.
num_attention_heads
total_num_heads
=
config
.
num_attention_heads
tensor_model_parallel_world_size
=
get_tensor_model_parallel_world_size
()
tensor_model_parallel_world_size
=
(
get_tensor_model_parallel_world_size
())
assert
total_num_heads
%
tensor_model_parallel_world_size
==
0
assert
total_num_heads
%
tensor_model_parallel_world_size
==
0
self
.
num_heads
=
total_num_heads
//
tensor_model_parallel_world_size
self
.
num_heads
=
total_num_heads
//
tensor_model_parallel_world_size
self
.
head_dim
=
self
.
hidden_size
//
total_num_heads
self
.
head_dim
=
self
.
hidden_size
//
total_num_heads
self
.
scale
=
self
.
head_dim
**
-
0.5
self
.
scale
=
self
.
head_dim
**-
0.5
self
.
c_attn
=
ColumnParallelLinear
(
self
.
hidden_size
,
3
*
self
.
hidden_size
,
self
.
c_attn
=
ColumnParallelLinear
(
self
.
hidden_size
,
bias
=
True
,
gather_output
=
False
,
3
*
self
.
hidden_size
,
bias
=
True
,
gather_output
=
False
,
perform_initialization
=
False
)
perform_initialization
=
False
)
self
.
c_proj
=
RowParallelLinear
(
self
.
hidden_size
,
self
.
hidden_size
,
self
.
c_proj
=
RowParallelLinear
(
self
.
hidden_size
,
bias
=
True
,
input_is_parallel
=
True
,
self
.
hidden_size
,
bias
=
True
,
input_is_parallel
=
True
,
perform_initialization
=
False
)
perform_initialization
=
False
)
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_dim
,
scale
=
self
.
scale
)
scale
=
self
.
scale
)
def
forward
(
def
forward
(
...
@@ -74,8 +81,8 @@ class GPTBigCodeAttention(nn.Module):
...
@@ -74,8 +81,8 @@ class GPTBigCodeAttention(nn.Module):
qkv
,
_
=
self
.
c_attn
(
hidden_states
)
qkv
,
_
=
self
.
c_attn
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
key_cache
,
value_cache
=
kv_cache
key_cache
,
value_cache
=
kv_cache
attn_output
=
self
.
attn
(
attn_output
=
self
.
attn
(
q
,
k
,
v
,
key_cache
,
value_cache
,
q
,
k
,
v
,
key_cache
,
value_cache
,
input_metadata
,
cache_event
)
input_metadata
,
cache_event
)
attn_output
,
_
=
self
.
c_proj
(
attn_output
)
attn_output
,
_
=
self
.
c_proj
(
attn_output
)
return
attn_output
return
attn_output
...
@@ -89,11 +96,15 @@ class GPTBigMLP(nn.Module):
...
@@ -89,11 +96,15 @@ class GPTBigMLP(nn.Module):
):
):
super
().
__init__
()
super
().
__init__
()
hidden_size
=
config
.
hidden_size
hidden_size
=
config
.
hidden_size
self
.
c_fc
=
ColumnParallelLinear
(
hidden_size
,
intermediate_size
,
self
.
c_fc
=
ColumnParallelLinear
(
hidden_size
,
bias
=
True
,
gather_output
=
False
,
intermediate_size
,
bias
=
True
,
gather_output
=
False
,
perform_initialization
=
False
)
perform_initialization
=
False
)
self
.
c_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
self
.
c_proj
=
RowParallelLinear
(
intermediate_size
,
bias
=
True
,
input_is_parallel
=
True
,
hidden_size
,
bias
=
True
,
input_is_parallel
=
True
,
perform_initialization
=
False
)
perform_initialization
=
False
)
self
.
act
=
get_act_fn
(
config
.
activation_function
)
self
.
act
=
get_act_fn
(
config
.
activation_function
)
...
@@ -109,7 +120,8 @@ class GPTBigCodeBlock(nn.Module):
...
@@ -109,7 +120,8 @@ class GPTBigCodeBlock(nn.Module):
def
__init__
(
self
,
config
:
GPTBigCodeConfig
):
def
__init__
(
self
,
config
:
GPTBigCodeConfig
):
super
().
__init__
()
super
().
__init__
()
hidden_size
=
config
.
hidden_size
hidden_size
=
config
.
hidden_size
inner_dim
=
config
.
n_inner
if
config
.
n_inner
is
not
None
else
4
*
hidden_size
inner_dim
=
(
config
.
n_inner
if
config
.
n_inner
is
not
None
else
4
*
hidden_size
)
self
.
ln_1
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
ln_1
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
attn
=
GPTBigCodeAttention
(
config
)
self
.
attn
=
GPTBigCodeAttention
(
config
)
...
@@ -147,7 +159,7 @@ class GPTBigCodeModel(nn.Module):
...
@@ -147,7 +159,7 @@ class GPTBigCodeModel(nn.Module):
def
__init__
(
self
,
config
:
GPTBigCodeConfig
):
def
__init__
(
self
,
config
:
GPTBigCodeConfig
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
assert
config
.
add_cross_attention
==
False
assert
not
config
.
add_cross_attention
self
.
embed_dim
=
config
.
hidden_size
self
.
embed_dim
=
config
.
hidden_size
...
@@ -181,8 +193,8 @@ class GPTBigCodeModel(nn.Module):
...
@@ -181,8 +193,8 @@ class GPTBigCodeModel(nn.Module):
else
:
else
:
cache_event
=
cache_events
[
i
]
cache_event
=
cache_events
[
i
]
layer
=
self
.
h
[
i
]
layer
=
self
.
h
[
i
]
hidden_states
=
layer
(
hidden_states
=
layer
(
hidden_states
,
kv_caches
[
i
],
input_metadata
,
hidden_states
,
kv_caches
[
i
],
input_metadata
,
cache_event
)
cache_event
)
hidden_states
=
self
.
ln_f
(
hidden_states
)
hidden_states
=
self
.
ln_f
(
hidden_states
)
return
hidden_states
return
hidden_states
...
@@ -207,24 +219,26 @@ class GPTBigCodeForCausalLM(nn.Module):
...
@@ -207,24 +219,26 @@ class GPTBigCodeForCausalLM(nn.Module):
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
Dict
[
int
,
SequenceOutputs
]:
)
->
Dict
[
int
,
SequenceOutputs
]:
hidden_states
=
self
.
transformer
(
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
input_metadata
,
cache_events
)
next_tokens
=
self
.
sampler
(
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
self
.
lm_head_weight
,
hidden_states
,
input_metadata
)
input_metadata
)
return
next_tokens
return
next_tokens
_column_parallel_weights
=
[
"wte.weight"
,
"c_fc.weight"
,
"c_fc.bias"
]
_column_parallel_weights
=
[
"wte.weight"
,
"c_fc.weight"
,
"c_fc.bias"
]
_row_parallel_weights
=
[
"c_proj.weight"
]
_row_parallel_weights
=
[
"c_proj.weight"
]
def
load_weights
(
self
,
model_name_or_path
:
str
,
def
load_weights
(
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
cache_dir
:
Optional
[
str
]
=
None
,
use_np_cache
:
bool
=
False
):
use_np_cache
:
bool
=
False
):
tensor_model_parallel_world_size
=
get_tensor_model_parallel_world_size
()
tensor_model_parallel_world_size
=
(
get_tensor_model_parallel_world_size
())
tensor_model_parallel_rank
=
get_tensor_model_parallel_rank
()
tensor_model_parallel_rank
=
get_tensor_model_parallel_rank
()
state_dict
=
self
.
state_dict
()
state_dict
=
self
.
state_dict
()
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
use_np_cache
):
model_name_or_path
,
cache_dir
,
use_np_cache
):
if
"lm_head.weight"
in
name
:
if
"lm_head.weight"
in
name
:
# GPT-2 ties the weights of the embedding layer and the final
# GPT-2 ties the weights of the embedding layer and the final
# linear layer.
# linear layer.
...
@@ -241,9 +255,11 @@ class GPTBigCodeForCausalLM(nn.Module):
...
@@ -241,9 +255,11 @@ class GPTBigCodeForCausalLM(nn.Module):
if
name
==
"transformer.wte.weight"
:
if
name
==
"transformer.wte.weight"
:
# Consider padding in the vocab size.
# Consider padding in the vocab size.
padded_vocab_size
=
param
.
shape
[
0
]
*
tensor_model_parallel_world_size
padded_vocab_size
=
param
.
shape
[
0
]
*
tensor_model_parallel_world_size
num_extra_rows
=
padded_vocab_size
-
self
.
config
.
vocab_size
num_extra_rows
=
padded_vocab_size
-
self
.
config
.
vocab_size
extra_rows
=
torch
.
empty
(
num_extra_rows
,
loaded_weight
.
shape
[
1
])
extra_rows
=
torch
.
empty
(
num_extra_rows
,
loaded_weight
.
shape
[
1
])
extra_rows
=
extra_rows
.
to
(
loaded_weight
)
extra_rows
=
extra_rows
.
to
(
loaded_weight
)
loaded_weight
=
torch
.
cat
([
loaded_weight
,
extra_rows
],
dim
=
0
)
loaded_weight
=
torch
.
cat
([
loaded_weight
,
extra_rows
],
dim
=
0
)
...
@@ -258,25 +274,31 @@ class GPTBigCodeForCausalLM(nn.Module):
...
@@ -258,25 +274,31 @@ class GPTBigCodeForCausalLM(nn.Module):
qkv_array
=
qkv_array
.
numpy
()
qkv_array
=
qkv_array
.
numpy
()
dims_q
=
n_head
*
head_dim
dims_q
=
n_head
*
head_dim
q
,
k
,
v
=
np
.
split
(
qkv_array
,
(
dims_q
,
dims_q
+
head_dim
),
axis
=
0
)
# pylint: disable=unbalanced-tuple-unpacking
# q is fine, but k & v have not replicated shape along the first axis
q
,
k
,
v
=
np
.
split
(
qkv_array
,
(
dims_q
,
dims_q
+
head_dim
),
# as long as MQA is not nativly supported, increase memory and replicated
axis
=
0
)
# (head_dim, hidden_dim) to (n_heads * head_dim, hidden_dim)
# q is fine, but k & v have not replicated shape along the first
# axis as long as MQA is not nativly supported, increase memory
# and replicated (head_dim, hidden_dim) to
# (n_heads * head_dim, hidden_dim)
if
k
.
ndim
==
2
and
v
.
ndim
==
2
:
if
k
.
ndim
==
2
and
v
.
ndim
==
2
:
replication
=
(
n_head
,
1
)
# weights
replication
=
(
n_head
,
1
)
# weights
else
:
else
:
replication
=
n_head
# biases
replication
=
n_head
# biases
# replicate n_head times for q, v
# replicate n_head times for q, v
k
,
v
=
np
.
tile
(
k
,
replication
),
np
.
tile
(
v
,
replication
)
k
,
v
=
np
.
tile
(
k
,
replication
),
np
.
tile
(
v
,
replication
)
# concat q, k, v along the first axis (n_heads * head_dim, hidden_dim)
# concat q, k, v along the first axis
# (n_heads * head_dim, hidden_dim)
# to (3 * n_heads * head_dim, hidden_dim)
# to (3 * n_heads * head_dim, hidden_dim)
qkv_array
=
np
.
concatenate
((
q
,
k
,
v
),
axis
=
0
)
qkv_array
=
np
.
concatenate
((
q
,
k
,
v
),
axis
=
0
)
return
torch
.
from_numpy
(
qkv_array
)
return
torch
.
from_numpy
(
qkv_array
)
# For the fused QKV linear layer, manually shard the weights.
# For the fused QKV linear layer, manually shard the weights.
if
"c_attn"
in
name
:
if
"c_attn"
in
name
:
# GPT-2's fused QKV has the shape of [3 * num_heads * head_size, hidden_size].
# GPT-2's fused QKV has the shape of
# When tensor parallelism is used, we shard the weights along the head dimension.
# [3 * num_heads * head_size, hidden_size].
# When tensor parallelism is used, we shard the weights along
# the head dimension.
total_num_heads
=
self
.
config
.
num_attention_heads
total_num_heads
=
self
.
config
.
num_attention_heads
hidden_size
=
self
.
config
.
hidden_size
hidden_size
=
self
.
config
.
hidden_size
head_size
=
hidden_size
//
total_num_heads
head_size
=
hidden_size
//
total_num_heads
...
@@ -285,13 +307,19 @@ class GPTBigCodeForCausalLM(nn.Module):
...
@@ -285,13 +307,19 @@ class GPTBigCodeForCausalLM(nn.Module):
head_end
=
(
tensor_model_parallel_rank
+
1
)
*
num_heads
head_end
=
(
tensor_model_parallel_rank
+
1
)
*
num_heads
if
name
.
endswith
(
".weight"
):
if
name
.
endswith
(
".weight"
):
loaded_weight
=
_expand_mqa_mha
(
loaded_weight
,
n_head
=
total_num_heads
,
head_dim
=
head_size
)
loaded_weight
=
_expand_mqa_mha
(
loaded_weight
,
loaded_weight
=
loaded_weight
.
view
(
3
,
total_num_heads
,
head_size
,
hidden_size
)
n_head
=
total_num_heads
,
head_dim
=
head_size
)
loaded_weight
=
loaded_weight
.
view
(
3
,
total_num_heads
,
head_size
,
hidden_size
)
loaded_weight
=
loaded_weight
[:,
head_start
:
head_end
,
:,
:]
loaded_weight
=
loaded_weight
[:,
head_start
:
head_end
,
:,
:]
loaded_weight
=
loaded_weight
.
reshape
(
-
1
,
hidden_size
)
loaded_weight
=
loaded_weight
.
reshape
(
-
1
,
hidden_size
)
elif
name
.
endswith
(
".bias"
):
elif
name
.
endswith
(
".bias"
):
loaded_weight
=
_expand_mqa_mha
(
loaded_weight
,
n_head
=
total_num_heads
,
head_dim
=
head_size
)
loaded_weight
=
_expand_mqa_mha
(
loaded_weight
,
loaded_weight
=
loaded_weight
.
view
(
3
,
total_num_heads
,
head_size
)
n_head
=
total_num_heads
,
head_dim
=
head_size
)
loaded_weight
=
loaded_weight
.
view
(
3
,
total_num_heads
,
head_size
)
loaded_weight
=
loaded_weight
[:,
head_start
:
head_end
,
:]
loaded_weight
=
loaded_weight
[:,
head_start
:
head_end
,
:]
loaded_weight
=
loaded_weight
.
reshape
(
-
1
)
loaded_weight
=
loaded_weight
.
reshape
(
-
1
)
else
:
else
:
...
...
vllm/model_executor/models/gpt_neox.py
View file @
d6fa1be3
# coding=utf-8
# coding=utf-8
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py
# Copyright 2023 The vLLM team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved.
# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved.
#
#
...
@@ -48,19 +49,23 @@ class GPTNeoXAttention(nn.Module):
...
@@ -48,19 +49,23 @@ class GPTNeoXAttention(nn.Module):
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
self
.
head_size
=
self
.
hidden_size
//
self
.
total_num_heads
self
.
head_size
=
self
.
hidden_size
//
self
.
total_num_heads
tensor_model_parallel_world_size
=
get_tensor_model_parallel_world_size
()
tensor_model_parallel_world_size
=
(
get_tensor_model_parallel_world_size
())
assert
self
.
total_num_heads
%
tensor_model_parallel_world_size
==
0
assert
self
.
total_num_heads
%
tensor_model_parallel_world_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tensor_model_parallel_world_size
self
.
num_heads
=
(
self
.
total_num_heads
//
tensor_model_parallel_world_size
)
self
.
query_key_value
=
ColumnParallelLinear
(
config
.
hidden_size
,
3
*
config
.
hidden_size
,
self
.
query_key_value
=
ColumnParallelLinear
(
gather_output
=
False
,
config
.
hidden_size
,
perform_initialization
=
False
)
3
*
config
.
hidden_size
,
self
.
dense
=
RowParallelLinear
(
config
.
hidden_size
,
config
.
hidden_size
,
gather_output
=
False
,
perform_initialization
=
False
)
self
.
dense
=
RowParallelLinear
(
config
.
hidden_size
,
config
.
hidden_size
,
input_is_parallel
=
True
,
input_is_parallel
=
True
,
perform_initialization
=
False
)
perform_initialization
=
False
)
scaling
=
self
.
head_size
**
-
0.5
scaling
=
self
.
head_size
**-
0.5
rotary_dim
=
int
(
self
.
head_size
*
config
.
rotary_pct
)
rotary_dim
=
int
(
self
.
head_size
*
config
.
rotary_pct
)
assert
rotary_dim
%
2
==
0
assert
rotary_dim
%
2
==
0
self
.
attn
=
PagedAttentionWithRoPE
(
self
.
num_heads
,
self
.
head_size
,
self
.
attn
=
PagedAttentionWithRoPE
(
self
.
num_heads
,
self
.
head_size
,
...
@@ -78,8 +83,8 @@ class GPTNeoXAttention(nn.Module):
...
@@ -78,8 +83,8 @@ class GPTNeoXAttention(nn.Module):
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
k_cache
,
v_cache
=
kv_cache
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
attn_output
=
self
.
attn
(
position_ids
,
q
,
k
,
v
,
k_cache
,
v_cache
,
position_ids
,
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
cache_event
)
input_metadata
,
cache_event
)
output
,
_
=
self
.
dense
(
attn_output
)
output
,
_
=
self
.
dense
(
attn_output
)
return
output
return
output
...
@@ -92,7 +97,8 @@ class GPTNeoXMLP(nn.Module):
...
@@ -92,7 +97,8 @@ class GPTNeoXMLP(nn.Module):
config
.
intermediate_size
,
config
.
intermediate_size
,
gather_output
=
False
,
gather_output
=
False
,
perform_initialization
=
False
)
perform_initialization
=
False
)
self
.
dense_4h_to_h
=
RowParallelLinear
(
config
.
intermediate_size
,
config
.
hidden_size
,
self
.
dense_4h_to_h
=
RowParallelLinear
(
config
.
intermediate_size
,
config
.
hidden_size
,
input_is_parallel
=
True
,
input_is_parallel
=
True
,
perform_initialization
=
False
)
perform_initialization
=
False
)
self
.
act
=
get_act_fn
(
config
.
hidden_act
)
self
.
act
=
get_act_fn
(
config
.
hidden_act
)
...
@@ -109,8 +115,10 @@ class GPTNeoXLayer(nn.Module):
...
@@ -109,8 +115,10 @@ class GPTNeoXLayer(nn.Module):
def
__init__
(
self
,
config
:
GPTNeoXConfig
):
def
__init__
(
self
,
config
:
GPTNeoXConfig
):
super
().
__init__
()
super
().
__init__
()
self
.
use_parallel_residual
=
config
.
use_parallel_residual
self
.
use_parallel_residual
=
config
.
use_parallel_residual
self
.
input_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
self
.
input_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
self
.
post_attention_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
eps
=
config
.
layer_norm_eps
)
self
.
post_attention_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
self
.
attention
=
GPTNeoXAttention
(
config
)
self
.
attention
=
GPTNeoXAttention
(
config
)
self
.
mlp
=
GPTNeoXMLP
(
config
)
self
.
mlp
=
GPTNeoXMLP
(
config
)
...
@@ -154,10 +162,13 @@ class GPTNeoXModel(nn.Module):
...
@@ -154,10 +162,13 @@ class GPTNeoXModel(nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
embed_in
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
self
.
embed_in
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
perform_initialization
=
False
)
perform_initialization
=
False
)
self
.
layers
=
nn
.
ModuleList
([
GPTNeoXLayer
(
config
)
for
_
in
range
(
config
.
num_hidden_layers
)])
self
.
layers
=
nn
.
ModuleList
(
self
.
final_layer_norm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
[
GPTNeoXLayer
(
config
)
for
_
in
range
(
config
.
num_hidden_layers
)])
self
.
final_layer_norm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -191,8 +202,10 @@ class GPTNeoXForCausalLM(nn.Module):
...
@@ -191,8 +202,10 @@ class GPTNeoXForCausalLM(nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
gpt_neox
=
GPTNeoXModel
(
config
)
self
.
gpt_neox
=
GPTNeoXModel
(
config
)
self
.
embed_out
=
ColumnParallelLinear
(
config
.
hidden_size
,
config
.
vocab_size
,
self
.
embed_out
=
ColumnParallelLinear
(
config
.
hidden_size
,
bias
=
False
,
gather_output
=
False
,
config
.
vocab_size
,
bias
=
False
,
gather_output
=
False
,
perform_initialization
=
False
)
perform_initialization
=
False
)
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
...
@@ -204,24 +217,28 @@ class GPTNeoXForCausalLM(nn.Module):
...
@@ -204,24 +217,28 @@ class GPTNeoXForCausalLM(nn.Module):
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
Dict
[
int
,
SequenceOutputs
]:
)
->
Dict
[
int
,
SequenceOutputs
]:
hidden_states
=
self
.
gpt_neox
(
hidden_states
=
self
.
gpt_neox
(
input_ids
,
positions
,
kv_caches
,
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
input_metadata
,
cache_events
)
next_tokens
=
self
.
sampler
(
next_tokens
=
self
.
sampler
(
self
.
embed_out
.
weight
,
hidden_states
,
self
.
embed_out
.
weight
,
hidden_states
,
input_metadata
)
input_metadata
)
return
next_tokens
return
next_tokens
_column_parallel_weights
=
[
"embed_in.weight"
,
"embed_out.weight"
,
"dense_h_to_4h.weight"
,
"dense_h_to_4h.bias"
]
_column_parallel_weights
=
[
"embed_in.weight"
,
"embed_out.weight"
,
"dense_h_to_4h.weight"
,
"dense_h_to_4h.bias"
]
_row_parallel_weights
=
[
"dense.weight"
,
"dense_4h_to_h.weight"
]
_row_parallel_weights
=
[
"dense.weight"
,
"dense_4h_to_h.weight"
]
def
load_weights
(
self
,
model_name_or_path
:
str
,
def
load_weights
(
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
cache_dir
:
Optional
[
str
]
=
None
,
use_np_cache
:
bool
=
False
):
use_np_cache
:
bool
=
False
):
tensor_model_parallel_rank
=
get_tensor_model_parallel_rank
()
tensor_model_parallel_rank
=
get_tensor_model_parallel_rank
()
state_dict
=
self
.
state_dict
()
state_dict
=
self
.
state_dict
()
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
use_np_cache
):
model_name_or_path
,
cache_dir
,
use_np_cache
):
if
(
"attention.bias"
in
name
or
"attention.masked_bias"
in
name
if
(
"attention.bias"
in
name
or
"attention.masked_bias"
in
name
or
"rotary_emb.inv_freq"
in
name
):
or
"rotary_emb.inv_freq"
in
name
):
continue
continue
param
=
state_dict
[
name
]
param
=
state_dict
[
name
]
if
"query_key_value"
in
name
:
if
"query_key_value"
in
name
:
...
@@ -230,17 +247,19 @@ class GPTNeoXForCausalLM(nn.Module):
...
@@ -230,17 +247,19 @@ class GPTNeoXForCausalLM(nn.Module):
# required shape is [3 * num_heads * head_size, hidden_size].
# required shape is [3 * num_heads * head_size, hidden_size].
# Thus, we need weight conversion.
# Thus, we need weight conversion.
shard_size
=
param
.
shape
[
0
]
shard_size
=
param
.
shape
[
0
]
loaded_weight
=
loaded_weight
[
shard_size
*
tensor_model_parallel_rank
loaded_weight
=
loaded_weight
[
:
shard_size
*
(
tensor_model_parallel_rank
+
1
)]
shard_size
*
tensor_model_parallel_rank
:
shard_size
*
(
tensor_model_parallel_rank
+
1
)]
num_heads
=
self
.
config
.
num_attention_heads
num_heads
=
self
.
config
.
num_attention_heads
hidden_size
=
self
.
config
.
hidden_size
hidden_size
=
self
.
config
.
hidden_size
head_size
=
hidden_size
//
num_heads
head_size
=
hidden_size
//
num_heads
if
'query_key_value.weight'
in
name
:
if
"query_key_value.weight"
in
name
:
loaded_weight
=
loaded_weight
.
view
(
-
1
,
3
,
head_size
,
hidden_size
)
loaded_weight
=
loaded_weight
.
view
(
-
1
,
3
,
head_size
,
hidden_size
)
loaded_weight
=
loaded_weight
.
transpose
(
0
,
1
)
loaded_weight
=
loaded_weight
.
transpose
(
0
,
1
)
loaded_weight
=
loaded_weight
.
reshape
(
-
1
,
hidden_size
)
loaded_weight
=
loaded_weight
.
reshape
(
-
1
,
hidden_size
)
elif
'
query_key_value.bias
'
in
name
:
elif
"
query_key_value.bias
"
in
name
:
loaded_weight
=
loaded_weight
.
view
(
-
1
,
3
,
head_size
)
loaded_weight
=
loaded_weight
.
view
(
-
1
,
3
,
head_size
)
loaded_weight
=
loaded_weight
.
transpose
(
0
,
1
)
loaded_weight
=
loaded_weight
.
transpose
(
0
,
1
)
loaded_weight
=
loaded_weight
.
reshape
(
-
1
)
loaded_weight
=
loaded_weight
.
reshape
(
-
1
)
...
...
vllm/model_executor/models/llama.py
View file @
d6fa1be3
# coding=utf-8
# coding=utf-8
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
#
...
@@ -30,7 +31,6 @@ import torch
...
@@ -30,7 +31,6 @@ import torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
LlamaConfig
from
transformers
import
LlamaConfig
from
vllm.sequence
import
SequenceOutputs
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
...
@@ -56,15 +56,19 @@ class LlamaMLP(nn.Module):
...
@@ -56,15 +56,19 @@ class LlamaMLP(nn.Module):
hidden_act
:
str
,
hidden_act
:
str
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
gate_up_proj
=
ColumnParallelLinear
(
hidden_size
,
2
*
intermediate_size
,
self
.
gate_up_proj
=
ColumnParallelLinear
(
hidden_size
,
bias
=
False
,
gather_output
=
False
,
2
*
intermediate_size
,
bias
=
False
,
gather_output
=
False
,
perform_initialization
=
False
)
perform_initialization
=
False
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
bias
=
False
,
input_is_parallel
=
True
,
hidden_size
,
bias
=
False
,
input_is_parallel
=
True
,
perform_initialization
=
False
)
perform_initialization
=
False
)
if
hidden_act
!=
'
silu
'
:
if
hidden_act
!=
"
silu
"
:
raise
ValueError
(
f
'
Unsupported activation:
{
hidden_act
}
.
'
raise
ValueError
(
f
"
Unsupported activation:
{
hidden_act
}
.
"
'
Only silu is supported for now.
'
)
"
Only silu is supported for now.
"
)
self
.
act_fn
=
SiluAndMul
()
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
...
@@ -83,12 +87,14 @@ class LlamaAttention(nn.Module):
...
@@ -83,12 +87,14 @@ class LlamaAttention(nn.Module):
):
):
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
tensor_model_parallel_world_size
=
get_tensor_model_parallel_world_size
()
tensor_model_parallel_world_size
=
(
get_tensor_model_parallel_world_size
())
self
.
total_num_heads
=
num_heads
self
.
total_num_heads
=
num_heads
assert
self
.
total_num_heads
%
tensor_model_parallel_world_size
==
0
assert
self
.
total_num_heads
%
tensor_model_parallel_world_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tensor_model_parallel_world_size
self
.
num_heads
=
(
self
.
total_num_heads
//
tensor_model_parallel_world_size
)
self
.
head_dim
=
hidden_size
//
self
.
total_num_heads
self
.
head_dim
=
hidden_size
//
self
.
total_num_heads
self
.
scaling
=
self
.
head_dim
**
-
0.5
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
qkv_proj
=
ColumnParallelLinear
(
self
.
qkv_proj
=
ColumnParallelLinear
(
hidden_size
,
hidden_size
,
...
@@ -104,8 +110,10 @@ class LlamaAttention(nn.Module):
...
@@ -104,8 +110,10 @@ class LlamaAttention(nn.Module):
input_is_parallel
=
True
,
input_is_parallel
=
True
,
perform_initialization
=
False
,
perform_initialization
=
False
,
)
)
self
.
attn
=
PagedAttentionWithRoPE
(
self
.
num_heads
,
self
.
head_dim
,
self
.
attn
=
PagedAttentionWithRoPE
(
self
.
num_heads
,
self
.
scaling
,
rotary_dim
=
self
.
head_dim
)
self
.
head_dim
,
self
.
scaling
,
rotary_dim
=
self
.
head_dim
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -118,8 +126,8 @@ class LlamaAttention(nn.Module):
...
@@ -118,8 +126,8 @@ class LlamaAttention(nn.Module):
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
k_cache
,
v_cache
=
kv_cache
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
attn_output
=
self
.
attn
(
positions
,
q
,
k
,
v
,
k_cache
,
v_cache
,
positions
,
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
cache_event
)
input_metadata
,
cache_event
)
output
,
_
=
self
.
o_proj
(
attn_output
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
return
output
...
@@ -138,8 +146,10 @@ class LlamaDecoderLayer(nn.Module):
...
@@ -138,8 +146,10 @@ class LlamaDecoderLayer(nn.Module):
intermediate_size
=
config
.
intermediate_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
hidden_act
=
config
.
hidden_act
,
)
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -177,9 +187,13 @@ class LlamaModel(nn.Module):
...
@@ -177,9 +187,13 @@ class LlamaModel(nn.Module):
self
.
padding_idx
=
config
.
pad_token_id
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
self
.
vocab_size
=
config
.
vocab_size
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
self
.
embed_tokens
=
VocabParallelEmbedding
(
perform_initialization
=
False
)
config
.
vocab_size
,
self
.
layers
=
nn
.
ModuleList
([
LlamaDecoderLayer
(
config
)
for
_
in
range
(
config
.
num_hidden_layers
)])
config
.
hidden_size
,
perform_initialization
=
False
)
self
.
layers
=
nn
.
ModuleList
([
LlamaDecoderLayer
(
config
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
def
forward
(
...
@@ -209,6 +223,7 @@ class LlamaModel(nn.Module):
...
@@ -209,6 +223,7 @@ class LlamaModel(nn.Module):
class
LlamaForCausalLM
(
nn
.
Module
):
class
LlamaForCausalLM
(
nn
.
Module
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
...
@@ -228,39 +243,42 @@ class LlamaForCausalLM(nn.Module):
...
@@ -228,39 +243,42 @@ class LlamaForCausalLM(nn.Module):
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
Dict
[
int
,
SequenceOutputs
]:
)
->
Dict
[
int
,
SequenceOutputs
]:
hidden_states
=
self
.
model
(
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
input_metadata
,
cache_events
)
next_tokens
=
self
.
sampler
(
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
self
.
lm_head
.
weight
,
hidden_states
,
input_metadata
)
input_metadata
)
return
next_tokens
return
next_tokens
_column_parallel_weights
=
[
"embed_tokens.weight"
,
"lm_head.weight"
,
_column_parallel_weights
=
[
"qkv_proj.weight"
,
"gate_proj.weight"
,
"embed_tokens.weight"
,
"lm_head.weight"
,
"qkv_proj.weight"
,
"up_proj.weight"
]
"gate_proj.weight"
,
"up_proj.weight"
]
_row_parallel_weights
=
[
"o_proj.weight"
,
"down_proj.weight"
]
_row_parallel_weights
=
[
"o_proj.weight"
,
"down_proj.weight"
]
def
load_weights
(
self
,
model_name_or_path
:
str
,
def
load_weights
(
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
cache_dir
:
Optional
[
str
]
=
None
,
use_np_cache
:
bool
=
False
):
use_np_cache
:
bool
=
False
):
tensor_model_parallel_rank
=
get_tensor_model_parallel_rank
()
tensor_model_parallel_rank
=
get_tensor_model_parallel_rank
()
state_dict
=
self
.
state_dict
()
state_dict
=
self
.
state_dict
()
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
use_np_cache
):
model_name_or_path
,
cache_dir
,
use_np_cache
):
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
continue
continue
is_attention_weight
=
False
is_attention_weight
=
False
for
stride_id
,
att_weight_name
in
enumerate
([
"q_proj"
,
"k_proj"
,
"v_proj"
]):
for
stride_id
,
att_weight_name
in
enumerate
(
[
"q_proj"
,
"k_proj"
,
"v_proj"
]):
if
att_weight_name
not
in
name
:
if
att_weight_name
not
in
name
:
continue
continue
param
=
state_dict
[
name
.
replace
(
att_weight_name
,
"qkv_proj"
)]
param
=
state_dict
[
name
.
replace
(
att_weight_name
,
"qkv_proj"
)]
shard_size
=
param
.
shape
[
0
]
//
3
shard_size
=
param
.
shape
[
0
]
//
3
loaded_weight
=
loaded_weight
[
loaded_weight
=
loaded_weight
[
shard_size
*
tensor_model_parallel_rank
shard_size
*
tensor_model_parallel_rank
:
shard_size
*
:
shard_size
*
(
tensor_model_parallel_rank
+
1
)]
(
tensor_model_parallel_rank
+
1
)]
param_slice
=
param
.
data
[
shard_size
*
stride_id
param_slice
=
param
.
data
[
shard_size
*
stride_id
:
shard_size
*
:
shard_size
*
(
stride_id
+
1
)]
(
stride_id
+
1
)]
assert
param_slice
.
shape
==
loaded_weight
.
shape
assert
param_slice
.
shape
==
loaded_weight
.
shape
param_slice
.
copy_
(
loaded_weight
)
param_slice
.
copy_
(
loaded_weight
)
is_attention_weight
=
True
is_attention_weight
=
True
...
@@ -275,10 +293,10 @@ class LlamaForCausalLM(nn.Module):
...
@@ -275,10 +293,10 @@ class LlamaForCausalLM(nn.Module):
param
=
state_dict
[
name
.
replace
(
weight_name
,
"gate_up_proj"
)]
param
=
state_dict
[
name
.
replace
(
weight_name
,
"gate_up_proj"
)]
shard_size
=
param
.
shape
[
0
]
//
2
shard_size
=
param
.
shape
[
0
]
//
2
loaded_weight
=
loaded_weight
[
loaded_weight
=
loaded_weight
[
shard_size
*
tensor_model_parallel_rank
shard_size
*
tensor_model_parallel_rank
:
shard_size
*
:
shard_size
*
(
tensor_model_parallel_rank
+
1
)]
(
tensor_model_parallel_rank
+
1
)]
param_slice
=
param
.
data
[
shard_size
*
stride_id
param_slice
=
param
.
data
[
shard_size
*
stride_id
:
shard_size
*
:
shard_size
*
(
stride_id
+
1
)]
(
stride_id
+
1
)]
assert
param_slice
.
shape
==
loaded_weight
.
shape
assert
param_slice
.
shape
==
loaded_weight
.
shape
param_slice
.
copy_
(
loaded_weight
)
param_slice
.
copy_
(
loaded_weight
)
is_gate_up_weight
=
True
is_gate_up_weight
=
True
...
...
vllm/model_executor/models/opt.py
View file @
d6fa1be3
# coding=utf-8
# coding=utf-8
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/opt/modeling_opt.py
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/opt/modeling_opt.py
# Copyright 2023 The vLLM team.
# Copyright 2023 The vLLM team.
# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights
# reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -43,8 +45,9 @@ KVCache = Tuple[torch.Tensor, torch.Tensor]
...
@@ -43,8 +45,9 @@ KVCache = Tuple[torch.Tensor, torch.Tensor]
class
OPTLearnedPositionalEmbedding
(
nn
.
Embedding
):
class
OPTLearnedPositionalEmbedding
(
nn
.
Embedding
):
def
__init__
(
self
,
num_embeddings
:
int
,
embedding_dim
:
int
):
def
__init__
(
self
,
num_embeddings
:
int
,
embedding_dim
:
int
):
# OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
# OPT is set up so that if padding_idx is specified then offset the
# and adjust num_embeddings appropriately. Other models don't have this hack
# embedding ids by 2 and adjust num_embeddings appropriately. Other
# models don't have this hack
self
.
offset
=
2
self
.
offset
=
2
super
().
__init__
(
num_embeddings
+
self
.
offset
,
embedding_dim
)
super
().
__init__
(
num_embeddings
+
self
.
offset
,
embedding_dim
)
...
@@ -62,20 +65,26 @@ class OPTAttention(nn.Module):
...
@@ -62,20 +65,26 @@ class OPTAttention(nn.Module):
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
embed_dim
=
embed_dim
self
.
embed_dim
=
embed_dim
tensor_model_parallel_world_size
=
get_tensor_model_parallel_world_size
()
tensor_model_parallel_world_size
=
(
get_tensor_model_parallel_world_size
())
total_num_heads
=
num_heads
total_num_heads
=
num_heads
assert
num_heads
%
tensor_model_parallel_world_size
==
0
assert
num_heads
%
tensor_model_parallel_world_size
==
0
self
.
num_heads
=
total_num_heads
//
tensor_model_parallel_world_size
self
.
num_heads
=
total_num_heads
//
tensor_model_parallel_world_size
self
.
head_dim
=
embed_dim
//
total_num_heads
self
.
head_dim
=
embed_dim
//
total_num_heads
self
.
scaling
=
self
.
head_dim
**
-
0.5
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
qkv_proj
=
ColumnParallelLinear
(
embed_dim
,
3
*
embed_dim
,
bias
=
bias
,
self
.
qkv_proj
=
ColumnParallelLinear
(
embed_dim
,
3
*
embed_dim
,
bias
=
bias
,
gather_output
=
False
,
gather_output
=
False
,
perform_initialization
=
False
)
perform_initialization
=
False
)
self
.
out_proj
=
RowParallelLinear
(
embed_dim
,
embed_dim
,
bias
=
bias
,
self
.
out_proj
=
RowParallelLinear
(
embed_dim
,
embed_dim
,
bias
=
bias
,
input_is_parallel
=
True
,
input_is_parallel
=
True
,
perform_initialization
=
False
)
perform_initialization
=
False
)
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_dim
,
scale
=
self
.
scaling
)
scale
=
self
.
scaling
)
def
forward
(
def
forward
(
...
@@ -88,8 +97,8 @@ class OPTAttention(nn.Module):
...
@@ -88,8 +97,8 @@ class OPTAttention(nn.Module):
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
key_cache
,
value_cache
=
kv_cache
key_cache
,
value_cache
=
kv_cache
attn_output
=
self
.
attn
(
attn_output
=
self
.
attn
(
q
,
k
,
v
,
key_cache
,
value_cache
,
q
,
k
,
v
,
key_cache
,
value_cache
,
input_metadata
,
cache_event
)
input_metadata
,
cache_event
)
output
,
_
=
self
.
out_proj
(
attn_output
)
output
,
_
=
self
.
out_proj
(
attn_output
)
return
output
return
output
...
@@ -109,17 +118,21 @@ class OPTDecoderLayer(nn.Module):
...
@@ -109,17 +118,21 @@ class OPTDecoderLayer(nn.Module):
self
.
activation_fn
=
get_act_fn
(
config
.
activation_function
)
self
.
activation_fn
=
get_act_fn
(
config
.
activation_function
)
self
.
self_attn_layer_norm
=
nn
.
LayerNorm
(
self
.
self_attn_layer_norm
=
nn
.
LayerNorm
(
self
.
embed_dim
,
elementwise_affine
=
config
.
layer_norm_elementwise_affine
)
self
.
embed_dim
,
self
.
fc1
=
ColumnParallelLinear
(
self
.
embed_dim
,
config
.
ffn_dim
,
elementwise_affine
=
config
.
layer_norm_elementwise_affine
)
self
.
fc1
=
ColumnParallelLinear
(
self
.
embed_dim
,
config
.
ffn_dim
,
bias
=
config
.
enable_bias
,
bias
=
config
.
enable_bias
,
gather_output
=
False
,
gather_output
=
False
,
perform_initialization
=
False
)
perform_initialization
=
False
)
self
.
fc2
=
RowParallelLinear
(
config
.
ffn_dim
,
self
.
embed_dim
,
self
.
fc2
=
RowParallelLinear
(
config
.
ffn_dim
,
self
.
embed_dim
,
bias
=
config
.
enable_bias
,
bias
=
config
.
enable_bias
,
input_is_parallel
=
True
,
input_is_parallel
=
True
,
perform_initialization
=
False
)
perform_initialization
=
False
)
self
.
final_layer_norm
=
nn
.
LayerNorm
(
self
.
final_layer_norm
=
nn
.
LayerNorm
(
self
.
embed_dim
,
elementwise_affine
=
config
.
layer_norm_elementwise_affine
)
self
.
embed_dim
,
elementwise_affine
=
config
.
layer_norm_elementwise_affine
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -133,11 +146,10 @@ class OPTDecoderLayer(nn.Module):
...
@@ -133,11 +146,10 @@ class OPTDecoderLayer(nn.Module):
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
if
self
.
do_layer_norm_before
:
if
self
.
do_layer_norm_before
:
hidden_states
=
self
.
self_attn_layer_norm
(
hidden_states
)
hidden_states
=
self
.
self_attn_layer_norm
(
hidden_states
)
hidden_states
=
self
.
self_attn
(
hidden_states
=
self
.
self_attn
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
kv_cache
=
kv_cache
,
input_metadata
=
input_metadata
,
input_metadata
=
input_metadata
,
cache_event
=
cache_event
)
cache_event
=
cache_event
)
hidden_states
=
residual
+
hidden_states
hidden_states
=
residual
+
hidden_states
# 350m applies layer norm AFTER attention
# 350m applies layer norm AFTER attention
if
not
self
.
do_layer_norm_before
:
if
not
self
.
do_layer_norm_before
:
...
@@ -167,35 +179,42 @@ class OPTDecoder(nn.Module):
...
@@ -167,35 +179,42 @@ class OPTDecoder(nn.Module):
self
.
max_target_positions
=
config
.
max_position_embeddings
self
.
max_target_positions
=
config
.
max_position_embeddings
self
.
vocab_size
=
config
.
vocab_size
self
.
vocab_size
=
config
.
vocab_size
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
word_embed_proj_dim
,
config
.
vocab_size
,
perform_initialization
=
False
)
config
.
word_embed_proj_dim
,
perform_initialization
=
False
)
# Positional embeddings are replicated (not sharded).
# Positional embeddings are replicated (not sharded).
self
.
embed_positions
=
OPTLearnedPositionalEmbedding
(
self
.
embed_positions
=
OPTLearnedPositionalEmbedding
(
config
.
max_position_embeddings
,
config
.
hidden_size
)
config
.
max_position_embeddings
,
config
.
hidden_size
)
# Project out & in will be replicated if they exist.
# Project out & in will be replicated if they exist.
if
config
.
word_embed_proj_dim
!=
config
.
hidden_size
:
if
config
.
word_embed_proj_dim
!=
config
.
hidden_size
:
self
.
project_out
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
word_embed_proj_dim
,
bias
=
False
)
self
.
project_out
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
word_embed_proj_dim
,
bias
=
False
)
else
:
else
:
self
.
project_out
=
None
self
.
project_out
=
None
if
config
.
word_embed_proj_dim
!=
config
.
hidden_size
:
if
config
.
word_embed_proj_dim
!=
config
.
hidden_size
:
self
.
project_in
=
nn
.
Linear
(
config
.
word_embed_proj_dim
,
config
.
hidden_size
,
bias
=
False
)
self
.
project_in
=
nn
.
Linear
(
config
.
word_embed_proj_dim
,
config
.
hidden_size
,
bias
=
False
)
else
:
else
:
self
.
project_in
=
None
self
.
project_in
=
None
# Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
# Note that the only purpose of `config._remove_final_layer_norm` is to
# with checkpoints that have been fine-tuned before transformers v4.20.1
# keep backward compatibility with checkpoints that have been fine-tuned
# before transformers v4.20.1
# see https://github.com/facebookresearch/metaseq/pull/164
# see https://github.com/facebookresearch/metaseq/pull/164
if
config
.
do_layer_norm_before
and
not
config
.
_remove_final_layer_norm
:
if
config
.
do_layer_norm_before
and
not
config
.
_remove_final_layer_norm
:
self
.
final_layer_norm
=
nn
.
LayerNorm
(
self
.
final_layer_norm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
elementwise_affine
=
config
.
layer_norm_elementwise_affine
config
.
hidden_size
,
)
elementwise_affine
=
config
.
layer_norm_elementwise_affine
)
else
:
else
:
self
.
final_layer_norm
=
None
self
.
final_layer_norm
=
None
self
.
layers
=
nn
.
ModuleList
([
OPTDecoderLayer
(
config
)
for
_
in
range
(
config
.
num_hidden_layers
)])
self
.
layers
=
nn
.
ModuleList
(
[
OPTDecoderLayer
(
config
)
for
_
in
range
(
config
.
num_hidden_layers
)])
def
forward
(
def
forward
(
self
,
self
,
...
@@ -217,8 +236,8 @@ class OPTDecoder(nn.Module):
...
@@ -217,8 +236,8 @@ class OPTDecoder(nn.Module):
else
:
else
:
cache_event
=
cache_events
[
i
]
cache_event
=
cache_events
[
i
]
layer
=
self
.
layers
[
i
]
layer
=
self
.
layers
[
i
]
hidden_states
=
layer
(
hidden_states
=
layer
(
hidden_states
,
kv_caches
[
i
],
input_metadata
,
hidden_states
,
kv_caches
[
i
],
input_metadata
,
cache_event
)
cache_event
)
if
self
.
final_layer_norm
is
not
None
:
if
self
.
final_layer_norm
is
not
None
:
hidden_states
=
self
.
final_layer_norm
(
hidden_states
)
hidden_states
=
self
.
final_layer_norm
(
hidden_states
)
...
@@ -241,8 +260,8 @@ class OPTModel(nn.Module):
...
@@ -241,8 +260,8 @@ class OPTModel(nn.Module):
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
self
.
decoder
(
return
self
.
decoder
(
input_ids
,
positions
,
kv_caches
,
input_metadata
,
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
cache_events
)
class
OPTForCausalLM
(
nn
.
Module
):
class
OPTForCausalLM
(
nn
.
Module
):
...
@@ -264,23 +283,26 @@ class OPTForCausalLM(nn.Module):
...
@@ -264,23 +283,26 @@ class OPTForCausalLM(nn.Module):
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
Dict
[
int
,
SequenceOutputs
]:
)
->
Dict
[
int
,
SequenceOutputs
]:
hidden_states
=
self
.
model
(
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
input_metadata
,
cache_events
)
next_tokens
=
self
.
sampler
(
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
self
.
lm_head_weight
,
hidden_states
,
input_metadata
)
input_metadata
)
return
next_tokens
return
next_tokens
_column_parallel_weights
=
[
"embed_tokens.weight"
,
"fc1.weight"
,
"fc1.bias"
]
_column_parallel_weights
=
[
"embed_tokens.weight"
,
"fc1.weight"
,
"fc1.bias"
]
_row_parallel_weights
=
[
"out_proj.weight"
,
"fc2.weight"
]
_row_parallel_weights
=
[
"out_proj.weight"
,
"fc2.weight"
]
def
load_weights
(
self
,
model_name_or_path
:
str
,
def
load_weights
(
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
cache_dir
:
Optional
[
str
]
=
None
,
use_np_cache
:
bool
=
False
):
use_np_cache
:
bool
=
False
):
tensor_model_parallel_rank
=
get_tensor_model_parallel_rank
()
tensor_model_parallel_rank
=
get_tensor_model_parallel_rank
()
state_dict
=
self
.
state_dict
()
state_dict
=
self
.
state_dict
()
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
use_np_cache
):
model_name_or_path
,
cache_dir
,
use_np_cache
):
if
"lm_head.weight"
in
name
:
if
"lm_head.weight"
in
name
:
continue
continue
...
@@ -288,16 +310,17 @@ class OPTForCausalLM(nn.Module):
...
@@ -288,16 +310,17 @@ class OPTForCausalLM(nn.Module):
name
=
"model."
+
name
name
=
"model."
+
name
is_attention_weight
=
False
is_attention_weight
=
False
for
stride_id
,
att_weight_name
in
enumerate
([
"q_proj"
,
"k_proj"
,
"v_proj"
]):
for
stride_id
,
att_weight_name
in
enumerate
(
[
"q_proj"
,
"k_proj"
,
"v_proj"
]):
if
att_weight_name
not
in
name
:
if
att_weight_name
not
in
name
:
continue
continue
param
=
state_dict
[
name
.
replace
(
att_weight_name
,
"qkv_proj"
)]
param
=
state_dict
[
name
.
replace
(
att_weight_name
,
"qkv_proj"
)]
shard_size
=
param
.
shape
[
0
]
//
3
shard_size
=
param
.
shape
[
0
]
//
3
loaded_weight
=
loaded_weight
[
loaded_weight
=
loaded_weight
[
shard_size
*
tensor_model_parallel_rank
shard_size
*
tensor_model_parallel_rank
:
shard_size
*
:
shard_size
*
(
tensor_model_parallel_rank
+
1
)]
(
tensor_model_parallel_rank
+
1
)]
param_slice
=
param
.
data
[
shard_size
*
stride_id
param_slice
=
param
.
data
[
shard_size
*
stride_id
:
shard_size
*
:
shard_size
*
(
stride_id
+
1
)]
(
stride_id
+
1
)]
assert
param_slice
.
shape
==
loaded_weight
.
shape
assert
param_slice
.
shape
==
loaded_weight
.
shape
param_slice
.
copy_
(
loaded_weight
)
param_slice
.
copy_
(
loaded_weight
)
is_attention_weight
=
True
is_attention_weight
=
True
...
...
vllm/model_executor/weight_utils.py
View file @
d6fa1be3
...
@@ -44,9 +44,9 @@ def hf_model_weights_iterator(
...
@@ -44,9 +44,9 @@ def hf_model_weights_iterator(
if
use_np_cache
:
if
use_np_cache
:
# Convert the model weights from torch tensors to numpy arrays for
# Convert the model weights from torch tensors to numpy arrays for
# faster loading.
# faster loading.
np_folder
=
os
.
path
.
join
(
hf_folder
,
'
np
'
)
np_folder
=
os
.
path
.
join
(
hf_folder
,
"
np
"
)
os
.
makedirs
(
np_folder
,
exist_ok
=
True
)
os
.
makedirs
(
np_folder
,
exist_ok
=
True
)
weight_names_file
=
os
.
path
.
join
(
np_folder
,
'
weight_names.json
'
)
weight_names_file
=
os
.
path
.
join
(
np_folder
,
"
weight_names.json
"
)
with
lock
:
with
lock
:
if
not
os
.
path
.
exists
(
weight_names_file
):
if
not
os
.
path
.
exists
(
weight_names_file
):
weight_names
=
[]
weight_names
=
[]
...
@@ -57,10 +57,10 @@ def hf_model_weights_iterator(
...
@@ -57,10 +57,10 @@ def hf_model_weights_iterator(
with
open
(
param_path
,
"wb"
)
as
f
:
with
open
(
param_path
,
"wb"
)
as
f
:
np
.
save
(
f
,
param
.
cpu
().
detach
().
numpy
())
np
.
save
(
f
,
param
.
cpu
().
detach
().
numpy
())
weight_names
.
append
(
name
)
weight_names
.
append
(
name
)
with
open
(
weight_names_file
,
'w'
)
as
f
:
with
open
(
weight_names_file
,
"w"
)
as
f
:
json
.
dump
(
weight_names
,
f
)
json
.
dump
(
weight_names
,
f
)
with
open
(
weight_names_file
,
'r'
)
as
f
:
with
open
(
weight_names_file
,
"r"
)
as
f
:
weight_names
=
json
.
load
(
f
)
weight_names
=
json
.
load
(
f
)
for
name
in
weight_names
:
for
name
in
weight_names
:
...
@@ -86,17 +86,16 @@ def load_tensor_parallel_weights(
...
@@ -86,17 +86,16 @@ def load_tensor_parallel_weights(
for
p
in
column_parallel_weight_names
:
for
p
in
column_parallel_weight_names
:
if
p
in
param_name
:
if
p
in
param_name
:
shard_size
=
param
.
shape
[
0
]
shard_size
=
param
.
shape
[
0
]
loaded_weight
=
loaded_weight
[
start_idx
=
tensor_model_parallel_rank
*
shard_size
shard_size
*
tensor_model_parallel_rank
end_idx
=
(
tensor_model_parallel_rank
+
1
)
*
shard_size
:
shard_size
*
(
tensor_model_parallel_rank
+
1
)
]
loaded_weight
=
loaded_weight
[
start_idx
:
end_idx
]
break
break
for
p
in
row_parallel_weight_names
:
for
p
in
row_parallel_weight_names
:
if
p
in
param_name
:
if
p
in
param_name
:
shard_size
=
param
.
shape
[
1
]
shard_size
=
param
.
shape
[
1
]
loaded_weight
=
loaded_weight
[
start_idx
=
tensor_model_parallel_rank
*
shard_size
:,
end_idx
=
(
tensor_model_parallel_rank
+
1
)
*
shard_size
shard_size
*
tensor_model_parallel_rank
loaded_weight
=
loaded_weight
[:,
start_idx
:
end_idx
]
:
shard_size
*
(
tensor_model_parallel_rank
+
1
)]
break
break
assert
param
.
shape
==
loaded_weight
.
shape
,
(
assert
param
.
shape
==
loaded_weight
.
shape
,
(
f
"
{
param_name
}
shape mismatch between model and checkpoint: "
f
"
{
param_name
}
shape mismatch between model and checkpoint: "
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment