Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
add055e1
"vscode:/vscode.git/clone" did not exist on "4f87756c2eb3f4dff5b8abaf043d4b0864430816"
Unverified
Commit
add055e1
authored
May 09, 2023
by
Woosuk Kwon
Committed by
GitHub
May 09, 2023
Browse files
Enhance model loader (#83)
parent
7c041ab5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
56 additions
and
42 deletions
+56
-42
cacheflow/core/server.py
cacheflow/core/server.py
+1
-1
cacheflow/model_executor/model_loader.py
cacheflow/model_executor/model_loader.py
+55
-41
No files found.
cacheflow/core/server.py
View file @
add055e1
...
...
@@ -12,8 +12,8 @@ from cacheflow.core.scheduler import Scheduler
from
cacheflow.frontend.simple_frontend
import
SimpleFrontend
from
cacheflow.logger
import
init_logger
from
cacheflow.model_executor
import
get_memory_analyzer
from
cacheflow.sequence
import
SequenceGroup
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sequence
import
SequenceGroup
from
cacheflow.utils
import
get_gpu_memory
,
get_cpu_memory
from
cacheflow.worker.controller
import
Controller
,
DeviceID
...
...
cacheflow/model_executor/model_loader.py
View file @
add055e1
...
...
@@ -14,32 +14,51 @@ from cacheflow.model_executor.utils import get_torch_dtype
from
cacheflow.model_executor.weight_utils
import
initialize_dummy_weights
_MODELS
=
{
'gpt2'
:
GPT2LMHeadModel
,
'llama'
:
LlamaForCausalLM
,
'opt'
:
OPTForCausalLM
,
'stablelm'
:
GPTNeoXForCausalLM
,
'pythia'
:
GPTNeoXForCausalLM
,
'dolly-v2'
:
GPTNeoXForCausalLM
,
# TODO(woosuk): Lazy-load the model classes.
_MODEL_REGISTRY
=
{
"GPT2LMHeadModel"
:
GPT2LMHeadModel
,
"GPTNeoXForCausalLM"
:
GPTNeoXForCausalLM
,
"LlamaForCausalLM"
:
LlamaForCausalLM
,
"OPTForCausalLM"
:
OPTForCausalLM
,
}
_MEMORY_ANALYZERS
=
{
'gpt2'
:
GPT2MemoryAnalyzer
,
'llama'
:
LlamaMemoryAnalyzer
,
'opt'
:
OPTMemoryAnalyzer
,
'stablelm'
:
GPTNeoXMemoryAnalyzer
,
'pythia'
:
GPTNeoXMemoryAnalyzer
,
'dolly-v2'
:
GPTNeoXMemoryAnalyzer
,
"GPT2LMHeadModel"
:
GPT2MemoryAnalyzer
,
"GPTNeoXForCausalLM"
:
GPTNeoXMemoryAnalyzer
,
"LlamaForCausalLM"
:
LlamaMemoryAnalyzer
,
"OPTForCausalLM"
:
OPTMemoryAnalyzer
,
}
def
_get_model_architecture
(
config
:
PretrainedConfig
)
->
nn
.
Module
:
architectures
=
getattr
(
config
,
"architectures"
,
[])
for
arch
in
architectures
:
if
arch
in
_MODEL_REGISTRY
:
return
_MODEL_REGISTRY
[
arch
]
raise
ValueError
(
f
"Model architectures
{
architectures
}
are not supported for now. "
f
"Supported architectures:
{
list
(
_MODEL_REGISTRY
.
keys
())
}
"
)
def
_get_memory_analyzer
(
config
:
PretrainedConfig
)
->
CacheFlowMemoryAnalyzer
:
architectures
=
getattr
(
config
,
"architectures"
,
[])
for
arch
in
architectures
:
if
arch
in
_MEMORY_ANALYZERS
:
return
_MEMORY_ANALYZERS
[
arch
]
raise
ValueError
(
f
"Model architectures
{
architectures
}
are not supported for now. "
f
"Supported architectures:
{
list
(
_MEMORY_ANALYZERS
.
keys
())
}
"
)
def
_get_dtype
(
config
:
PretrainedConfig
,
dtype
:
str
)
->
torch
.
dtype
:
# NOTE: getattr(config,
'
torch_dtype
'
, torch.float32) is not correct
# NOTE: getattr(config,
"
torch_dtype
"
, torch.float32) is not correct
# because config.torch_dtype can be None.
config_dtype
=
getattr
(
config
,
'
torch_dtype
'
,
None
)
config_dtype
=
getattr
(
config
,
"
torch_dtype
"
,
None
)
if
config_dtype
is
None
:
config_dtype
=
torch
.
float32
if
dtype
==
'
default
'
:
if
dtype
==
"
default
"
:
if
config_dtype
==
torch
.
float32
:
# Following the common practice, we use float16 for float32 models.
torch_dtype
=
torch
.
float16
...
...
@@ -51,7 +70,7 @@ def _get_dtype(config: PretrainedConfig, dtype: str) -> torch.dtype:
# TODO(woosuk): Allow using float16 for bfloat16 models and
# vice versa. Print a warning message and continue.
raise
ValueError
(
f
'
Cannot use
{
torch_dtype
}
for
{
config_dtype
}
model.
'
)
f
"
Cannot use
{
torch_dtype
}
for
{
config_dtype
}
model.
"
)
return
torch_dtype
...
...
@@ -65,24 +84,21 @@ def get_model(
config
=
AutoConfig
.
from_pretrained
(
model_name
)
torch_dtype
=
_get_dtype
(
config
,
dtype
)
torch
.
set_default_dtype
(
torch_dtype
)
for
model_class_name
,
model_class
in
_MODELS
.
items
():
if
model_class_name
in
model_name
:
if
use_dummy_weights
:
# Create a model instance.
# The weights will be initialized as empty tensors.
model
=
model_class
(
config
)
model
=
model
.
cuda
()
# NOTE(woosuk): For precise performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights
(
model
)
else
:
# Create a model instance.
model
=
model_class
(
config
)
# Load the weights from the cached or downloaded files.
model
.
load_weights
(
model_name
,
cache_dir
,
use_np_cache
)
model
=
model
.
cuda
()
return
model
.
eval
(),
torch_dtype
raise
ValueError
(
f
'Unsupported model name:
{
model_name
}
'
)
model_class
=
_get_model_architecture
(
config
)
# Create a model instance.
# The weights will be initialized as empty tensors.
model
=
model_class
(
config
)
if
use_dummy_weights
:
model
=
model
.
cuda
()
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights
(
model
)
else
:
# Load the weights from the cached or downloaded files.
model
.
load_weights
(
model_name
,
cache_dir
,
use_np_cache
)
model
=
model
.
cuda
()
return
model
.
eval
(),
torch_dtype
def
get_memory_analyzer
(
...
...
@@ -95,9 +111,7 @@ def get_memory_analyzer(
)
->
CacheFlowMemoryAnalyzer
:
config
=
AutoConfig
.
from_pretrained
(
model_name
)
torch_dtype
=
_get_dtype
(
config
,
dtype
)
for
model_class
,
memory_analyzer
in
_MEMORY_ANALYZERS
.
items
():
if
model_class
in
model_name
:
return
memory_analyzer
(
model_name
,
block_size
,
torch_dtype
,
gpu_memory
,
cpu_memory
,
tensor_parallel_size
)
raise
ValueError
(
f
'Unsupported model name:
{
model_name
}
'
)
memory_analyzer
=
_get_memory_analyzer
(
config
)
return
memory_analyzer
(
model_name
,
block_size
,
torch_dtype
,
gpu_memory
,
cpu_memory
,
tensor_parallel_size
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment