Commit f51086de authored by zhuwenwen's avatar zhuwenwen
Browse files

update utils.py

parent db94f061
......@@ -33,52 +33,45 @@ def set_default_torch_dtype(dtype: torch.dtype):
def is_transformers_impl_compatible(
arch: str,
module: Optional["transformers.PreTrainedModel"] = None) -> bool:
module: Optional[transformers.PreTrainedModel] = None) -> bool:
mod = module or getattr(transformers, arch, None)
if mod is None:
return False
return mod.is_backend_compatible()
if hasattr(mod, "supports_backend"):
return mod.is_backend_compatible()
else:
return mod._supports_flex_attn
def resolve_transformers_arch(model_config: ModelConfig,
architectures: list[str]):
def resolve_transformers_fallback(model_config: ModelConfig,
architectures: list[str]):
for i, arch in enumerate(architectures):
if arch == "TransformersForCausalLM":
if arch == "TransformersModel":
continue
auto_map: dict[str, str] = getattr(model_config.hf_config, "auto_map",
None) or dict()
# Make sure that config class is always initialized before model class,
# otherwise the model class won't be able to access the config class,
# the expected auto_map should have correct order like:
# "auto_map": {
# "AutoConfig": "<your-repo-name>--<config-name>",
# "AutoModel": "<your-repo-name>--<config-name>",
# "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
# },
auto_modules = {
name: get_class_from_dynamic_module(module, model_config.model)
for name, module in sorted(auto_map.items(), key=lambda x: x[0])
}
custom_model_module = auto_modules.get("AutoModel")
custom_module = None
auto_map = getattr(model_config.hf_config, "auto_map", None)
if auto_map is not None and "AutoModel" in auto_map:
custom_module = get_class_from_dynamic_module(
model_config.hf_config.auto_map["AutoModel"],
model_config.model)
# TODO(Isotr0py): Further clean up these raises.
# perhaps handled them in _ModelRegistry._raise_for_unsupported?
if model_config.model_impl == ModelImpl.TRANSFORMERS:
if not is_transformers_impl_compatible(arch, custom_model_module):
if not is_transformers_impl_compatible(arch, custom_module):
raise ValueError(
f"The Transformers implementation of {arch} is not "
"compatible with vLLM.")
architectures[i] = "TransformersForCausalLM"
architectures[i] = "TransformersModel"
if model_config.model_impl == ModelImpl.AUTO:
if not is_transformers_impl_compatible(arch, custom_model_module):
if not is_transformers_impl_compatible(arch, custom_module):
raise ValueError(
f"{arch} has no vLLM implementation and the Transformers "
"implementation is not compatible with vLLM. Try setting "
"VLLM_USE_V1=0.")
"implementation is not compatible with vLLM.")
logger.warning(
"%s has no vLLM implementation, falling back to Transformers "
"implementation. Some features may not be supported and "
"performance may not be optimal.", arch)
architectures[i] = "TransformersForCausalLM"
architectures[i] = "TransformersModel"
return architectures
......@@ -92,7 +85,7 @@ def get_model_architecture(
'Qwen2MoeForCausalLM', 'ChatGLMModel', 'ChatGLMForConditionalGeneration',
'BaichuanForCausalLM', 'BloomForCausalLM', 'MedusaModel', 'MixtralForCausalLM',
'MLPSpeculatorPreTrainedModel', 'FalconForCausalLM', 'DeepseekV2ForCausalLM',
'DeepseekV3ForCausalLM', 'DeepSeekMTPModel']
'DeepseekV3ForCausalLM', 'DeepSeekMTPModel']
if any(arch in architectures for arch in support_nn_architectures):
if os.getenv('LLAMA_NN') != '0':
if (architectures == ['QWenLMHeadModel'] or architectures == ['ChatGLMModel'] ) and visions != []:
......@@ -119,7 +112,10 @@ def get_model_architecture(
else:
os.environ['AWQ_PAD'] = '0'
else:
os.environ['LLAMA_NN'] = '0'
if os.getenv('LLAMA_NN') == '1':
os.environ['LLAMA_NN'] = '1'
else:
os.environ['LLAMA_NN'] = '0'
os.environ['LM_NN'] = '0'
os.environ['GEMM_PAD'] = '0'
os.environ['FA_PAD'] = '0'
......@@ -141,7 +137,8 @@ def get_model_architecture(
for arch in architectures)
if (not is_vllm_supported
or model_config.model_impl == ModelImpl.TRANSFORMERS):
architectures = resolve_transformers_arch(model_config, architectures)
architectures = resolve_transformers_fallback(model_config,
architectures)
model_cls, arch = ModelRegistry.resolve_model_cls(architectures)
if model_config.task == "embed":
......@@ -205,4 +202,4 @@ def configure_quant_config(quant_config: QuantizationConfig,
logger.warning(
"The model class %s has not defined `packed_modules_mapping`, "
"this may lead to incorrect mapping of quantized or ignored "
"modules", model_class.__name__)
"modules", model_class.__name__)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment