Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
d6fa1be3
Unverified
Commit
d6fa1be3
authored
Jul 03, 2023
by
Zhuohan Li
Committed by
GitHub
Jul 03, 2023
Browse files
[Quality] Add code formatter and linter (#326)
parent
0ffded81
Changes
47
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
569 additions
and
374 deletions
+569
-374
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+18
-15
vllm/engine/ray_utils.py
vllm/engine/ray_utils.py
+7
-6
vllm/entrypoints/api_server.py
vllm/entrypoints/api_server.py
+8
-9
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+1
-2
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+67
-46
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+4
-2
vllm/logger.py
vllm/logger.py
+3
-3
vllm/model_executor/__init__.py
vllm/model_executor/__init__.py
+0
-1
vllm/model_executor/input_metadata.py
vllm/model_executor/input_metadata.py
+13
-2
vllm/model_executor/layers/activation.py
vllm/model_executor/layers/activation.py
+9
-10
vllm/model_executor/layers/attention.py
vllm/model_executor/layers/attention.py
+81
-32
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+45
-37
vllm/model_executor/model_loader.py
vllm/model_executor/model_loader.py
+6
-7
vllm/model_executor/models/__init__.py
vllm/model_executor/models/__init__.py
+0
-2
vllm/model_executor/models/gpt2.py
vllm/model_executor/models/gpt2.py
+53
-33
vllm/model_executor/models/gpt_bigcode.py
vllm/model_executor/models/gpt_bigcode.py
+66
-38
vllm/model_executor/models/gpt_neox.py
vllm/model_executor/models/gpt_neox.py
+52
-33
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+57
-39
vllm/model_executor/models/opt.py
vllm/model_executor/models/opt.py
+69
-46
vllm/model_executor/weight_utils.py
vllm/model_executor/weight_utils.py
+10
-11
No files found.
vllm/engine/llm_engine.py
View file @
d6fa1be3
...
...
@@ -67,8 +67,7 @@ class LLMEngine:
f
"download_dir=
{
model_config
.
download_dir
!
r
}
, "
f
"use_np_weights=
{
model_config
.
use_np_weights
}
, "
f
"tensor_parallel_size=
{
parallel_config
.
tensor_parallel_size
}
, "
f
"seed=
{
model_config
.
seed
}
)"
)
f
"seed=
{
model_config
.
seed
}
)"
)
# TODO(woosuk): Print more configs in debug mode.
self
.
model_config
=
model_config
...
...
@@ -78,8 +77,8 @@ class LLMEngine:
self
.
log_stats
=
log_stats
self
.
_verify_args
()
self
.
tokenizer
=
get_tokenizer
(
model_config
.
tokenizer
,
model_config
.
tokenizer_mode
)
self
.
tokenizer
=
get_tokenizer
(
model_config
.
tokenizer
,
tokenizer_mode
=
model_config
.
tokenizer_mode
)
self
.
seq_counter
=
Counter
()
# Create the parallel GPU workers.
...
...
@@ -129,8 +128,8 @@ class LLMEngine:
num_gpu_blocks
=
min
(
b
[
0
]
for
b
in
num_blocks
)
num_cpu_blocks
=
min
(
b
[
1
]
for
b
in
num_blocks
)
# FIXME(woosuk): Change to debug log.
logger
.
info
(
f
'
# GPU blocks:
{
num_gpu_blocks
}
,
'
f
'
# CPU blocks:
{
num_cpu_blocks
}
'
)
logger
.
info
(
f
"
# GPU blocks:
{
num_gpu_blocks
}
,
"
f
"
# CPU blocks:
{
num_cpu_blocks
}
"
)
if
num_gpu_blocks
<=
0
:
raise
ValueError
(
"No available memory for the cache blocks. "
...
...
@@ -152,7 +151,9 @@ class LLMEngine:
# Initialize the cluster.
distributed_init_method
,
devices
=
initialize_cluster
(
parallel_config
)
# Create the LLM engine.
engine
=
cls
(
*
engine_configs
,
distributed_init_method
,
devices
,
engine
=
cls
(
*
engine_configs
,
distributed_init_method
,
devices
,
log_stats
=
not
engine_args
.
disable_log_stats
)
return
engine
...
...
@@ -226,8 +227,10 @@ class LLMEngine:
and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results.
"""
seq_group_metadata_list
,
scheduler_outputs
,
ignored_seq_groups
=
self
.
scheduler
.
schedule
()
if
(
not
seq_group_metadata_list
)
and
scheduler_outputs
.
is_empty
()
and
(
not
ignored_seq_groups
):
(
seq_group_metadata_list
,
scheduler_outputs
,
ignored_seq_groups
)
=
self
.
scheduler
.
schedule
()
if
((
not
seq_group_metadata_list
)
and
scheduler_outputs
.
is_empty
()
and
(
not
ignored_seq_groups
)):
# Nothing to do.
return
[]
...
...
@@ -281,8 +284,8 @@ class LLMEngine:
# Truncate the output text so that the stop string is
# not included in the output.
seq
.
output_text
=
seq
.
output_text
[:
-
len
(
stop_str
)]
self
.
scheduler
.
free_seq
(
seq
,
SequenceStatus
.
FINISHED_STOPPED
)
self
.
scheduler
.
free_seq
(
seq
,
SequenceStatus
.
FINISHED_STOPPED
)
stopped
=
True
break
if
stopped
:
...
...
@@ -290,7 +293,7 @@ class LLMEngine:
# Check if the sequence has reached max_seq_len.
if
(
seq
.
get_len
()
>=
self
.
scheduler
.
scheduler_config
.
max_seq_len
):
self
.
scheduler
.
scheduler_config
.
max_seq_len
):
self
.
scheduler
.
free_seq
(
seq
,
SequenceStatus
.
FINISHED_LENGTH_CAPPED
)
continue
...
...
@@ -302,15 +305,15 @@ class LLMEngine:
# Check if the sequence has generated the EOS token.
if
not
sampling_params
.
ignore_eos
:
if
seq
.
get_last_token_id
()
==
self
.
tokenizer
.
eos_token_id
:
self
.
scheduler
.
free_seq
(
seq
,
SequenceStatus
.
FINISHED_STOPPED
)
self
.
scheduler
.
free_seq
(
seq
,
SequenceStatus
.
FINISHED_STOPPED
)
continue
def
_run_workers
(
self
,
method
:
str
,
get_all_outputs
:
bool
=
False
,
*
args
,
get_all_outputs
:
bool
=
False
,
**
kwargs
,
)
->
Any
:
"""Runs the given method on all workers."""
...
...
vllm/engine/ray_utils.py
View file @
d6fa1be3
...
...
@@ -8,7 +8,8 @@ except ImportError:
from
vllm.config
import
ParallelConfig
DeviceID
=
Tuple
[
int
,
Optional
[
str
],
int
]
# rank, node resource (node IP), device id
# rank, node resource (node IP), device id
DeviceID
=
Tuple
[
int
,
Optional
[
str
],
int
]
def
initialize_cluster
(
...
...
@@ -53,15 +54,15 @@ def initialize_cluster(
valid_node_resources
=
[]
num_devices_per_node
=
None
for
node
in
ray
.
nodes
():
if
(
not
node
[
'
Alive
'
])
or
node
[
'
Resources
'
][
'
GPU
'
]
<=
0
:
if
(
not
node
[
"
Alive
"
])
or
node
[
"
Resources
"
][
"
GPU
"
]
<=
0
:
continue
if
num_devices_per_node
is
None
:
num_devices_per_node
=
node
[
'
Resources
'
][
'
GPU
'
]
num_devices_per_node
=
node
[
"
Resources
"
][
"
GPU
"
]
else
:
assert
num_devices_per_node
==
node
[
'
Resources
'
][
'
GPU
'
],
(
assert
num_devices_per_node
==
node
[
"
Resources
"
][
"
GPU
"
],
(
"The number of GPUs per node is not uniform."
)
for
key
in
node
[
'
Resources
'
]:
if
key
.
startswith
(
'
node:
'
):
for
key
in
node
[
"
Resources
"
]:
if
key
.
startswith
(
"
node:
"
):
valid_node_resources
.
append
(
key
)
# Verify the parallel config.
...
...
vllm/entrypoints/api_server.py
View file @
d6fa1be3
...
...
@@ -11,8 +11,8 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils
import
random_uuid
TIMEOUT_KEEP_ALIVE
=
5
# seconds.
TIMEOUT_TO_PREVENT_DEADLOCK
=
1
# seconds
TIMEOUT_KEEP_ALIVE
=
5
# seconds.
TIMEOUT_TO_PREVENT_DEADLOCK
=
1
# seconds
.
app
=
FastAPI
()
...
...
@@ -37,8 +37,7 @@ async def generate(request: Request) -> Response:
async
for
request_output
in
results_generator
:
prompt
=
request_output
.
prompt
text_outputs
=
[
prompt
+
output
.
text
for
output
in
request_output
.
outputs
prompt
+
output
.
text
for
output
in
request_output
.
outputs
]
ret
=
{
"text"
:
text_outputs
}
yield
(
json
.
dumps
(
ret
)
+
"
\0
"
).
encode
(
"utf-8"
)
...
...
@@ -63,10 +62,7 @@ async def generate(request: Request) -> Response:
assert
final_output
is
not
None
prompt
=
final_output
.
prompt
text_outputs
=
[
prompt
+
output
.
text
for
output
in
final_output
.
outputs
]
text_outputs
=
[
prompt
+
output
.
text
for
output
in
final_output
.
outputs
]
ret
=
{
"text"
:
text_outputs
}
return
Response
(
content
=
json
.
dumps
(
ret
))
...
...
@@ -81,5 +77,8 @@ if __name__ == "__main__":
engine_args
=
AsyncEngineArgs
.
from_cli_args
(
args
)
engine
=
AsyncLLMEngine
.
from_engine_args
(
engine_args
)
uvicorn
.
run
(
app
,
host
=
args
.
host
,
port
=
args
.
port
,
log_level
=
"debug"
,
uvicorn
.
run
(
app
,
host
=
args
.
host
,
port
=
args
.
port
,
log_level
=
"debug"
,
timeout_keep_alive
=
TIMEOUT_KEEP_ALIVE
)
vllm/entrypoints/llm.py
View file @
d6fa1be3
...
...
@@ -63,8 +63,7 @@ class LLM:
self
.
request_counter
=
Counter
()
def
get_tokenizer
(
self
,
)
->
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
]:
self
)
->
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
]:
return
self
.
llm_engine
.
tokenizer
def
set_tokenizer
(
...
...
vllm/entrypoints/openai/api_server.py
View file @
d6fa1be3
# Adapted from https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/serve/openai_api_server.py
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/serve/openai_api_server.py
import
argparse
from
http
import
HTTPStatus
...
...
@@ -29,7 +30,7 @@ from vllm.sampling_params import SamplingParams
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
from
vllm.utils
import
random_uuid
TIMEOUT_KEEP_ALIVE
=
5
# seconds
TIMEOUT_KEEP_ALIVE
=
5
# seconds
logger
=
init_logger
(
__name__
)
served_model
=
None
...
...
@@ -38,14 +39,13 @@ app = fastapi.FastAPI()
def
create_error_response
(
status_code
:
HTTPStatus
,
message
:
str
)
->
JSONResponse
:
return
JSONResponse
(
ErrorResponse
(
message
=
message
,
type
=
"invalid_request_error"
).
dict
(),
status_code
=
status_code
.
value
)
return
JSONResponse
(
ErrorResponse
(
message
=
message
,
type
=
"invalid_request_error"
).
dict
(),
status_code
=
status_code
.
value
)
@
app
.
exception_handler
(
RequestValidationError
)
async
def
validation_exception_handler
(
request
,
exc
):
async
def
validation_exception_handler
(
request
,
exc
):
# pylint: disable=unused-argument
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
str
(
exc
))
...
...
@@ -126,8 +126,11 @@ async def check_length(request, prompt, engine):
@
app
.
get
(
"/v1/models"
)
async
def
show_available_models
():
"""Show available models. Right now we only have one model."""
model_cards
=
[
ModelCard
(
id
=
served_model
,
root
=
served_model
,
permission
=
[
ModelPermission
()])]
model_cards
=
[
ModelCard
(
id
=
served_model
,
root
=
served_model
,
permission
=
[
ModelPermission
()])
]
return
ModelList
(
data
=
model_cards
)
...
...
@@ -144,12 +147,14 @@ def create_logprobs(token_ids: List[int],
if
len
(
logprobs
.
text_offset
)
==
0
:
logprobs
.
text_offset
.
append
(
initial_text_offset
)
else
:
logprobs
.
text_offset
.
append
(
logprobs
.
text_offset
[
-
1
]
+
last_token_len
)
logprobs
.
text_offset
.
append
(
logprobs
.
text_offset
[
-
1
]
+
last_token_len
)
last_token_len
=
len
(
token
)
logprobs
.
top_logprobs
.
append
(
{
tokenizer
.
convert_ids_to_tokens
(
i
):
p
for
i
,
p
in
id_logprob
.
items
()})
logprobs
.
top_logprobs
.
append
({
tokenizer
.
convert_ids_to_tokens
(
i
):
p
for
i
,
p
in
id_logprob
.
items
()
})
return
logprobs
...
...
@@ -348,7 +353,7 @@ async def create_completion(raw_request: Request):
if
request
.
suffix
is
not
None
:
# The language models we currently support do not support suffix.
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
"suffix is not currently supported"
)
"suffix is not currently supported"
)
if
request
.
logit_bias
is
not
None
:
# TODO: support logit_bias in vLLM engine.
...
...
@@ -387,22 +392,23 @@ async def create_completion(raw_request: Request):
except
ValueError
as
e
:
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
str
(
e
))
result_generator
=
engine
.
generate
(
prompt
,
sampling_params
,
request_id
)
result_generator
=
engine
.
generate
(
prompt
,
sampling_params
,
request_id
)
# Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use beam search.
stream
=
(
request
.
stream
and
(
request
.
best_of
is
None
or
request
.
n
==
request
.
best_of
)
and
not
request
.
use_beam_search
)
stream
=
(
request
.
stream
and
(
request
.
best_of
is
None
or
request
.
n
==
request
.
best_of
)
and
not
request
.
use_beam_search
)
async
def
abort_request
()
->
None
:
await
engine
.
abort
(
request_id
)
def
create_stream_response_json
(
index
:
int
,
text
:
str
,
logprobs
:
Optional
[
LogProbs
]
=
None
,
finish_reason
:
Optional
[
str
]
=
None
)
->
str
:
def
create_stream_response_json
(
index
:
int
,
text
:
str
,
logprobs
:
Optional
[
LogProbs
]
=
None
,
finish_reason
:
Optional
[
str
]
=
None
,
)
->
str
:
choice_data
=
CompletionResponseStreamChoice
(
index
=
index
,
text
=
text
,
...
...
@@ -443,7 +449,8 @@ async def create_completion(raw_request: Request):
)
yield
f
"data:
{
response_json
}
\n\n
"
if
output
.
finish_reason
is
not
None
:
logprobs
=
LogProbs
()
if
request
.
logprobs
is
not
None
else
None
logprobs
=
(
LogProbs
()
if
request
.
logprobs
is
not
None
else
None
)
response_json
=
create_stream_response_json
(
index
=
i
,
text
=
""
,
...
...
@@ -487,8 +494,8 @@ async def create_completion(raw_request: Request):
choices
.
append
(
choice_data
)
num_prompt_tokens
=
len
(
final_res
.
prompt_token_ids
)
num_generated_tokens
=
sum
(
len
(
output
.
token_ids
)
for
output
in
final_res
.
outputs
)
num_generated_tokens
=
sum
(
len
(
output
.
token_ids
)
for
output
in
final_res
.
outputs
)
usage
=
UsageInfo
(
prompt_tokens
=
num_prompt_tokens
,
completion_tokens
=
num_generated_tokens
,
...
...
@@ -506,9 +513,11 @@ async def create_completion(raw_request: Request):
# When user requests streaming but we don't stream, we still need to
# return a streaming response with a single event.
response_json
=
response
.
json
(
ensure_ascii
=
False
)
async
def
fake_stream_generator
()
->
AsyncGenerator
[
str
,
None
]:
yield
f
"data:
{
response_json
}
\n\n
"
yield
"data: [DONE]
\n\n
"
return
StreamingResponse
(
fake_stream_generator
(),
media_type
=
"text/event-stream"
)
...
...
@@ -517,26 +526,34 @@ async def create_completion(raw_request: Request):
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
"vLLM OpenAI-Compatible RESTful API server."
)
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"localhost"
,
help
=
"host name"
)
description
=
"vLLM OpenAI-Compatible RESTful API server."
)
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"localhost"
,
help
=
"host name"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
8000
,
help
=
"port number"
)
parser
.
add_argument
(
"--allow-credentials"
,
action
=
"store_true"
,
help
=
"allow credentials"
)
parser
.
add_argument
(
"--allowed-origins"
,
type
=
json
.
loads
,
default
=
[
"*"
],
help
=
"allowed origins"
)
parser
.
add_argument
(
"--allowed-methods"
,
type
=
json
.
loads
,
default
=
[
"*"
],
help
=
"allowed methods"
)
parser
.
add_argument
(
"--allowed-headers"
,
type
=
json
.
loads
,
default
=
[
"*"
],
help
=
"allowed headers"
)
parser
.
add_argument
(
"--allow-credentials"
,
action
=
"store_true"
,
help
=
"allow credentials"
)
parser
.
add_argument
(
"--allowed-origins"
,
type
=
json
.
loads
,
default
=
[
"*"
],
help
=
"allowed origins"
)
parser
.
add_argument
(
"--allowed-methods"
,
type
=
json
.
loads
,
default
=
[
"*"
],
help
=
"allowed methods"
)
parser
.
add_argument
(
"--allowed-headers"
,
type
=
json
.
loads
,
default
=
[
"*"
],
help
=
"allowed headers"
)
parser
.
add_argument
(
"--served-model-name"
,
type
=
str
,
default
=
None
,
help
=
"The model name used in the API. If not specified, "
"the model name will be the same as the "
"huggingface name."
)
"--served-model-name"
,
type
=
str
,
default
=
None
,
help
=
"The model name used in the API. If not specified, "
"the model name will be the same as the "
"huggingface name."
)
parser
=
AsyncEngineArgs
.
add_cli_args
(
parser
)
args
=
parser
.
parse_args
()
...
...
@@ -556,7 +573,11 @@ if __name__ == "__main__":
engine
=
AsyncLLMEngine
.
from_engine_args
(
engine_args
)
# A separate tokenizer to map token IDs to strings.
tokenizer
=
get_tokenizer
(
engine_args
.
tokenizer
,
engine_args
.
tokenizer_mode
)
tokenizer
=
get_tokenizer
(
engine_args
.
tokenizer
,
tokenizer_mode
=
engine_args
.
tokenizer_mode
)
uvicorn
.
run
(
app
,
host
=
args
.
host
,
port
=
args
.
port
,
log_level
=
"info"
,
uvicorn
.
run
(
app
,
host
=
args
.
host
,
port
=
args
.
port
,
log_level
=
"info"
,
timeout_keep_alive
=
TIMEOUT_KEEP_ALIVE
)
vllm/entrypoints/openai/protocol.py
View file @
d6fa1be3
# Adapted from https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
import
time
from
typing
import
Dict
,
List
,
Literal
,
Optional
,
Union
...
...
@@ -98,7 +99,8 @@ class LogProbs(BaseModel):
text_offset
:
List
[
int
]
=
Field
(
default_factory
=
list
)
token_logprobs
:
List
[
Optional
[
float
]]
=
Field
(
default_factory
=
list
)
tokens
:
List
[
str
]
=
Field
(
default_factory
=
list
)
top_logprobs
:
List
[
Optional
[
Dict
[
str
,
float
]]]
=
Field
(
default_factory
=
list
)
top_logprobs
:
List
[
Optional
[
Dict
[
str
,
float
]]]
=
Field
(
default_factory
=
list
)
class
CompletionResponseChoice
(
BaseModel
):
...
...
vllm/logger.py
View file @
d6fa1be3
# Adapted from https://github.com/skypilot-org/skypilot/blob/86dc0f6283a335e4aa37b3c10716f90999f48ab6/sky/sky_logging.py
# Adapted from
# https://github.com/skypilot-org/skypilot/blob/86dc0f6283a335e4aa37b3c10716f90999f48ab6/sky/sky_logging.py
"""Logging configuration for vLLM."""
import
logging
import
sys
_FORMAT
=
"%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s"
_DATE_FORMAT
=
"%m-%d %H:%M:%S"
...
...
vllm/model_executor/__init__.py
View file @
d6fa1be3
...
...
@@ -2,7 +2,6 @@ from vllm.model_executor.input_metadata import InputMetadata
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.utils
import
set_random_seed
__all__
=
[
"InputMetadata"
,
"get_model"
,
...
...
vllm/model_executor/input_metadata.py
View file @
d6fa1be3
...
...
@@ -8,11 +8,22 @@ from vllm.sequence import SequenceData
class
InputMetadata
:
"""Metadata for input sequences. Used for PagedAttention.
Args:
seq_groups: List of (seq_ids, sampling_params).
seq_data: Seq_id -> SequenceData.
prompt_lens: Lengths of prompts.
slot_mapping: The address to write the new KV to of each token.
context_lens: the length of attention context for each generation token.
max_context_len: The maximum context length.
block_tables: The block tables. (Seq id -> list of physical block)
"""
def
__init__
(
self
,
seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]],
# List of (seq_ids, sampling_params).
seq_data
:
Dict
[
int
,
SequenceData
],
# Seq_id -> SequenceData.
seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]],
seq_data
:
Dict
[
int
,
SequenceData
],
prompt_lens
:
List
[
int
],
slot_mapping
:
torch
.
Tensor
,
context_lens
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/activation.py
View file @
d6fa1be3
...
...
@@ -6,9 +6,10 @@ from vllm import activation_ops
_ACTIVATION_REGISTRY
=
{
"gelu"
:
nn
.
GELU
(),
"gelu_new"
:
nn
.
GELU
(
approximate
=
"tanh"
),
# NOTE: This may introduce small rounding errors.
"gelu_fast"
:
nn
.
GELU
(
approximate
=
"tanh"
),
# NOTE: This may introduce small rounding errors.
"gelu_pytorch_tanh"
:
nn
.
GELU
(
approximate
=
"tanh"
),
# NOTE: This may introduce small rounding errors.
# NOTE: The following GELU functions may introduce small rounding errors.
"gelu_new"
:
nn
.
GELU
(
approximate
=
"tanh"
),
"gelu_fast"
:
nn
.
GELU
(
approximate
=
"tanh"
),
"gelu_pytorch_tanh"
:
nn
.
GELU
(
approximate
=
"tanh"
),
"relu"
:
nn
.
ReLU
(),
}
...
...
@@ -25,15 +26,13 @@ class SiluAndMul(nn.Module):
"""An activation function for SwiGLU.
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[1] // 2.
"""
def
__init__
(
self
):
super
().
__init__
()
Shapes:
x: (num_tokens, 2 * d)
return: (num_tokens, d)
"""
def
forward
(
self
,
x
:
torch
.
Tensor
,
# (num_tokens, 2 * d)
)
->
torch
.
Tensor
:
# (num_tokens, d)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_tokens
=
x
.
shape
[
0
]
d
=
x
.
shape
[
1
]
//
2
out
=
torch
.
empty
(
num_tokens
,
d
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
...
...
vllm/model_executor/layers/attention.py
View file @
d6fa1be3
...
...
@@ -14,6 +14,7 @@ _SUPPORTED_HEAD_SIZES = [64, 80, 96, 128]
class
PagedAttention
(
nn
.
Module
):
# pylint: disable=line-too-long
"""GPT-style multi-head PagedAttention.
This class takes flattened 1D query, key, and value tensors as input. The
...
...
@@ -54,12 +55,20 @@ class PagedAttention(nn.Module):
def
multi_query_kv_attention
(
self
,
output
:
torch
.
Tensor
,
# [num_prompt_tokens, num_heads, head_size]
query
:
torch
.
Tensor
,
# [num_prompt_tokens, num_heads, head_size]
key
:
torch
.
Tensor
,
# [num_prompt_tokens, num_heads, head_size]
value
:
torch
.
Tensor
,
# [num_prompt_tokens, num_heads, head_size]
output
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
attn_bias
:
xops
.
AttentionBias
,
)
->
torch
.
Tensor
:
"""Normal attention for the prompt tokens.
Args:
output: shape = [num_prompt_tokens, num_heads, head_size]
query: shape = [num_prompt_tokens, num_heads, head_size]
key: shape = [num_prompt_tokens, num_heads, head_size]
value: shape = [num_prompt_tokens, num_heads, head_size]
"""
# TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
out
=
xops
.
memory_efficient_attention_forward
(
query
.
unsqueeze
(
0
),
...
...
@@ -76,12 +85,22 @@ class PagedAttention(nn.Module):
def
single_query_cached_kv_attention
(
self
,
output
:
torch
.
Tensor
,
# [num_generation_tokens, num_heads, head_size]
query
:
torch
.
Tensor
,
# [num_generation_tokens, num_heads, head_size]
key_cache
:
torch
.
Tensor
,
# [num_blocks, num_heads, head_size/x, block_size, x]
value_cache
:
torch
.
Tensor
,
# [num_blocks, num_heads, head_size, block_size]
output
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
)
->
None
:
"""PagedAttention for the generation tokens.
Args:
output: shape = [num_generation_tokens, num_heads, head_size]
query: shape = [num_generation_tokens, num_heads, head_size]
key_cache: shape = [num_blocks, num_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_heads, head_size, block_size]
input_metadata: metadata for paged attention.
"""
block_size
=
value_cache
.
shape
[
3
]
attention_ops
.
single_query_cached_kv_attention
(
output
,
...
...
@@ -97,16 +116,32 @@ class PagedAttention(nn.Module):
def
forward
(
self
,
query
:
torch
.
Tensor
,
# [num_tokens, num_heads * head_size]
key
:
torch
.
Tensor
,
# [num_tokens, num_heads * head_size]
value
:
torch
.
Tensor
,
# [num_tokens, num_heads * head_size]
key_cache
:
Optional
[
torch
.
Tensor
],
# [num_blocks, num_heads, head_size/x, block_size, x]
value_cache
:
Optional
[
torch
.
Tensor
],
# [num_blocks, num_heads, head_size, block_size]
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_cache
:
Optional
[
torch
.
Tensor
],
value_cache
:
Optional
[
torch
.
Tensor
],
input_metadata
:
InputMetadata
,
cache_event
:
Optional
[
torch
.
cuda
.
Event
],
)
->
torch
.
Tensor
:
# [num_tokens, num_heads * head_size]
# NOTE: The query, key, and value tensors must be sliced from a qkv
# tensor of shape [num_tokens, 3 * num_heads * head_size].
)
->
torch
.
Tensor
:
"""PagedAttention forward pass.
NOTE: The query, key, and value tensors must be sliced from a qkv
tensor of shape [num_tokens, 3 * num_heads * head_size].
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_heads * head_size]
value: shape = [num_tokens, num_heads * head_size]
key_cache: shape = [num_blocks, num_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_heads, head_size, block_size]
input_metadata: metadata for paged attention.
cache_event: event to wait for the cache operations to finish.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
# Reshape the query, key, and value tensors.
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
...
...
@@ -136,7 +171,7 @@ class PagedAttention(nn.Module):
# and value vectors will not be cached.
num_valid_tokens
=
input_metadata
.
num_valid_tokens
if
(
num_valid_tokens
>
0
and
key_cache
is
not
None
and
value_cache
is
not
None
):
and
value_cache
is
not
None
):
# The stride is 3 because the key and value are sliced from qkv.
cache_ops
.
reshape_and_cache
(
key
[:
num_valid_tokens
],
...
...
@@ -149,15 +184,12 @@ class PagedAttention(nn.Module):
if
input_metadata
.
num_generation_tokens
>
0
:
assert
key_cache
is
not
None
and
value_cache
is
not
None
,
(
"key_cache and value_cache must be provided when "
"generating tokens."
)
"generating tokens."
)
# Compute the attention op for generation tokens.
self
.
single_query_cached_kv_attention
(
output
[
num_prompt_tokens
:
num_valid_tokens
],
query
[
num_prompt_tokens
:
num_valid_tokens
],
key_cache
,
value_cache
,
input_metadata
)
query
[
num_prompt_tokens
:
num_valid_tokens
],
key_cache
,
value_cache
,
input_metadata
)
# Reshape the output tensor.
# NOTE(woosuk): The output tensor may include paddings.
...
...
@@ -179,9 +211,9 @@ class PagedAttentionWithRoPE(PagedAttention):
super
().
__init__
(
num_heads
,
head_size
,
scale
)
# Create the cos and sin cache.
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
rotary_dim
,
2
)
/
rotary_dim
))
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
rotary_dim
,
2
)
/
rotary_dim
))
t
=
torch
.
arange
(
max_position
).
float
()
freqs
=
torch
.
einsum
(
'
i,j -> ij
'
,
t
,
inv_freq
.
float
())
freqs
=
torch
.
einsum
(
"
i,j -> ij
"
,
t
,
inv_freq
.
float
())
cos
=
freqs
.
cos
()
sin
=
freqs
.
sin
()
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
...
...
@@ -195,15 +227,32 @@ class PagedAttentionWithRoPE(PagedAttention):
def
forward
(
self
,
positions
:
torch
.
Tensor
,
# [num_tokens]
query
:
torch
.
Tensor
,
# [num_tokens, num_heads * head_size]
key
:
torch
.
Tensor
,
# [num_tokens, num_heads * head_size]
value
:
torch
.
Tensor
,
# [num_tokens, num_heads * head_size]
key_cache
:
torch
.
Tensor
,
# [num_blocks, num_heads, head_size/x, block_size, x]
value_cache
:
torch
.
Tensor
,
# [num_blocks, num_heads, head_size, block_size]
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
cache_event
:
Optional
[
torch
.
cuda
.
Event
],
)
->
torch
.
Tensor
:
# [num_tokens, num_heads * head_size]
)
->
torch
.
Tensor
:
""" PagedAttention forward pass with rotary embedding.
Args:
positions: shape = [num_tokens]
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_heads * head_size]
value: shape = [num_tokens, num_heads * head_size]
key_cache: shape = [num_blocks, num_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_heads, head_size, block_size]
input_metadata: metadata for paged attention.
cache_event: event to wait for the cache operations to finish.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
# Apply rotary embedding to the query and key before passing them
# to the attention op.
pos_encoding_ops
.
rotary_embedding_neox
(
...
...
vllm/model_executor/layers/sampler.py
View file @
d6fa1be3
...
...
@@ -13,6 +13,7 @@ from vllm.sequence import SequenceOutputs
_SAMPLING_EPS
=
1e-5
class
Sampler
(
nn
.
Module
):
"""Samples the next tokens from the model's outputs.
...
...
@@ -50,19 +51,20 @@ class Sampler(nn.Module):
# Apply presence and frequency penalties.
output_tokens
=
_get_output_tokens
(
input_metadata
)
assert
len
(
output_tokens
)
==
logits
.
shape
[
0
]
presence_penalties
,
frequency_penalties
=
_get_penalties
(
input_metadata
)
presence_penalties
,
frequency_penalties
=
_get_penalties
(
input_metadata
)
assert
len
(
presence_penalties
)
==
logits
.
shape
[
0
]
assert
len
(
frequency_penalties
)
==
logits
.
shape
[
0
]
logits
=
_apply_penalties
(
logits
,
output_tokens
,
presence_penalties
,
frequency_penalties
,
self
.
vocab_size
)
logits
=
_apply_penalties
(
logits
,
output_tokens
,
presence_penalties
,
frequency_penalties
,
self
.
vocab_size
)
# Apply temperature scaling.
temperatures
=
_get_temperatures
(
input_metadata
)
assert
len
(
temperatures
)
==
logits
.
shape
[
0
]
if
any
(
t
!=
1.0
for
t
in
temperatures
):
t
=
torch
.
tensor
(
temperatures
,
dtype
=
logits
.
dtype
,
device
=
logits
.
device
)
t
=
torch
.
tensor
(
temperatures
,
dtype
=
logits
.
dtype
,
device
=
logits
.
device
)
# Use in-place division to avoid creating a new tensor.
logits
.
div_
(
t
.
unsqueeze
(
dim
=
1
))
...
...
@@ -75,7 +77,9 @@ class Sampler(nn.Module):
# Apply top-p and top-k truncation.
top_ps
,
top_ks
=
_get_top_p_top_k
(
input_metadata
,
self
.
vocab_size
)
assert
len
(
top_ps
)
==
len
(
top_ks
)
==
probs
.
shape
[
0
]
if
any
(
p
<
1.0
-
_SAMPLING_EPS
for
p
in
top_ps
)
or
any
(
k
!=
self
.
vocab_size
for
k
in
top_ks
):
do_top_p
=
any
(
p
<
1.0
-
_SAMPLING_EPS
for
p
in
top_ps
)
do_top_k
=
any
(
k
!=
self
.
vocab_size
for
k
in
top_ks
)
if
do_top_p
or
do_top_k
:
probs
=
_apply_top_p_top_k
(
probs
,
top_ps
,
top_ks
)
# Sample the next tokens.
...
...
@@ -97,8 +101,7 @@ def _prune_hidden_states(
def
_get_penalties
(
input_metadata
:
InputMetadata
,
)
->
Tuple
[
List
[
float
],
List
[
float
]]:
input_metadata
:
InputMetadata
)
->
Tuple
[
List
[
float
],
List
[
float
]]:
# Collect the presence and frequency penalties.
presence_penalties
:
List
[
float
]
=
[]
frequency_penalties
:
List
[
float
]
=
[]
...
...
@@ -117,9 +120,7 @@ def _get_penalties(
return
presence_penalties
,
frequency_penalties
def
_get_output_tokens
(
input_metadata
:
InputMetadata
,
)
->
List
[
List
[
int
]]:
def
_get_output_tokens
(
input_metadata
:
InputMetadata
)
->
List
[
List
[
int
]]:
output_tokens
:
List
[
List
[
int
]]
=
[]
for
i
,
seq_group
in
enumerate
(
input_metadata
.
seq_groups
):
seq_ids
,
_
=
seq_group
...
...
@@ -169,11 +170,13 @@ def _apply_penalties(
device
=
logits
.
device
)
frequency_penalties
=
[
frequency_penalties
[
i
]
for
i
in
indices
]
frequency_penalties
=
torch
.
tensor
(
frequency_penalties
,
dtype
=
logits
.
dtype
,
device
=
logits
.
device
)
frequency_penalties
=
torch
.
tensor
(
frequency_penalties
,
dtype
=
logits
.
dtype
,
device
=
logits
.
device
)
presence_penalties
=
[
presence_penalties
[
i
]
for
i
in
indices
]
presence_penalties
=
torch
.
tensor
(
presence_penalties
,
dtype
=
logits
.
dtype
,
device
=
logits
.
device
)
presence_penalties
=
torch
.
tensor
(
presence_penalties
,
dtype
=
logits
.
dtype
,
device
=
logits
.
device
)
# We follow the definition in OpenAI API.
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
...
...
@@ -183,9 +186,7 @@ def _apply_penalties(
return
logits
def
_get_temperatures
(
input_metadata
:
InputMetadata
,
)
->
List
[
float
]:
def
_get_temperatures
(
input_metadata
:
InputMetadata
)
->
List
[
float
]:
# Collect the temperatures for the logits.
temperatures
:
List
[
float
]
=
[]
for
i
,
seq_group
in
enumerate
(
input_metadata
.
seq_groups
):
...
...
@@ -252,8 +253,9 @@ def _apply_top_p_top_k(
probs_sort
[
top_k_mask
]
=
0.0
# Re-sort the probabilities.
probs
=
torch
.
gather
(
probs_sort
,
dim
=-
1
,
index
=
torch
.
argsort
(
probs_idx
,
dim
=-
1
))
probs
=
torch
.
gather
(
probs_sort
,
dim
=-
1
,
index
=
torch
.
argsort
(
probs_idx
,
dim
=-
1
))
return
probs
...
...
@@ -296,8 +298,9 @@ def _sample_from_prompt(
# Random sampling.
# Sample `best_of` tokens for the prompt.
num_seqs
=
sampling_params
.
best_of
next_token_ids
=
torch
.
multinomial
(
prob
,
num_samples
=
num_seqs
,
replacement
=
True
)
next_token_ids
=
torch
.
multinomial
(
prob
,
num_samples
=
num_seqs
,
replacement
=
True
)
next_token_ids
=
next_token_ids
.
tolist
()
return
next_token_ids
...
...
@@ -315,8 +318,9 @@ def _sample_from_generation_tokens(
if
sampling_params
.
use_beam_search
:
# Beam search.
# Add cumulative logprobs for the sequences in the group.
seq_logprobs
=
torch
.
tensor
(
seq_logprobs
,
dtype
=
torch
.
float
,
device
=
logprobs
.
device
)
seq_logprobs
=
torch
.
tensor
(
seq_logprobs
,
dtype
=
torch
.
float
,
device
=
logprobs
.
device
)
logprobs
=
logprobs
+
seq_logprobs
.
unsqueeze
(
dim
=
1
)
vocab_size
=
logprobs
.
size
(
-
1
)
...
...
@@ -353,8 +357,9 @@ def _sample_from_generation_tokens(
else
:
# Random sampling.
# Sample 1 token for each sequence in the group.
next_token_ids
=
torch
.
multinomial
(
probs
,
num_samples
=
1
,
replacement
=
True
)
next_token_ids
=
torch
.
multinomial
(
probs
,
num_samples
=
1
,
replacement
=
True
)
next_token_ids
=
next_token_ids
.
squeeze
(
dim
=-
1
).
tolist
()
parent_seq_ids
=
seq_ids
return
parent_seq_ids
,
next_token_ids
...
...
@@ -381,15 +386,16 @@ def _sample(
# Sample the next tokens.
next_token_ids
=
_sample_from_prompt
(
prob
,
sampling_params
)
# Get top-k log probabilities for the next tokens.
next_logprobs
=
_get_topk_logprobs
(
logprob
,
sampling_params
.
logprobs
)
next_logprobs
=
_get_topk_logprobs
(
logprob
,
sampling_params
.
logprobs
)
# Build the output.
for
seq_id
,
next_token_id
in
zip
(
seq_ids
,
next_token_ids
):
output_logprobs
=
next_logprobs
.
copy
()
output_logprobs
[
next_token_id
]
=
logprob
[
next_token_id
].
item
()
seq_outputs
[
seq_id
]
=
SequenceOutputs
(
seq_id
,
seq_id
,
next_token_id
,
output_logprobs
)
seq_outputs
[
seq_id
]
=
SequenceOutputs
(
seq_id
,
seq_id
,
next_token_id
,
output_logprobs
)
else
:
# Generate the next tokens for generation tokens.
prob
=
probs
[
idx
:
idx
+
len
(
seq_ids
)]
...
...
@@ -399,22 +405,24 @@ def _sample(
# Sample the next tokens.
seq_logprobs
=
[
input_metadata
.
seq_data
[
seq_id
].
cumulative_logprob
for
seq_id
in
seq_ids
]
for
seq_id
in
seq_ids
]
parent_seq_ids
,
next_token_ids
=
_sample_from_generation_tokens
(
seq_ids
,
prob
,
logprob
,
seq_logprobs
,
sampling_params
)
# Get top-k log probabilities for the next tokens.
next_logprobs
:
Dict
[
int
,
Dict
[
int
,
float
]]
=
{}
for
i
,
seq_id
in
enumerate
(
seq_ids
):
for
j
,
seq_id
in
enumerate
(
seq_ids
):
next_logprobs
[
seq_id
]
=
_get_topk_logprobs
(
logprob
[
i
],
sampling_params
.
logprobs
)
logprob
[
j
],
sampling_params
.
logprobs
)
# Build the output.
for
seq_id
,
parent_seq_id
,
next_token_id
in
zip
(
seq_ids
,
parent_seq_ids
,
next_token_ids
):
i
=
seq_ids
.
index
(
parent_seq_id
)
seq_ids
,
parent_seq_ids
,
next_token_ids
):
j
=
seq_ids
.
index
(
parent_seq_id
)
output_logprobs
=
next_logprobs
[
parent_seq_id
].
copy
()
output_logprobs
[
next_token_id
]
=
logprob
[
i
,
next_token_id
].
item
()
output_logprobs
[
next_token_id
]
=
logprob
[
j
,
next_token_id
].
item
()
seq_outputs
[
seq_id
]
=
SequenceOutputs
(
seq_id
,
parent_seq_id
,
...
...
vllm/model_executor/model_loader.py
View file @
d6fa1be3
...
...
@@ -6,8 +6,9 @@ import torch.nn as nn
from
transformers
import
PretrainedConfig
from
vllm.config
import
ModelConfig
from
vllm.model_executor.models
import
(
GPT2LMHeadModel
,
GPTBigCodeForCausalLM
,
GPTNeoXForCausalLM
,
LlamaForCausalLM
,
OPTForCausalLM
)
from
vllm.model_executor.models
import
(
GPT2LMHeadModel
,
GPTBigCodeForCausalLM
,
GPTNeoXForCausalLM
,
LlamaForCausalLM
,
OPTForCausalLM
)
from
vllm.model_executor.weight_utils
import
initialize_dummy_weights
# TODO(woosuk): Lazy-load the model classes.
...
...
@@ -28,8 +29,7 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
return
_MODEL_REGISTRY
[
arch
]
raise
ValueError
(
f
"Model architectures
{
architectures
}
are not supported for now. "
f
"Supported architectures:
{
list
(
_MODEL_REGISTRY
.
keys
())
}
"
)
f
"Supported architectures:
{
list
(
_MODEL_REGISTRY
.
keys
())
}
"
)
def
get_model
(
model_config
:
ModelConfig
)
->
nn
.
Module
:
...
...
@@ -46,8 +46,7 @@ def get_model(model_config: ModelConfig) -> nn.Module:
initialize_dummy_weights
(
model
)
else
:
# Load the weights from the cached or downloaded files.
model
.
load_weights
(
model_config
.
model
,
model_config
.
download_dir
,
model_config
.
use_np_weights
)
model
.
load_weights
(
model_config
.
model
,
model_config
.
download_dir
,
model_config
.
use_np_weights
)
model
=
model
.
cuda
()
return
model
.
eval
()
vllm/model_executor/models/__init__.py
View file @
d6fa1be3
...
...
@@ -4,8 +4,6 @@ from vllm.model_executor.models.gpt_neox import GPTNeoXForCausalLM
from
vllm.model_executor.models.llama
import
LlamaForCausalLM
from
vllm.model_executor.models.opt
import
OPTForCausalLM
__all__
=
[
"GPT2LMHeadModel"
,
"GPTBigCodeForCausalLM"
,
...
...
vllm/model_executor/models/gpt2.py
View file @
d6fa1be3
# coding=utf-8
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
# Copyright 2023 The vLLM team.
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
...
...
@@ -47,19 +48,25 @@ class GPT2Attention(nn.Module):
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
total_num_heads
=
config
.
num_attention_heads
tensor_model_parallel_world_size
=
get_tensor_model_parallel_world_size
()
tensor_model_parallel_world_size
=
(
get_tensor_model_parallel_world_size
())
assert
total_num_heads
%
tensor_model_parallel_world_size
==
0
self
.
num_heads
=
total_num_heads
//
tensor_model_parallel_world_size
self
.
head_dim
=
self
.
hidden_size
//
total_num_heads
self
.
scale
=
self
.
head_dim
**
-
0.5
self
.
scale
=
self
.
head_dim
**-
0.5
self
.
c_attn
=
ColumnParallelLinear
(
self
.
hidden_size
,
3
*
self
.
hidden_size
,
bias
=
True
,
gather_output
=
False
,
self
.
c_attn
=
ColumnParallelLinear
(
self
.
hidden_size
,
3
*
self
.
hidden_size
,
bias
=
True
,
gather_output
=
False
,
perform_initialization
=
False
)
self
.
c_proj
=
RowParallelLinear
(
self
.
hidden_size
,
self
.
hidden_size
,
bias
=
True
,
input_is_parallel
=
True
,
self
.
c_proj
=
RowParallelLinear
(
self
.
hidden_size
,
self
.
hidden_size
,
bias
=
True
,
input_is_parallel
=
True
,
perform_initialization
=
False
)
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_dim
,
scale
=
self
.
scale
)
def
forward
(
...
...
@@ -72,8 +79,8 @@ class GPT2Attention(nn.Module):
qkv
,
_
=
self
.
c_attn
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
key_cache
,
value_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
key_cache
,
value_cache
,
input_metadata
,
cache_event
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
key_cache
,
value_cache
,
input_metadata
,
cache_event
)
attn_output
,
_
=
self
.
c_proj
(
attn_output
)
return
attn_output
...
...
@@ -87,11 +94,15 @@ class GPT2MLP(nn.Module):
):
super
().
__init__
()
hidden_size
=
config
.
hidden_size
self
.
c_fc
=
ColumnParallelLinear
(
hidden_size
,
intermediate_size
,
bias
=
True
,
gather_output
=
False
,
self
.
c_fc
=
ColumnParallelLinear
(
hidden_size
,
intermediate_size
,
bias
=
True
,
gather_output
=
False
,
perform_initialization
=
False
)
self
.
c_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
True
,
input_is_parallel
=
True
,
self
.
c_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
True
,
input_is_parallel
=
True
,
perform_initialization
=
False
)
self
.
act
=
get_act_fn
(
config
.
activation_function
)
...
...
@@ -107,7 +118,8 @@ class GPT2Block(nn.Module):
def
__init__
(
self
,
config
:
GPT2Config
):
super
().
__init__
()
hidden_size
=
config
.
hidden_size
inner_dim
=
config
.
n_inner
if
config
.
n_inner
is
not
None
else
4
*
hidden_size
inner_dim
=
(
config
.
n_inner
if
config
.
n_inner
is
not
None
else
4
*
hidden_size
)
self
.
ln_1
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
attn
=
GPT2Attention
(
config
)
...
...
@@ -145,9 +157,9 @@ class GPT2Model(nn.Module):
def
__init__
(
self
,
config
:
GPT2Config
):
super
().
__init__
()
self
.
config
=
config
assert
config
.
add_cross_attention
==
False
assert
config
.
scale_attn_by_inverse_layer_idx
==
False
assert
config
.
reorder_and_upcast_attn
==
False
assert
not
config
.
add_cross_attention
assert
not
config
.
scale_attn_by_inverse_layer_idx
assert
not
config
.
reorder_and_upcast_attn
self
.
embed_dim
=
config
.
hidden_size
# Optimization: While the vocab size of GPT-2 is 50257, we extend it
...
...
@@ -180,8 +192,8 @@ class GPT2Model(nn.Module):
else
:
cache_event
=
cache_events
[
i
]
layer
=
self
.
h
[
i
]
hidden_states
=
layer
(
hidden_states
,
kv_caches
[
i
],
input_metadata
,
cache_event
)
hidden_states
=
layer
(
hidden_states
,
kv_caches
[
i
],
input_metadata
,
cache_event
)
hidden_states
=
self
.
ln_f
(
hidden_states
)
return
hidden_states
...
...
@@ -206,24 +218,26 @@ class GPT2LMHeadModel(nn.Module):
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
Dict
[
int
,
SequenceOutputs
]:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
input_metadata
)
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
input_metadata
)
return
next_tokens
_column_parallel_weights
=
[
"wte.weight"
,
"c_fc.weight"
,
"c_fc.bias"
]
_row_parallel_weights
=
[
"c_proj.weight"
]
def
load_weights
(
self
,
model_name_or_path
:
str
,
def
load_weights
(
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
use_np_cache
:
bool
=
False
):
tensor_model_parallel_world_size
=
get_tensor_model_parallel_world_size
()
tensor_model_parallel_world_size
=
(
get_tensor_model_parallel_world_size
())
tensor_model_parallel_rank
=
get_tensor_model_parallel_rank
()
state_dict
=
self
.
state_dict
()
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
use_np_cache
):
model_name_or_path
,
cache_dir
,
use_np_cache
):
if
"lm_head.weight"
in
name
:
# GPT-2 ties the weights of the embedding layer and the final
# linear layer.
...
...
@@ -248,16 +262,20 @@ class GPT2LMHeadModel(nn.Module):
if
name
==
"transformer.wte.weight"
:
# Consider padding in the vocab size.
padded_vocab_size
=
param
.
shape
[
0
]
*
tensor_model_parallel_world_size
padded_vocab_size
=
(
param
.
shape
[
0
]
*
tensor_model_parallel_world_size
)
num_extra_rows
=
padded_vocab_size
-
self
.
config
.
vocab_size
extra_rows
=
torch
.
empty
(
num_extra_rows
,
loaded_weight
.
shape
[
1
])
extra_rows
=
torch
.
empty
(
num_extra_rows
,
loaded_weight
.
shape
[
1
])
extra_rows
=
extra_rows
.
to
(
loaded_weight
)
loaded_weight
=
torch
.
cat
([
loaded_weight
,
extra_rows
],
dim
=
0
)
# For the fused QKV linear layer, manually shard the weights.
if
"c_attn"
in
name
:
# GPT-2's fused QKV has the shape of [3 * num_heads * head_size, hidden_size].
# When tensor parallelism is used, we shard the weights along the head dimension.
# GPT-2's fused QKV has the shape of
# [3 * num_heads * head_size, hidden_size].
# When tensor parallelism is used, we shard the weights along
# the head dimension.
total_num_heads
=
self
.
config
.
num_attention_heads
hidden_size
=
self
.
config
.
hidden_size
head_size
=
hidden_size
//
total_num_heads
...
...
@@ -266,11 +284,13 @@ class GPT2LMHeadModel(nn.Module):
head_end
=
(
tensor_model_parallel_rank
+
1
)
*
num_heads
if
name
.
endswith
(
".weight"
):
loaded_weight
=
loaded_weight
.
view
(
3
,
total_num_heads
,
head_size
,
hidden_size
)
loaded_weight
=
loaded_weight
.
view
(
3
,
total_num_heads
,
head_size
,
hidden_size
)
loaded_weight
=
loaded_weight
[:,
head_start
:
head_end
,
:,
:]
loaded_weight
=
loaded_weight
.
reshape
(
-
1
,
hidden_size
)
elif
name
.
endswith
(
".bias"
):
loaded_weight
=
loaded_weight
.
view
(
3
,
total_num_heads
,
head_size
)
loaded_weight
=
loaded_weight
.
view
(
3
,
total_num_heads
,
head_size
)
loaded_weight
=
loaded_weight
[:,
head_start
:
head_end
,
:]
loaded_weight
=
loaded_weight
.
reshape
(
-
1
)
else
:
...
...
vllm/model_executor/models/gpt_bigcode.py
View file @
d6fa1be3
# coding=utf-8
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
# Copyright 2023 The vLLM team.
# Copyright 2023 CTranslate2, and Michael Feil
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
...
...
@@ -49,19 +50,25 @@ class GPTBigCodeAttention(nn.Module):
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
total_num_heads
=
config
.
num_attention_heads
tensor_model_parallel_world_size
=
get_tensor_model_parallel_world_size
()
tensor_model_parallel_world_size
=
(
get_tensor_model_parallel_world_size
())
assert
total_num_heads
%
tensor_model_parallel_world_size
==
0
self
.
num_heads
=
total_num_heads
//
tensor_model_parallel_world_size
self
.
head_dim
=
self
.
hidden_size
//
total_num_heads
self
.
scale
=
self
.
head_dim
**
-
0.5
self
.
scale
=
self
.
head_dim
**-
0.5
self
.
c_attn
=
ColumnParallelLinear
(
self
.
hidden_size
,
3
*
self
.
hidden_size
,
bias
=
True
,
gather_output
=
False
,
self
.
c_attn
=
ColumnParallelLinear
(
self
.
hidden_size
,
3
*
self
.
hidden_size
,
bias
=
True
,
gather_output
=
False
,
perform_initialization
=
False
)
self
.
c_proj
=
RowParallelLinear
(
self
.
hidden_size
,
self
.
hidden_size
,
bias
=
True
,
input_is_parallel
=
True
,
self
.
c_proj
=
RowParallelLinear
(
self
.
hidden_size
,
self
.
hidden_size
,
bias
=
True
,
input_is_parallel
=
True
,
perform_initialization
=
False
)
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_dim
,
scale
=
self
.
scale
)
def
forward
(
...
...
@@ -74,8 +81,8 @@ class GPTBigCodeAttention(nn.Module):
qkv
,
_
=
self
.
c_attn
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
key_cache
,
value_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
key_cache
,
value_cache
,
input_metadata
,
cache_event
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
key_cache
,
value_cache
,
input_metadata
,
cache_event
)
attn_output
,
_
=
self
.
c_proj
(
attn_output
)
return
attn_output
...
...
@@ -89,11 +96,15 @@ class GPTBigMLP(nn.Module):
):
super
().
__init__
()
hidden_size
=
config
.
hidden_size
self
.
c_fc
=
ColumnParallelLinear
(
hidden_size
,
intermediate_size
,
bias
=
True
,
gather_output
=
False
,
self
.
c_fc
=
ColumnParallelLinear
(
hidden_size
,
intermediate_size
,
bias
=
True
,
gather_output
=
False
,
perform_initialization
=
False
)
self
.
c_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
True
,
input_is_parallel
=
True
,
self
.
c_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
True
,
input_is_parallel
=
True
,
perform_initialization
=
False
)
self
.
act
=
get_act_fn
(
config
.
activation_function
)
...
...
@@ -109,7 +120,8 @@ class GPTBigCodeBlock(nn.Module):
def
__init__
(
self
,
config
:
GPTBigCodeConfig
):
super
().
__init__
()
hidden_size
=
config
.
hidden_size
inner_dim
=
config
.
n_inner
if
config
.
n_inner
is
not
None
else
4
*
hidden_size
inner_dim
=
(
config
.
n_inner
if
config
.
n_inner
is
not
None
else
4
*
hidden_size
)
self
.
ln_1
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
attn
=
GPTBigCodeAttention
(
config
)
...
...
@@ -147,7 +159,7 @@ class GPTBigCodeModel(nn.Module):
def
__init__
(
self
,
config
:
GPTBigCodeConfig
):
super
().
__init__
()
self
.
config
=
config
assert
config
.
add_cross_attention
==
False
assert
not
config
.
add_cross_attention
self
.
embed_dim
=
config
.
hidden_size
...
...
@@ -181,8 +193,8 @@ class GPTBigCodeModel(nn.Module):
else
:
cache_event
=
cache_events
[
i
]
layer
=
self
.
h
[
i
]
hidden_states
=
layer
(
hidden_states
,
kv_caches
[
i
],
input_metadata
,
cache_event
)
hidden_states
=
layer
(
hidden_states
,
kv_caches
[
i
],
input_metadata
,
cache_event
)
hidden_states
=
self
.
ln_f
(
hidden_states
)
return
hidden_states
...
...
@@ -207,24 +219,26 @@ class GPTBigCodeForCausalLM(nn.Module):
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
Dict
[
int
,
SequenceOutputs
]:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
input_metadata
)
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
input_metadata
)
return
next_tokens
_column_parallel_weights
=
[
"wte.weight"
,
"c_fc.weight"
,
"c_fc.bias"
]
_row_parallel_weights
=
[
"c_proj.weight"
]
def
load_weights
(
self
,
model_name_or_path
:
str
,
def
load_weights
(
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
use_np_cache
:
bool
=
False
):
tensor_model_parallel_world_size
=
get_tensor_model_parallel_world_size
()
tensor_model_parallel_world_size
=
(
get_tensor_model_parallel_world_size
())
tensor_model_parallel_rank
=
get_tensor_model_parallel_rank
()
state_dict
=
self
.
state_dict
()
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
use_np_cache
):
model_name_or_path
,
cache_dir
,
use_np_cache
):
if
"lm_head.weight"
in
name
:
# GPT-2 ties the weights of the embedding layer and the final
# linear layer.
...
...
@@ -241,9 +255,11 @@ class GPTBigCodeForCausalLM(nn.Module):
if
name
==
"transformer.wte.weight"
:
# Consider padding in the vocab size.
padded_vocab_size
=
param
.
shape
[
0
]
*
tensor_model_parallel_world_size
padded_vocab_size
=
param
.
shape
[
0
]
*
tensor_model_parallel_world_size
num_extra_rows
=
padded_vocab_size
-
self
.
config
.
vocab_size
extra_rows
=
torch
.
empty
(
num_extra_rows
,
loaded_weight
.
shape
[
1
])
extra_rows
=
torch
.
empty
(
num_extra_rows
,
loaded_weight
.
shape
[
1
])
extra_rows
=
extra_rows
.
to
(
loaded_weight
)
loaded_weight
=
torch
.
cat
([
loaded_weight
,
extra_rows
],
dim
=
0
)
...
...
@@ -258,25 +274,31 @@ class GPTBigCodeForCausalLM(nn.Module):
qkv_array
=
qkv_array
.
numpy
()
dims_q
=
n_head
*
head_dim
q
,
k
,
v
=
np
.
split
(
qkv_array
,
(
dims_q
,
dims_q
+
head_dim
),
axis
=
0
)
# q is fine, but k & v have not replicated shape along the first axis
# as long as MQA is not nativly supported, increase memory and replicated
# (head_dim, hidden_dim) to (n_heads * head_dim, hidden_dim)
# pylint: disable=unbalanced-tuple-unpacking
q
,
k
,
v
=
np
.
split
(
qkv_array
,
(
dims_q
,
dims_q
+
head_dim
),
axis
=
0
)
# q is fine, but k & v have not replicated shape along the first
# axis as long as MQA is not nativly supported, increase memory
# and replicated (head_dim, hidden_dim) to
# (n_heads * head_dim, hidden_dim)
if
k
.
ndim
==
2
and
v
.
ndim
==
2
:
replication
=
(
n_head
,
1
)
# weights
else
:
replication
=
n_head
# biases
# replicate n_head times for q, v
k
,
v
=
np
.
tile
(
k
,
replication
),
np
.
tile
(
v
,
replication
)
# concat q, k, v along the first axis (n_heads * head_dim, hidden_dim)
# concat q, k, v along the first axis
# (n_heads * head_dim, hidden_dim)
# to (3 * n_heads * head_dim, hidden_dim)
qkv_array
=
np
.
concatenate
((
q
,
k
,
v
),
axis
=
0
)
return
torch
.
from_numpy
(
qkv_array
)
# For the fused QKV linear layer, manually shard the weights.
if
"c_attn"
in
name
:
# GPT-2's fused QKV has the shape of [3 * num_heads * head_size, hidden_size].
# When tensor parallelism is used, we shard the weights along the head dimension.
# GPT-2's fused QKV has the shape of
# [3 * num_heads * head_size, hidden_size].
# When tensor parallelism is used, we shard the weights along
# the head dimension.
total_num_heads
=
self
.
config
.
num_attention_heads
hidden_size
=
self
.
config
.
hidden_size
head_size
=
hidden_size
//
total_num_heads
...
...
@@ -285,13 +307,19 @@ class GPTBigCodeForCausalLM(nn.Module):
head_end
=
(
tensor_model_parallel_rank
+
1
)
*
num_heads
if
name
.
endswith
(
".weight"
):
loaded_weight
=
_expand_mqa_mha
(
loaded_weight
,
n_head
=
total_num_heads
,
head_dim
=
head_size
)
loaded_weight
=
loaded_weight
.
view
(
3
,
total_num_heads
,
head_size
,
hidden_size
)
loaded_weight
=
_expand_mqa_mha
(
loaded_weight
,
n_head
=
total_num_heads
,
head_dim
=
head_size
)
loaded_weight
=
loaded_weight
.
view
(
3
,
total_num_heads
,
head_size
,
hidden_size
)
loaded_weight
=
loaded_weight
[:,
head_start
:
head_end
,
:,
:]
loaded_weight
=
loaded_weight
.
reshape
(
-
1
,
hidden_size
)
elif
name
.
endswith
(
".bias"
):
loaded_weight
=
_expand_mqa_mha
(
loaded_weight
,
n_head
=
total_num_heads
,
head_dim
=
head_size
)
loaded_weight
=
loaded_weight
.
view
(
3
,
total_num_heads
,
head_size
)
loaded_weight
=
_expand_mqa_mha
(
loaded_weight
,
n_head
=
total_num_heads
,
head_dim
=
head_size
)
loaded_weight
=
loaded_weight
.
view
(
3
,
total_num_heads
,
head_size
)
loaded_weight
=
loaded_weight
[:,
head_start
:
head_end
,
:]
loaded_weight
=
loaded_weight
.
reshape
(
-
1
)
else
:
...
...
vllm/model_executor/models/gpt_neox.py
View file @
d6fa1be3
# coding=utf-8
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved.
#
...
...
@@ -48,19 +49,23 @@ class GPTNeoXAttention(nn.Module):
self
.
hidden_size
=
config
.
hidden_size
self
.
head_size
=
self
.
hidden_size
//
self
.
total_num_heads
tensor_model_parallel_world_size
=
get_tensor_model_parallel_world_size
()
tensor_model_parallel_world_size
=
(
get_tensor_model_parallel_world_size
())
assert
self
.
total_num_heads
%
tensor_model_parallel_world_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tensor_model_parallel_world_size
self
.
query_key_value
=
ColumnParallelLinear
(
config
.
hidden_size
,
3
*
config
.
hidden_size
,
gather_output
=
False
,
perform_initialization
=
False
)
self
.
dense
=
RowParallelLinear
(
config
.
hidden_size
,
config
.
hidden_size
,
self
.
num_heads
=
(
self
.
total_num_heads
//
tensor_model_parallel_world_size
)
self
.
query_key_value
=
ColumnParallelLinear
(
config
.
hidden_size
,
3
*
config
.
hidden_size
,
gather_output
=
False
,
perform_initialization
=
False
)
self
.
dense
=
RowParallelLinear
(
config
.
hidden_size
,
config
.
hidden_size
,
input_is_parallel
=
True
,
perform_initialization
=
False
)
scaling
=
self
.
head_size
**
-
0.5
scaling
=
self
.
head_size
**-
0.5
rotary_dim
=
int
(
self
.
head_size
*
config
.
rotary_pct
)
assert
rotary_dim
%
2
==
0
self
.
attn
=
PagedAttentionWithRoPE
(
self
.
num_heads
,
self
.
head_size
,
...
...
@@ -78,8 +83,8 @@ class GPTNeoXAttention(nn.Module):
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
position_ids
,
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
cache_event
)
attn_output
=
self
.
attn
(
position_ids
,
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
cache_event
)
output
,
_
=
self
.
dense
(
attn_output
)
return
output
...
...
@@ -92,7 +97,8 @@ class GPTNeoXMLP(nn.Module):
config
.
intermediate_size
,
gather_output
=
False
,
perform_initialization
=
False
)
self
.
dense_4h_to_h
=
RowParallelLinear
(
config
.
intermediate_size
,
config
.
hidden_size
,
self
.
dense_4h_to_h
=
RowParallelLinear
(
config
.
intermediate_size
,
config
.
hidden_size
,
input_is_parallel
=
True
,
perform_initialization
=
False
)
self
.
act
=
get_act_fn
(
config
.
hidden_act
)
...
...
@@ -109,8 +115,10 @@ class GPTNeoXLayer(nn.Module):
def
__init__
(
self
,
config
:
GPTNeoXConfig
):
super
().
__init__
()
self
.
use_parallel_residual
=
config
.
use_parallel_residual
self
.
input_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
self
.
post_attention_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
self
.
input_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
self
.
post_attention_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
self
.
attention
=
GPTNeoXAttention
(
config
)
self
.
mlp
=
GPTNeoXMLP
(
config
)
...
...
@@ -154,10 +162,13 @@ class GPTNeoXModel(nn.Module):
super
().
__init__
()
self
.
config
=
config
self
.
embed_in
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
self
.
embed_in
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
perform_initialization
=
False
)
self
.
layers
=
nn
.
ModuleList
([
GPTNeoXLayer
(
config
)
for
_
in
range
(
config
.
num_hidden_layers
)])
self
.
final_layer_norm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
self
.
layers
=
nn
.
ModuleList
(
[
GPTNeoXLayer
(
config
)
for
_
in
range
(
config
.
num_hidden_layers
)])
self
.
final_layer_norm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
def
forward
(
self
,
...
...
@@ -191,8 +202,10 @@ class GPTNeoXForCausalLM(nn.Module):
super
().
__init__
()
self
.
config
=
config
self
.
gpt_neox
=
GPTNeoXModel
(
config
)
self
.
embed_out
=
ColumnParallelLinear
(
config
.
hidden_size
,
config
.
vocab_size
,
bias
=
False
,
gather_output
=
False
,
self
.
embed_out
=
ColumnParallelLinear
(
config
.
hidden_size
,
config
.
vocab_size
,
bias
=
False
,
gather_output
=
False
,
perform_initialization
=
False
)
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
...
...
@@ -204,24 +217,28 @@ class GPTNeoXForCausalLM(nn.Module):
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
Dict
[
int
,
SequenceOutputs
]:
hidden_states
=
self
.
gpt_neox
(
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
next_tokens
=
self
.
sampler
(
self
.
embed_out
.
weight
,
hidden_states
,
input_metadata
)
hidden_states
=
self
.
gpt_neox
(
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
next_tokens
=
self
.
sampler
(
self
.
embed_out
.
weight
,
hidden_states
,
input_metadata
)
return
next_tokens
_column_parallel_weights
=
[
"embed_in.weight"
,
"embed_out.weight"
,
"dense_h_to_4h.weight"
,
"dense_h_to_4h.bias"
]
_column_parallel_weights
=
[
"embed_in.weight"
,
"embed_out.weight"
,
"dense_h_to_4h.weight"
,
"dense_h_to_4h.bias"
]
_row_parallel_weights
=
[
"dense.weight"
,
"dense_4h_to_h.weight"
]
def
load_weights
(
self
,
model_name_or_path
:
str
,
def
load_weights
(
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
use_np_cache
:
bool
=
False
):
tensor_model_parallel_rank
=
get_tensor_model_parallel_rank
()
state_dict
=
self
.
state_dict
()
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
use_np_cache
):
model_name_or_path
,
cache_dir
,
use_np_cache
):
if
(
"attention.bias"
in
name
or
"attention.masked_bias"
in
name
or
"rotary_emb.inv_freq"
in
name
):
or
"rotary_emb.inv_freq"
in
name
):
continue
param
=
state_dict
[
name
]
if
"query_key_value"
in
name
:
...
...
@@ -230,17 +247,19 @@ class GPTNeoXForCausalLM(nn.Module):
# required shape is [3 * num_heads * head_size, hidden_size].
# Thus, we need weight conversion.
shard_size
=
param
.
shape
[
0
]
loaded_weight
=
loaded_weight
[
shard_size
*
tensor_model_parallel_rank
:
shard_size
*
(
tensor_model_parallel_rank
+
1
)]
loaded_weight
=
loaded_weight
[
shard_size
*
tensor_model_parallel_rank
:
shard_size
*
(
tensor_model_parallel_rank
+
1
)]
num_heads
=
self
.
config
.
num_attention_heads
hidden_size
=
self
.
config
.
hidden_size
head_size
=
hidden_size
//
num_heads
if
'query_key_value.weight'
in
name
:
loaded_weight
=
loaded_weight
.
view
(
-
1
,
3
,
head_size
,
hidden_size
)
if
"query_key_value.weight"
in
name
:
loaded_weight
=
loaded_weight
.
view
(
-
1
,
3
,
head_size
,
hidden_size
)
loaded_weight
=
loaded_weight
.
transpose
(
0
,
1
)
loaded_weight
=
loaded_weight
.
reshape
(
-
1
,
hidden_size
)
elif
'
query_key_value.bias
'
in
name
:
elif
"
query_key_value.bias
"
in
name
:
loaded_weight
=
loaded_weight
.
view
(
-
1
,
3
,
head_size
)
loaded_weight
=
loaded_weight
.
transpose
(
0
,
1
)
loaded_weight
=
loaded_weight
.
reshape
(
-
1
)
...
...
vllm/model_executor/models/llama.py
View file @
d6fa1be3
# coding=utf-8
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
...
...
@@ -30,7 +31,6 @@ import torch
from
torch
import
nn
from
transformers
import
LlamaConfig
from
vllm.sequence
import
SequenceOutputs
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
...
...
@@ -56,15 +56,19 @@ class LlamaMLP(nn.Module):
hidden_act
:
str
,
):
super
().
__init__
()
self
.
gate_up_proj
=
ColumnParallelLinear
(
hidden_size
,
2
*
intermediate_size
,
bias
=
False
,
gather_output
=
False
,
self
.
gate_up_proj
=
ColumnParallelLinear
(
hidden_size
,
2
*
intermediate_size
,
bias
=
False
,
gather_output
=
False
,
perform_initialization
=
False
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
input_is_parallel
=
True
,
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
input_is_parallel
=
True
,
perform_initialization
=
False
)
if
hidden_act
!=
'
silu
'
:
raise
ValueError
(
f
'
Unsupported activation:
{
hidden_act
}
.
'
'
Only silu is supported for now.
'
)
if
hidden_act
!=
"
silu
"
:
raise
ValueError
(
f
"
Unsupported activation:
{
hidden_act
}
.
"
"
Only silu is supported for now.
"
)
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
):
...
...
@@ -83,12 +87,14 @@ class LlamaAttention(nn.Module):
):
super
().
__init__
()
self
.
hidden_size
=
hidden_size
tensor_model_parallel_world_size
=
get_tensor_model_parallel_world_size
()
tensor_model_parallel_world_size
=
(
get_tensor_model_parallel_world_size
())
self
.
total_num_heads
=
num_heads
assert
self
.
total_num_heads
%
tensor_model_parallel_world_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tensor_model_parallel_world_size
self
.
num_heads
=
(
self
.
total_num_heads
//
tensor_model_parallel_world_size
)
self
.
head_dim
=
hidden_size
//
self
.
total_num_heads
self
.
scaling
=
self
.
head_dim
**
-
0.5
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
qkv_proj
=
ColumnParallelLinear
(
hidden_size
,
...
...
@@ -104,8 +110,10 @@ class LlamaAttention(nn.Module):
input_is_parallel
=
True
,
perform_initialization
=
False
,
)
self
.
attn
=
PagedAttentionWithRoPE
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
rotary_dim
=
self
.
head_dim
)
self
.
attn
=
PagedAttentionWithRoPE
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
rotary_dim
=
self
.
head_dim
)
def
forward
(
self
,
...
...
@@ -118,8 +126,8 @@ class LlamaAttention(nn.Module):
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
positions
,
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
cache_event
)
attn_output
=
self
.
attn
(
positions
,
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
cache_event
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
@@ -138,8 +146,10 @@ class LlamaDecoderLayer(nn.Module):
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
...
...
@@ -177,9 +187,13 @@ class LlamaModel(nn.Module):
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
perform_initialization
=
False
)
self
.
layers
=
nn
.
ModuleList
([
LlamaDecoderLayer
(
config
)
for
_
in
range
(
config
.
num_hidden_layers
)])
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
perform_initialization
=
False
)
self
.
layers
=
nn
.
ModuleList
([
LlamaDecoderLayer
(
config
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
...
...
@@ -209,6 +223,7 @@ class LlamaModel(nn.Module):
class
LlamaForCausalLM
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
config
=
config
...
...
@@ -228,39 +243,42 @@ class LlamaForCausalLM(nn.Module):
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
Dict
[
int
,
SequenceOutputs
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
input_metadata
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
input_metadata
)
return
next_tokens
_column_parallel_weights
=
[
"embed_tokens.weight"
,
"lm_head.weight"
,
"qkv_proj.weight"
,
"gate_proj.weight"
,
"up_proj.weight"
]
_column_parallel_weights
=
[
"embed_tokens.weight"
,
"lm_head.weight"
,
"qkv_proj.weight"
,
"gate_proj.weight"
,
"up_proj.weight"
]
_row_parallel_weights
=
[
"o_proj.weight"
,
"down_proj.weight"
]
def
load_weights
(
self
,
model_name_or_path
:
str
,
def
load_weights
(
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
use_np_cache
:
bool
=
False
):
tensor_model_parallel_rank
=
get_tensor_model_parallel_rank
()
state_dict
=
self
.
state_dict
()
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
use_np_cache
):
model_name_or_path
,
cache_dir
,
use_np_cache
):
if
"rotary_emb.inv_freq"
in
name
:
continue
is_attention_weight
=
False
for
stride_id
,
att_weight_name
in
enumerate
([
"q_proj"
,
"k_proj"
,
"v_proj"
]):
for
stride_id
,
att_weight_name
in
enumerate
(
[
"q_proj"
,
"k_proj"
,
"v_proj"
]):
if
att_weight_name
not
in
name
:
continue
param
=
state_dict
[
name
.
replace
(
att_weight_name
,
"qkv_proj"
)]
shard_size
=
param
.
shape
[
0
]
//
3
loaded_weight
=
loaded_weight
[
shard_size
*
tensor_model_parallel_rank
:
shard_size
*
(
tensor_model_parallel_rank
+
1
)]
param_slice
=
param
.
data
[
shard_size
*
stride_id
:
shard_size
*
(
stride_id
+
1
)]
shard_size
*
tensor_model_parallel_rank
:
shard_size
*
(
tensor_model_parallel_rank
+
1
)]
param_slice
=
param
.
data
[
shard_size
*
stride_id
:
shard_size
*
(
stride_id
+
1
)]
assert
param_slice
.
shape
==
loaded_weight
.
shape
param_slice
.
copy_
(
loaded_weight
)
is_attention_weight
=
True
...
...
@@ -275,10 +293,10 @@ class LlamaForCausalLM(nn.Module):
param
=
state_dict
[
name
.
replace
(
weight_name
,
"gate_up_proj"
)]
shard_size
=
param
.
shape
[
0
]
//
2
loaded_weight
=
loaded_weight
[
shard_size
*
tensor_model_parallel_rank
:
shard_size
*
(
tensor_model_parallel_rank
+
1
)]
param_slice
=
param
.
data
[
shard_size
*
stride_id
:
shard_size
*
(
stride_id
+
1
)]
shard_size
*
tensor_model_parallel_rank
:
shard_size
*
(
tensor_model_parallel_rank
+
1
)]
param_slice
=
param
.
data
[
shard_size
*
stride_id
:
shard_size
*
(
stride_id
+
1
)]
assert
param_slice
.
shape
==
loaded_weight
.
shape
param_slice
.
copy_
(
loaded_weight
)
is_gate_up_weight
=
True
...
...
vllm/model_executor/models/opt.py
View file @
d6fa1be3
# coding=utf-8
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/opt/modeling_opt.py
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/opt/modeling_opt.py
# Copyright 2023 The vLLM team.
# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights
# reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -43,8 +45,9 @@ KVCache = Tuple[torch.Tensor, torch.Tensor]
class
OPTLearnedPositionalEmbedding
(
nn
.
Embedding
):
def
__init__
(
self
,
num_embeddings
:
int
,
embedding_dim
:
int
):
# OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models don't have this hack
# OPT is set up so that if padding_idx is specified then offset the
# embedding ids by 2 and adjust num_embeddings appropriately. Other
# models don't have this hack
self
.
offset
=
2
super
().
__init__
(
num_embeddings
+
self
.
offset
,
embedding_dim
)
...
...
@@ -62,20 +65,26 @@ class OPTAttention(nn.Module):
)
->
None
:
super
().
__init__
()
self
.
embed_dim
=
embed_dim
tensor_model_parallel_world_size
=
get_tensor_model_parallel_world_size
()
tensor_model_parallel_world_size
=
(
get_tensor_model_parallel_world_size
())
total_num_heads
=
num_heads
assert
num_heads
%
tensor_model_parallel_world_size
==
0
self
.
num_heads
=
total_num_heads
//
tensor_model_parallel_world_size
self
.
head_dim
=
embed_dim
//
total_num_heads
self
.
scaling
=
self
.
head_dim
**
-
0.5
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
qkv_proj
=
ColumnParallelLinear
(
embed_dim
,
3
*
embed_dim
,
bias
=
bias
,
self
.
qkv_proj
=
ColumnParallelLinear
(
embed_dim
,
3
*
embed_dim
,
bias
=
bias
,
gather_output
=
False
,
perform_initialization
=
False
)
self
.
out_proj
=
RowParallelLinear
(
embed_dim
,
embed_dim
,
bias
=
bias
,
self
.
out_proj
=
RowParallelLinear
(
embed_dim
,
embed_dim
,
bias
=
bias
,
input_is_parallel
=
True
,
perform_initialization
=
False
)
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_dim
,
scale
=
self
.
scaling
)
def
forward
(
...
...
@@ -88,8 +97,8 @@ class OPTAttention(nn.Module):
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
key_cache
,
value_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
key_cache
,
value_cache
,
input_metadata
,
cache_event
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
key_cache
,
value_cache
,
input_metadata
,
cache_event
)
output
,
_
=
self
.
out_proj
(
attn_output
)
return
output
...
...
@@ -109,17 +118,21 @@ class OPTDecoderLayer(nn.Module):
self
.
activation_fn
=
get_act_fn
(
config
.
activation_function
)
self
.
self_attn_layer_norm
=
nn
.
LayerNorm
(
self
.
embed_dim
,
elementwise_affine
=
config
.
layer_norm_elementwise_affine
)
self
.
fc1
=
ColumnParallelLinear
(
self
.
embed_dim
,
config
.
ffn_dim
,
self
.
embed_dim
,
elementwise_affine
=
config
.
layer_norm_elementwise_affine
)
self
.
fc1
=
ColumnParallelLinear
(
self
.
embed_dim
,
config
.
ffn_dim
,
bias
=
config
.
enable_bias
,
gather_output
=
False
,
perform_initialization
=
False
)
self
.
fc2
=
RowParallelLinear
(
config
.
ffn_dim
,
self
.
embed_dim
,
self
.
fc2
=
RowParallelLinear
(
config
.
ffn_dim
,
self
.
embed_dim
,
bias
=
config
.
enable_bias
,
input_is_parallel
=
True
,
perform_initialization
=
False
)
self
.
final_layer_norm
=
nn
.
LayerNorm
(
self
.
embed_dim
,
elementwise_affine
=
config
.
layer_norm_elementwise_affine
)
self
.
embed_dim
,
elementwise_affine
=
config
.
layer_norm_elementwise_affine
)
def
forward
(
self
,
...
...
@@ -133,11 +146,10 @@ class OPTDecoderLayer(nn.Module):
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
if
self
.
do_layer_norm_before
:
hidden_states
=
self
.
self_attn_layer_norm
(
hidden_states
)
hidden_states
=
self
.
self_attn
(
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
input_metadata
=
input_metadata
,
cache_event
=
cache_event
)
hidden_states
=
self
.
self_attn
(
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
input_metadata
=
input_metadata
,
cache_event
=
cache_event
)
hidden_states
=
residual
+
hidden_states
# 350m applies layer norm AFTER attention
if
not
self
.
do_layer_norm_before
:
...
...
@@ -167,35 +179,42 @@ class OPTDecoder(nn.Module):
self
.
max_target_positions
=
config
.
max_position_embeddings
self
.
vocab_size
=
config
.
vocab_size
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
word_embed_proj_dim
,
perform_initialization
=
False
)
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
word_embed_proj_dim
,
perform_initialization
=
False
)
# Positional embeddings are replicated (not sharded).
self
.
embed_positions
=
OPTLearnedPositionalEmbedding
(
config
.
max_position_embeddings
,
config
.
hidden_size
)
# Project out & in will be replicated if they exist.
if
config
.
word_embed_proj_dim
!=
config
.
hidden_size
:
self
.
project_out
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
word_embed_proj_dim
,
bias
=
False
)
self
.
project_out
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
word_embed_proj_dim
,
bias
=
False
)
else
:
self
.
project_out
=
None
if
config
.
word_embed_proj_dim
!=
config
.
hidden_size
:
self
.
project_in
=
nn
.
Linear
(
config
.
word_embed_proj_dim
,
config
.
hidden_size
,
bias
=
False
)
self
.
project_in
=
nn
.
Linear
(
config
.
word_embed_proj_dim
,
config
.
hidden_size
,
bias
=
False
)
else
:
self
.
project_in
=
None
# Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
# with checkpoints that have been fine-tuned before transformers v4.20.1
# Note that the only purpose of `config._remove_final_layer_norm` is to
# keep backward compatibility with checkpoints that have been fine-tuned
# before transformers v4.20.1
# see https://github.com/facebookresearch/metaseq/pull/164
if
config
.
do_layer_norm_before
and
not
config
.
_remove_final_layer_norm
:
self
.
final_layer_norm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
elementwise_affine
=
config
.
layer_norm_elementwise_affine
)
config
.
hidden_size
,
elementwise_affine
=
config
.
layer_norm_elementwise_affine
)
else
:
self
.
final_layer_norm
=
None
self
.
layers
=
nn
.
ModuleList
([
OPTDecoderLayer
(
config
)
for
_
in
range
(
config
.
num_hidden_layers
)])
self
.
layers
=
nn
.
ModuleList
(
[
OPTDecoderLayer
(
config
)
for
_
in
range
(
config
.
num_hidden_layers
)])
def
forward
(
self
,
...
...
@@ -217,8 +236,8 @@ class OPTDecoder(nn.Module):
else
:
cache_event
=
cache_events
[
i
]
layer
=
self
.
layers
[
i
]
hidden_states
=
layer
(
hidden_states
,
kv_caches
[
i
],
input_metadata
,
cache_event
)
hidden_states
=
layer
(
hidden_states
,
kv_caches
[
i
],
input_metadata
,
cache_event
)
if
self
.
final_layer_norm
is
not
None
:
hidden_states
=
self
.
final_layer_norm
(
hidden_states
)
...
...
@@ -241,8 +260,8 @@ class OPTModel(nn.Module):
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
torch
.
Tensor
:
return
self
.
decoder
(
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
return
self
.
decoder
(
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
class
OPTForCausalLM
(
nn
.
Module
):
...
...
@@ -264,23 +283,26 @@ class OPTForCausalLM(nn.Module):
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
Dict
[
int
,
SequenceOutputs
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
input_metadata
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
input_metadata
)
return
next_tokens
_column_parallel_weights
=
[
"embed_tokens.weight"
,
"fc1.weight"
,
"fc1.bias"
]
_column_parallel_weights
=
[
"embed_tokens.weight"
,
"fc1.weight"
,
"fc1.bias"
]
_row_parallel_weights
=
[
"out_proj.weight"
,
"fc2.weight"
]
def
load_weights
(
self
,
model_name_or_path
:
str
,
def
load_weights
(
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
use_np_cache
:
bool
=
False
):
tensor_model_parallel_rank
=
get_tensor_model_parallel_rank
()
state_dict
=
self
.
state_dict
()
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
use_np_cache
):
model_name_or_path
,
cache_dir
,
use_np_cache
):
if
"lm_head.weight"
in
name
:
continue
...
...
@@ -288,16 +310,17 @@ class OPTForCausalLM(nn.Module):
name
=
"model."
+
name
is_attention_weight
=
False
for
stride_id
,
att_weight_name
in
enumerate
([
"q_proj"
,
"k_proj"
,
"v_proj"
]):
for
stride_id
,
att_weight_name
in
enumerate
(
[
"q_proj"
,
"k_proj"
,
"v_proj"
]):
if
att_weight_name
not
in
name
:
continue
param
=
state_dict
[
name
.
replace
(
att_weight_name
,
"qkv_proj"
)]
shard_size
=
param
.
shape
[
0
]
//
3
loaded_weight
=
loaded_weight
[
shard_size
*
tensor_model_parallel_rank
:
shard_size
*
(
tensor_model_parallel_rank
+
1
)]
param_slice
=
param
.
data
[
shard_size
*
stride_id
:
shard_size
*
(
stride_id
+
1
)]
shard_size
*
tensor_model_parallel_rank
:
shard_size
*
(
tensor_model_parallel_rank
+
1
)]
param_slice
=
param
.
data
[
shard_size
*
stride_id
:
shard_size
*
(
stride_id
+
1
)]
assert
param_slice
.
shape
==
loaded_weight
.
shape
param_slice
.
copy_
(
loaded_weight
)
is_attention_weight
=
True
...
...
vllm/model_executor/weight_utils.py
View file @
d6fa1be3
...
...
@@ -44,9 +44,9 @@ def hf_model_weights_iterator(
if
use_np_cache
:
# Convert the model weights from torch tensors to numpy arrays for
# faster loading.
np_folder
=
os
.
path
.
join
(
hf_folder
,
'
np
'
)
np_folder
=
os
.
path
.
join
(
hf_folder
,
"
np
"
)
os
.
makedirs
(
np_folder
,
exist_ok
=
True
)
weight_names_file
=
os
.
path
.
join
(
np_folder
,
'
weight_names.json
'
)
weight_names_file
=
os
.
path
.
join
(
np_folder
,
"
weight_names.json
"
)
with
lock
:
if
not
os
.
path
.
exists
(
weight_names_file
):
weight_names
=
[]
...
...
@@ -57,10 +57,10 @@ def hf_model_weights_iterator(
with
open
(
param_path
,
"wb"
)
as
f
:
np
.
save
(
f
,
param
.
cpu
().
detach
().
numpy
())
weight_names
.
append
(
name
)
with
open
(
weight_names_file
,
'w'
)
as
f
:
with
open
(
weight_names_file
,
"w"
)
as
f
:
json
.
dump
(
weight_names
,
f
)
with
open
(
weight_names_file
,
'r'
)
as
f
:
with
open
(
weight_names_file
,
"r"
)
as
f
:
weight_names
=
json
.
load
(
f
)
for
name
in
weight_names
:
...
...
@@ -86,17 +86,16 @@ def load_tensor_parallel_weights(
for
p
in
column_parallel_weight_names
:
if
p
in
param_name
:
shard_size
=
param
.
shape
[
0
]
loaded_weight
=
loaded_weight
[
shard_size
*
tensor_model_parallel_rank
:
shard_size
*
(
tensor_model_parallel_rank
+
1
)
]
start_idx
=
tensor_model_parallel_rank
*
shard_size
end_idx
=
(
tensor_model_parallel_rank
+
1
)
*
shard_size
loaded_weight
=
loaded_weight
[
start_idx
:
end_idx
]
break
for
p
in
row_parallel_weight_names
:
if
p
in
param_name
:
shard_size
=
param
.
shape
[
1
]
loaded_weight
=
loaded_weight
[
:,
shard_size
*
tensor_model_parallel_rank
:
shard_size
*
(
tensor_model_parallel_rank
+
1
)]
start_idx
=
tensor_model_parallel_rank
*
shard_size
end_idx
=
(
tensor_model_parallel_rank
+
1
)
*
shard_size
loaded_weight
=
loaded_weight
[:,
start_idx
:
end_idx
]
break
assert
param
.
shape
==
loaded_weight
.
shape
,
(
f
"
{
param_name
}
shape mismatch between model and checkpoint: "
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment