Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
SIYIXNI
vllm
Commits
518369d7
"profiler/include/profile_gemm_splitk_impl.hpp" did not exist on "e823d518cb46ad61ddb3c70eac8529e0a58af1f8"
Unverified
Commit
518369d7
authored
Dec 12, 2023
by
Woosuk Kwon
Committed by
GitHub
Dec 12, 2023
Browse files
Implement lazy model loader (#2044)
parent
30bad5c4
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
89 additions
and
101 deletions
+89
-101
vllm/model_executor/model_loader.py
vllm/model_executor/model_loader.py
+5
-57
vllm/model_executor/models/__init__.py
vllm/model_executor/models/__init__.py
+77
-38
vllm/model_executor/models/mixtral.py
vllm/model_executor/models/mixtral.py
+7
-6
No files found.
vllm/model_executor/model_loader.py
View file @
518369d7
...
...
@@ -7,54 +7,9 @@ import torch.nn as nn
from
transformers
import
PretrainedConfig
from
vllm.config
import
ModelConfig
from
vllm.model_executor.models
import
*
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.weight_utils
import
(
get_quant_config
,
initialize_dummy_weights
)
from
vllm.utils
import
is_hip
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
# TODO(woosuk): Lazy-load the model classes.
_MODEL_REGISTRY
=
{
"AquilaModel"
:
AquilaForCausalLM
,
"AquilaForCausalLM"
:
AquilaForCausalLM
,
# AquilaChat2
"BaiChuanForCausalLM"
:
BaiChuanForCausalLM
,
# baichuan-7b
"BaichuanForCausalLM"
:
BaichuanForCausalLM
,
# baichuan-13b
"BloomForCausalLM"
:
BloomForCausalLM
,
"ChatGLMModel"
:
ChatGLMForCausalLM
,
"ChatGLMForConditionalGeneration"
:
ChatGLMForCausalLM
,
"FalconForCausalLM"
:
FalconForCausalLM
,
"GPT2LMHeadModel"
:
GPT2LMHeadModel
,
"GPTBigCodeForCausalLM"
:
GPTBigCodeForCausalLM
,
"GPTJForCausalLM"
:
GPTJForCausalLM
,
"GPTNeoXForCausalLM"
:
GPTNeoXForCausalLM
,
"InternLMForCausalLM"
:
InternLMForCausalLM
,
"LlamaForCausalLM"
:
LlamaForCausalLM
,
"LLaMAForCausalLM"
:
LlamaForCausalLM
,
# For decapoda-research/llama-*
"MistralForCausalLM"
:
MistralForCausalLM
,
"MixtralForCausalLM"
:
MixtralForCausalLM
,
# transformers's mpt class has lower case
"MptForCausalLM"
:
MPTForCausalLM
,
"MPTForCausalLM"
:
MPTForCausalLM
,
"OPTForCausalLM"
:
OPTForCausalLM
,
"PhiForCausalLM"
:
PhiForCausalLM
,
"QWenLMHeadModel"
:
QWenLMHeadModel
,
"RWForCausalLM"
:
FalconForCausalLM
,
"YiForCausalLM"
:
YiForCausalLM
,
}
# Models to be disabled in ROCm
_ROCM_UNSUPPORTED_MODELS
=
[]
if
is_hip
():
for
rocm_model
in
_ROCM_UNSUPPORTED_MODELS
:
del
_MODEL_REGISTRY
[
rocm_model
]
# Models partially supported in ROCm
_ROCM_PARTIALLY_SUPPORTED_MODELS
=
{
"MistralForCausalLM"
:
"Sliding window attention is not supported in ROCm's flash attention"
,
}
@
contextlib
.
contextmanager
...
...
@@ -69,19 +24,12 @@ def _set_default_torch_dtype(dtype: torch.dtype):
def
_get_model_architecture
(
config
:
PretrainedConfig
)
->
Type
[
nn
.
Module
]:
architectures
=
getattr
(
config
,
"architectures"
,
[])
for
arch
in
architectures
:
if
arch
in
_MODEL_REGISTRY
:
if
is_hip
()
and
arch
in
_ROCM_PARTIALLY_SUPPORTED_MODELS
:
logger
.
warning
(
f
"
{
arch
}
is not fully supported in ROCm. Reason: "
f
"
{
_ROCM_PARTIALLY_SUPPORTED_MODELS
[
arch
]
}
"
)
return
_MODEL_REGISTRY
[
arch
]
elif
arch
in
_ROCM_UNSUPPORTED_MODELS
:
raise
ValueError
(
f
"Model architecture
{
arch
}
is not supported by ROCm for now.
\n
"
f
"Supported architectures
{
list
(
_MODEL_REGISTRY
.
keys
())
}
"
)
model_cls
=
ModelRegistry
.
load_model_cls
(
arch
)
if
model_cls
is
not
None
:
return
model_cls
raise
ValueError
(
f
"Model architectures
{
architectures
}
are not supported for now. "
f
"Supported architectures:
{
list
(
_MODEL_REGISTRY
.
key
s
()
)
}
"
)
f
"Supported architectures:
{
ModelRegistry
.
get_supported_arch
s
()
}
"
)
def
get_model
(
model_config
:
ModelConfig
)
->
nn
.
Module
:
...
...
vllm/model_executor/models/__init__.py
View file @
518369d7
from
vllm.model_executor.models.aquila
import
AquilaForCausalLM
from
vllm.model_executor.models.baichuan
import
(
BaiChuanForCausalLM
,
BaichuanForCausalLM
)
from
vllm.model_executor.models.bloom
import
BloomForCausalLM
from
vllm.model_executor.models.falcon
import
FalconForCausalLM
from
vllm.model_executor.models.gpt2
import
GPT2LMHeadModel
from
vllm.model_executor.models.gpt_bigcode
import
GPTBigCodeForCausalLM
from
vllm.model_executor.models.gpt_j
import
GPTJForCausalLM
from
vllm.model_executor.models.gpt_neox
import
GPTNeoXForCausalLM
from
vllm.model_executor.models.internlm
import
InternLMForCausalLM
from
vllm.model_executor.models.llama
import
LlamaForCausalLM
from
vllm.model_executor.models.mistral
import
MistralForCausalLM
from
vllm.model_executor.models.mixtral
import
MixtralForCausalLM
from
vllm.model_executor.models.mpt
import
MPTForCausalLM
from
vllm.model_executor.models.opt
import
OPTForCausalLM
from
vllm.model_executor.models.phi_1_5
import
PhiForCausalLM
from
vllm.model_executor.models.qwen
import
QWenLMHeadModel
from
vllm.model_executor.models.chatglm
import
ChatGLMForCausalLM
from
vllm.model_executor.models.yi
import
YiForCausalLM
import
importlib
from
typing
import
List
,
Optional
,
Type
import
torch.nn
as
nn
from
vllm.logger
import
init_logger
from
vllm.utils
import
is_hip
logger
=
init_logger
(
__name__
)
# Architecture -> (module, class).
_MODELS
=
{
"AquilaModel"
:
(
"aquila"
,
"AquilaForCausalLM"
),
"AquilaForCausalLM"
:
(
"aquila"
,
"AquilaForCausalLM"
),
# AquilaChat2
"BaiChuanForCausalLM"
:
(
"baichuan"
,
"BaiChuanForCausalLM"
),
# baichuan-7b
"BaichuanForCausalLM"
:
(
"baichuan"
,
"BaichuanForCausalLM"
),
# baichuan-13b
"BloomForCausalLM"
:
(
"bloom"
,
"BloomForCausalLM"
),
"ChatGLMModel"
:
(
"chatglm"
,
"ChatGLMForCausalLM"
),
"ChatGLMForConditionalGeneration"
:
(
"chatglm"
,
"ChatGLMForCausalLM"
),
"FalconForCausalLM"
:
(
"falcon"
,
"FalconForCausalLM"
),
"GPT2LMHeadModel"
:
(
"gpt2"
,
"GPT2LMHeadModel"
),
"GPTBigCodeForCausalLM"
:
(
"gpt_bigcode"
,
"GPTBigCodeForCausalLM"
),
"GPTJForCausalLM"
:
(
"gpt_j"
,
"GPTJForCausalLM"
),
"GPTNeoXForCausalLM"
:
(
"gpt_neox"
,
"GPTNeoXForCausalLM"
),
"InternLMForCausalLM"
:
(
"internlm"
,
"InternLMForCausalLM"
),
"LlamaForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
# For decapoda-research/llama-*
"LLaMAForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
"MistralForCausalLM"
:
(
"mistral"
,
"MistralForCausalLM"
),
"MixtralForCausalLM"
:
(
"mixtral"
,
"MixtralForCausalLM"
),
# transformers's mpt class has lower case
"MptForCausalLM"
:
(
"mpt"
,
"MPTForCausalLM"
),
"MPTForCausalLM"
:
(
"mpt"
,
"MPTForCausalLM"
),
"OPTForCausalLM"
:
(
"opt"
,
"OPTForCausalLM"
),
"PhiForCausalLM"
:
(
"phi_1_5"
,
"PhiForCausalLM"
),
"QWenLMHeadModel"
:
(
"qwen"
,
"QWenLMHeadModel"
),
"RWForCausalLM"
:
(
"falcon"
,
"FalconForCausalLM"
),
"YiForCausalLM"
:
(
"yi"
,
"YiForCausalLM"
),
}
# Models not supported by ROCm.
_ROCM_UNSUPPORTED_MODELS
=
[
"MixtralForCausalLM"
]
# Models partially supported by ROCm.
# Architecture -> Reason.
_ROCM_PARTIALLY_SUPPORTED_MODELS
=
{
"MistralForCausalLM"
:
"Sliding window attention is not yet supported in ROCm's flash attention"
,
}
class
ModelRegistry
:
@
staticmethod
def
load_model_cls
(
model_arch
:
str
)
->
Optional
[
Type
[
nn
.
Module
]]:
if
model_arch
not
in
_MODELS
:
return
None
if
is_hip
():
if
model_arch
in
_ROCM_UNSUPPORTED_MODELS
:
raise
ValueError
(
f
"Model architecture
{
model_arch
}
is not supported by "
"ROCm for now."
)
if
model_arch
in
_ROCM_PARTIALLY_SUPPORTED_MODELS
:
logger
.
warning
(
f
"Model architecture
{
model_arch
}
is partially supported "
"by ROCm: "
+
_ROCM_PARTIALLY_SUPPORTED_MODELS
[
model_arch
])
module_name
,
model_cls_name
=
_MODELS
[
model_arch
]
module
=
importlib
.
import_module
(
f
"vllm.model_executor.models.
{
module_name
}
"
)
return
getattr
(
module
,
model_cls_name
,
None
)
@
staticmethod
def
get_supported_archs
()
->
List
[
str
]:
return
list
(
_MODELS
.
keys
())
__all__
=
[
"AquilaForCausalLM"
,
"BaiChuanForCausalLM"
,
"BaichuanForCausalLM"
,
"BloomForCausalLM"
,
"ChatGLMForCausalLM"
,
"FalconForCausalLM"
,
"GPT2LMHeadModel"
,
"GPTBigCodeForCausalLM"
,
"GPTJForCausalLM"
,
"GPTNeoXForCausalLM"
,
"InternLMForCausalLM"
,
"LlamaForCausalLM"
,
"MPTForCausalLM"
,
"OPTForCausalLM"
,
"PhiForCausalLM"
,
"QWenLMHeadModel"
,
"MistralForCausalLM"
,
"MixtralForCausalLM"
,
"YiForCausalLM"
,
"ModelRegistry"
,
]
vllm/model_executor/models/mixtral.py
View file @
518369d7
...
...
@@ -33,14 +33,15 @@ from transformers import MixtralConfig
try
:
import
megablocks.ops
as
ops
except
ImportError
:
print
(
"MegaBlocks not found.
Please install it by `pip install megablocks`."
)
except
ImportError
as
e
:
raise
ImportError
(
"MegaBlocks not found. "
"
Please install it by `pip install megablocks`."
)
from
e
try
:
import
stk
except
ImportError
:
print
(
"STK not found: please see https://github.com/stanford-futuredata/stk"
)
except
ImportError
as
e
:
raise
ImportError
(
"STK not found. "
"Please install it by `pip install stanford-stk`."
)
from
e
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.attention
import
PagedAttention
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment