Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
518369d7
Unverified
Commit
518369d7
authored
Dec 12, 2023
by
Woosuk Kwon
Committed by
GitHub
Dec 12, 2023
Browse files
Implement lazy model loader (#2044)
parent
30bad5c4
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
89 additions
and
101 deletions
+89
-101
vllm/model_executor/model_loader.py
vllm/model_executor/model_loader.py
+5
-57
vllm/model_executor/models/__init__.py
vllm/model_executor/models/__init__.py
+77
-38
vllm/model_executor/models/mixtral.py
vllm/model_executor/models/mixtral.py
+7
-6
No files found.
vllm/model_executor/model_loader.py
View file @
518369d7
...
...
@@ -7,54 +7,9 @@ import torch.nn as nn
from
transformers
import
PretrainedConfig
from
vllm.config
import
ModelConfig
from
vllm.model_executor.models
import
*
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.weight_utils
import
(
get_quant_config
,
initialize_dummy_weights
)
from
vllm.utils
import
is_hip
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
# TODO(woosuk): Lazy-load the model classes.
_MODEL_REGISTRY
=
{
"AquilaModel"
:
AquilaForCausalLM
,
"AquilaForCausalLM"
:
AquilaForCausalLM
,
# AquilaChat2
"BaiChuanForCausalLM"
:
BaiChuanForCausalLM
,
# baichuan-7b
"BaichuanForCausalLM"
:
BaichuanForCausalLM
,
# baichuan-13b
"BloomForCausalLM"
:
BloomForCausalLM
,
"ChatGLMModel"
:
ChatGLMForCausalLM
,
"ChatGLMForConditionalGeneration"
:
ChatGLMForCausalLM
,
"FalconForCausalLM"
:
FalconForCausalLM
,
"GPT2LMHeadModel"
:
GPT2LMHeadModel
,
"GPTBigCodeForCausalLM"
:
GPTBigCodeForCausalLM
,
"GPTJForCausalLM"
:
GPTJForCausalLM
,
"GPTNeoXForCausalLM"
:
GPTNeoXForCausalLM
,
"InternLMForCausalLM"
:
InternLMForCausalLM
,
"LlamaForCausalLM"
:
LlamaForCausalLM
,
"LLaMAForCausalLM"
:
LlamaForCausalLM
,
# For decapoda-research/llama-*
"MistralForCausalLM"
:
MistralForCausalLM
,
"MixtralForCausalLM"
:
MixtralForCausalLM
,
# transformers's mpt class has lower case
"MptForCausalLM"
:
MPTForCausalLM
,
"MPTForCausalLM"
:
MPTForCausalLM
,
"OPTForCausalLM"
:
OPTForCausalLM
,
"PhiForCausalLM"
:
PhiForCausalLM
,
"QWenLMHeadModel"
:
QWenLMHeadModel
,
"RWForCausalLM"
:
FalconForCausalLM
,
"YiForCausalLM"
:
YiForCausalLM
,
}
# Models to be disabled in ROCm
_ROCM_UNSUPPORTED_MODELS
=
[]
if
is_hip
():
for
rocm_model
in
_ROCM_UNSUPPORTED_MODELS
:
del
_MODEL_REGISTRY
[
rocm_model
]
# Models partially supported in ROCm
_ROCM_PARTIALLY_SUPPORTED_MODELS
=
{
"MistralForCausalLM"
:
"Sliding window attention is not supported in ROCm's flash attention"
,
}
@
contextlib
.
contextmanager
...
...
@@ -69,19 +24,12 @@ def _set_default_torch_dtype(dtype: torch.dtype):
def
_get_model_architecture
(
config
:
PretrainedConfig
)
->
Type
[
nn
.
Module
]:
architectures
=
getattr
(
config
,
"architectures"
,
[])
for
arch
in
architectures
:
if
arch
in
_MODEL_REGISTRY
:
if
is_hip
()
and
arch
in
_ROCM_PARTIALLY_SUPPORTED_MODELS
:
logger
.
warning
(
f
"
{
arch
}
is not fully supported in ROCm. Reason: "
f
"
{
_ROCM_PARTIALLY_SUPPORTED_MODELS
[
arch
]
}
"
)
return
_MODEL_REGISTRY
[
arch
]
elif
arch
in
_ROCM_UNSUPPORTED_MODELS
:
raise
ValueError
(
f
"Model architecture
{
arch
}
is not supported by ROCm for now.
\n
"
f
"Supported architectures
{
list
(
_MODEL_REGISTRY
.
keys
())
}
"
)
model_cls
=
ModelRegistry
.
load_model_cls
(
arch
)
if
model_cls
is
not
None
:
return
model_cls
raise
ValueError
(
f
"Model architectures
{
architectures
}
are not supported for now. "
f
"Supported architectures:
{
list
(
_MODEL_REGISTRY
.
key
s
()
)
}
"
)
f
"Supported architectures:
{
ModelRegistry
.
get_supported_arch
s
()
}
"
)
def
get_model
(
model_config
:
ModelConfig
)
->
nn
.
Module
:
...
...
vllm/model_executor/models/__init__.py
View file @
518369d7
from
vllm.model_executor.models.aquila
import
AquilaForCausalLM
from
vllm.model_executor.models.baichuan
import
(
BaiChuanForCausalLM
,
BaichuanForCausalLM
)
from
vllm.model_executor.models.bloom
import
BloomForCausalLM
from
vllm.model_executor.models.falcon
import
FalconForCausalLM
from
vllm.model_executor.models.gpt2
import
GPT2LMHeadModel
from
vllm.model_executor.models.gpt_bigcode
import
GPTBigCodeForCausalLM
from
vllm.model_executor.models.gpt_j
import
GPTJForCausalLM
from
vllm.model_executor.models.gpt_neox
import
GPTNeoXForCausalLM
from
vllm.model_executor.models.internlm
import
InternLMForCausalLM
from
vllm.model_executor.models.llama
import
LlamaForCausalLM
from
vllm.model_executor.models.mistral
import
MistralForCausalLM
from
vllm.model_executor.models.mixtral
import
MixtralForCausalLM
from
vllm.model_executor.models.mpt
import
MPTForCausalLM
from
vllm.model_executor.models.opt
import
OPTForCausalLM
from
vllm.model_executor.models.phi_1_5
import
PhiForCausalLM
from
vllm.model_executor.models.qwen
import
QWenLMHeadModel
from
vllm.model_executor.models.chatglm
import
ChatGLMForCausalLM
from
vllm.model_executor.models.yi
import
YiForCausalLM
import
importlib
from
typing
import
List
,
Optional
,
Type
import
torch.nn
as
nn
from
vllm.logger
import
init_logger
from
vllm.utils
import
is_hip
logger
=
init_logger
(
__name__
)
# Architecture -> (module, class).
_MODELS
=
{
"AquilaModel"
:
(
"aquila"
,
"AquilaForCausalLM"
),
"AquilaForCausalLM"
:
(
"aquila"
,
"AquilaForCausalLM"
),
# AquilaChat2
"BaiChuanForCausalLM"
:
(
"baichuan"
,
"BaiChuanForCausalLM"
),
# baichuan-7b
"BaichuanForCausalLM"
:
(
"baichuan"
,
"BaichuanForCausalLM"
),
# baichuan-13b
"BloomForCausalLM"
:
(
"bloom"
,
"BloomForCausalLM"
),
"ChatGLMModel"
:
(
"chatglm"
,
"ChatGLMForCausalLM"
),
"ChatGLMForConditionalGeneration"
:
(
"chatglm"
,
"ChatGLMForCausalLM"
),
"FalconForCausalLM"
:
(
"falcon"
,
"FalconForCausalLM"
),
"GPT2LMHeadModel"
:
(
"gpt2"
,
"GPT2LMHeadModel"
),
"GPTBigCodeForCausalLM"
:
(
"gpt_bigcode"
,
"GPTBigCodeForCausalLM"
),
"GPTJForCausalLM"
:
(
"gpt_j"
,
"GPTJForCausalLM"
),
"GPTNeoXForCausalLM"
:
(
"gpt_neox"
,
"GPTNeoXForCausalLM"
),
"InternLMForCausalLM"
:
(
"internlm"
,
"InternLMForCausalLM"
),
"LlamaForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
# For decapoda-research/llama-*
"LLaMAForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
"MistralForCausalLM"
:
(
"mistral"
,
"MistralForCausalLM"
),
"MixtralForCausalLM"
:
(
"mixtral"
,
"MixtralForCausalLM"
),
# transformers's mpt class has lower case
"MptForCausalLM"
:
(
"mpt"
,
"MPTForCausalLM"
),
"MPTForCausalLM"
:
(
"mpt"
,
"MPTForCausalLM"
),
"OPTForCausalLM"
:
(
"opt"
,
"OPTForCausalLM"
),
"PhiForCausalLM"
:
(
"phi_1_5"
,
"PhiForCausalLM"
),
"QWenLMHeadModel"
:
(
"qwen"
,
"QWenLMHeadModel"
),
"RWForCausalLM"
:
(
"falcon"
,
"FalconForCausalLM"
),
"YiForCausalLM"
:
(
"yi"
,
"YiForCausalLM"
),
}
# Models not supported by ROCm.
_ROCM_UNSUPPORTED_MODELS
=
[
"MixtralForCausalLM"
]
# Models partially supported by ROCm.
# Architecture -> Reason.
_ROCM_PARTIALLY_SUPPORTED_MODELS
=
{
"MistralForCausalLM"
:
"Sliding window attention is not yet supported in ROCm's flash attention"
,
}
class
ModelRegistry
:
@
staticmethod
def
load_model_cls
(
model_arch
:
str
)
->
Optional
[
Type
[
nn
.
Module
]]:
if
model_arch
not
in
_MODELS
:
return
None
if
is_hip
():
if
model_arch
in
_ROCM_UNSUPPORTED_MODELS
:
raise
ValueError
(
f
"Model architecture
{
model_arch
}
is not supported by "
"ROCm for now."
)
if
model_arch
in
_ROCM_PARTIALLY_SUPPORTED_MODELS
:
logger
.
warning
(
f
"Model architecture
{
model_arch
}
is partially supported "
"by ROCm: "
+
_ROCM_PARTIALLY_SUPPORTED_MODELS
[
model_arch
])
module_name
,
model_cls_name
=
_MODELS
[
model_arch
]
module
=
importlib
.
import_module
(
f
"vllm.model_executor.models.
{
module_name
}
"
)
return
getattr
(
module
,
model_cls_name
,
None
)
@
staticmethod
def
get_supported_archs
()
->
List
[
str
]:
return
list
(
_MODELS
.
keys
())
__all__
=
[
"AquilaForCausalLM"
,
"BaiChuanForCausalLM"
,
"BaichuanForCausalLM"
,
"BloomForCausalLM"
,
"ChatGLMForCausalLM"
,
"FalconForCausalLM"
,
"GPT2LMHeadModel"
,
"GPTBigCodeForCausalLM"
,
"GPTJForCausalLM"
,
"GPTNeoXForCausalLM"
,
"InternLMForCausalLM"
,
"LlamaForCausalLM"
,
"MPTForCausalLM"
,
"OPTForCausalLM"
,
"PhiForCausalLM"
,
"QWenLMHeadModel"
,
"MistralForCausalLM"
,
"MixtralForCausalLM"
,
"YiForCausalLM"
,
"ModelRegistry"
,
]
vllm/model_executor/models/mixtral.py
View file @
518369d7
...
...
@@ -33,14 +33,15 @@ from transformers import MixtralConfig
try
:
import
megablocks.ops
as
ops
except
ImportError
:
print
(
"MegaBlocks not found.
Please install it by `pip install megablocks`."
)
except
ImportError
as
e
:
raise
ImportError
(
"MegaBlocks not found. "
"
Please install it by `pip install megablocks`."
)
from
e
try
:
import
stk
except
ImportError
:
print
(
"STK not found: please see https://github.com/stanford-futuredata/stk"
)
except
ImportError
as
e
:
raise
ImportError
(
"STK not found. "
"Please install it by `pip install stanford-stk`."
)
from
e
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.attention
import
PagedAttention
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment