Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
add055e1
Unverified
Commit
add055e1
authored
May 09, 2023
by
Woosuk Kwon
Committed by
GitHub
May 09, 2023
Browse files
Enhance model loader (#83)
parent
7c041ab5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
56 additions
and
42 deletions
+56
-42
cacheflow/core/server.py
cacheflow/core/server.py
+1
-1
cacheflow/model_executor/model_loader.py
cacheflow/model_executor/model_loader.py
+55
-41
No files found.
cacheflow/core/server.py
View file @
add055e1
...
@@ -12,8 +12,8 @@ from cacheflow.core.scheduler import Scheduler
...
@@ -12,8 +12,8 @@ from cacheflow.core.scheduler import Scheduler
from
cacheflow.frontend.simple_frontend
import
SimpleFrontend
from
cacheflow.frontend.simple_frontend
import
SimpleFrontend
from
cacheflow.logger
import
init_logger
from
cacheflow.logger
import
init_logger
from
cacheflow.model_executor
import
get_memory_analyzer
from
cacheflow.model_executor
import
get_memory_analyzer
from
cacheflow.sequence
import
SequenceGroup
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sequence
import
SequenceGroup
from
cacheflow.utils
import
get_gpu_memory
,
get_cpu_memory
from
cacheflow.utils
import
get_gpu_memory
,
get_cpu_memory
from
cacheflow.worker.controller
import
Controller
,
DeviceID
from
cacheflow.worker.controller
import
Controller
,
DeviceID
...
...
cacheflow/model_executor/model_loader.py
View file @
add055e1
...
@@ -14,32 +14,51 @@ from cacheflow.model_executor.utils import get_torch_dtype
...
@@ -14,32 +14,51 @@ from cacheflow.model_executor.utils import get_torch_dtype
from
cacheflow.model_executor.weight_utils
import
initialize_dummy_weights
from
cacheflow.model_executor.weight_utils
import
initialize_dummy_weights
_MODELS
=
{
# TODO(woosuk): Lazy-load the model classes.
'gpt2'
:
GPT2LMHeadModel
,
_MODEL_REGISTRY
=
{
'llama'
:
LlamaForCausalLM
,
"GPT2LMHeadModel"
:
GPT2LMHeadModel
,
'opt'
:
OPTForCausalLM
,
"GPTNeoXForCausalLM"
:
GPTNeoXForCausalLM
,
'stablelm'
:
GPTNeoXForCausalLM
,
"LlamaForCausalLM"
:
LlamaForCausalLM
,
'pythia'
:
GPTNeoXForCausalLM
,
"OPTForCausalLM"
:
OPTForCausalLM
,
'dolly-v2'
:
GPTNeoXForCausalLM
,
}
}
_MEMORY_ANALYZERS
=
{
_MEMORY_ANALYZERS
=
{
'gpt2'
:
GPT2MemoryAnalyzer
,
"GPT2LMHeadModel"
:
GPT2MemoryAnalyzer
,
'llama'
:
LlamaMemoryAnalyzer
,
"GPTNeoXForCausalLM"
:
GPTNeoXMemoryAnalyzer
,
'opt'
:
OPTMemoryAnalyzer
,
"LlamaForCausalLM"
:
LlamaMemoryAnalyzer
,
'stablelm'
:
GPTNeoXMemoryAnalyzer
,
"OPTForCausalLM"
:
OPTMemoryAnalyzer
,
'pythia'
:
GPTNeoXMemoryAnalyzer
,
'dolly-v2'
:
GPTNeoXMemoryAnalyzer
,
}
}
def
_get_model_architecture
(
config
:
PretrainedConfig
)
->
nn
.
Module
:
architectures
=
getattr
(
config
,
"architectures"
,
[])
for
arch
in
architectures
:
if
arch
in
_MODEL_REGISTRY
:
return
_MODEL_REGISTRY
[
arch
]
raise
ValueError
(
f
"Model architectures
{
architectures
}
are not supported for now. "
f
"Supported architectures:
{
list
(
_MODEL_REGISTRY
.
keys
())
}
"
)
def
_get_memory_analyzer
(
config
:
PretrainedConfig
)
->
CacheFlowMemoryAnalyzer
:
architectures
=
getattr
(
config
,
"architectures"
,
[])
for
arch
in
architectures
:
if
arch
in
_MEMORY_ANALYZERS
:
return
_MEMORY_ANALYZERS
[
arch
]
raise
ValueError
(
f
"Model architectures
{
architectures
}
are not supported for now. "
f
"Supported architectures:
{
list
(
_MEMORY_ANALYZERS
.
keys
())
}
"
)
def
_get_dtype
(
config
:
PretrainedConfig
,
dtype
:
str
)
->
torch
.
dtype
:
def
_get_dtype
(
config
:
PretrainedConfig
,
dtype
:
str
)
->
torch
.
dtype
:
# NOTE: getattr(config,
'
torch_dtype
'
, torch.float32) is not correct
# NOTE: getattr(config,
"
torch_dtype
"
, torch.float32) is not correct
# because config.torch_dtype can be None.
# because config.torch_dtype can be None.
config_dtype
=
getattr
(
config
,
'
torch_dtype
'
,
None
)
config_dtype
=
getattr
(
config
,
"
torch_dtype
"
,
None
)
if
config_dtype
is
None
:
if
config_dtype
is
None
:
config_dtype
=
torch
.
float32
config_dtype
=
torch
.
float32
if
dtype
==
'
default
'
:
if
dtype
==
"
default
"
:
if
config_dtype
==
torch
.
float32
:
if
config_dtype
==
torch
.
float32
:
# Following the common practice, we use float16 for float32 models.
# Following the common practice, we use float16 for float32 models.
torch_dtype
=
torch
.
float16
torch_dtype
=
torch
.
float16
...
@@ -51,7 +70,7 @@ def _get_dtype(config: PretrainedConfig, dtype: str) -> torch.dtype:
...
@@ -51,7 +70,7 @@ def _get_dtype(config: PretrainedConfig, dtype: str) -> torch.dtype:
# TODO(woosuk): Allow using float16 for bfloat16 models and
# TODO(woosuk): Allow using float16 for bfloat16 models and
# vice versa. Print a warning message and continue.
# vice versa. Print a warning message and continue.
raise
ValueError
(
raise
ValueError
(
f
'
Cannot use
{
torch_dtype
}
for
{
config_dtype
}
model.
'
)
f
"
Cannot use
{
torch_dtype
}
for
{
config_dtype
}
model.
"
)
return
torch_dtype
return
torch_dtype
...
@@ -65,24 +84,21 @@ def get_model(
...
@@ -65,24 +84,21 @@ def get_model(
config
=
AutoConfig
.
from_pretrained
(
model_name
)
config
=
AutoConfig
.
from_pretrained
(
model_name
)
torch_dtype
=
_get_dtype
(
config
,
dtype
)
torch_dtype
=
_get_dtype
(
config
,
dtype
)
torch
.
set_default_dtype
(
torch_dtype
)
torch
.
set_default_dtype
(
torch_dtype
)
for
model_class_name
,
model_class
in
_MODELS
.
items
():
model_class
=
_get_model_architecture
(
config
)
if
model_class_name
in
model_name
:
if
use_dummy_weights
:
# Create a model instance.
# Create a model instance.
# The weights will be initialized as empty tensors.
# The weights will be initialized as empty tensors.
model
=
model_class
(
config
)
model
=
model_class
(
config
)
if
use_dummy_weights
:
model
=
model
.
cuda
()
model
=
model
.
cuda
()
# NOTE(woosuk): For precise performance evaluation, we assign
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
# random values to the weights.
initialize_dummy_weights
(
model
)
initialize_dummy_weights
(
model
)
else
:
else
:
# Create a model instance.
# Load the weights from the cached or downloaded files.
model
=
model_class
(
config
)
model
.
load_weights
(
model_name
,
cache_dir
,
use_np_cache
)
# Load the weights from the cached or downloaded files.
model
=
model
.
cuda
()
model
.
load_weights
(
model_name
,
cache_dir
,
use_np_cache
)
return
model
.
eval
(),
torch_dtype
model
=
model
.
cuda
()
return
model
.
eval
(),
torch_dtype
raise
ValueError
(
f
'Unsupported model name:
{
model_name
}
'
)
def
get_memory_analyzer
(
def
get_memory_analyzer
(
...
@@ -95,9 +111,7 @@ def get_memory_analyzer(
...
@@ -95,9 +111,7 @@ def get_memory_analyzer(
)
->
CacheFlowMemoryAnalyzer
:
)
->
CacheFlowMemoryAnalyzer
:
config
=
AutoConfig
.
from_pretrained
(
model_name
)
config
=
AutoConfig
.
from_pretrained
(
model_name
)
torch_dtype
=
_get_dtype
(
config
,
dtype
)
torch_dtype
=
_get_dtype
(
config
,
dtype
)
for
model_class
,
memory_analyzer
in
_MEMORY_ANALYZERS
.
items
():
memory_analyzer
=
_get_memory_analyzer
(
config
)
if
model_class
in
model_name
:
return
memory_analyzer
(
return
memory_analyzer
(
model_name
,
block_size
,
torch_dtype
,
gpu_memory
,
cpu_memory
,
model_name
,
block_size
,
torch_dtype
,
gpu_memory
,
cpu_memory
,
tensor_parallel_size
)
tensor_parallel_size
)
raise
ValueError
(
f
'Unsupported model name:
{
model_name
}
'
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment