Commit f51086de authored by zhuwenwen's avatar zhuwenwen
Browse files

update utils.py

parent db94f061
...@@ -33,52 +33,45 @@ def set_default_torch_dtype(dtype: torch.dtype): ...@@ -33,52 +33,45 @@ def set_default_torch_dtype(dtype: torch.dtype):
def is_transformers_impl_compatible( def is_transformers_impl_compatible(
arch: str, arch: str,
module: Optional["transformers.PreTrainedModel"] = None) -> bool: module: Optional[transformers.PreTrainedModel] = None) -> bool:
mod = module or getattr(transformers, arch, None) mod = module or getattr(transformers, arch, None)
if mod is None: if mod is None:
return False return False
if hasattr(mod, "supports_backend"):
return mod.is_backend_compatible() return mod.is_backend_compatible()
else:
return mod._supports_flex_attn
def resolve_transformers_arch(model_config: ModelConfig, def resolve_transformers_fallback(model_config: ModelConfig,
architectures: list[str]): architectures: list[str]):
for i, arch in enumerate(architectures): for i, arch in enumerate(architectures):
if arch == "TransformersForCausalLM": if arch == "TransformersModel":
continue continue
auto_map: dict[str, str] = getattr(model_config.hf_config, "auto_map", custom_module = None
None) or dict() auto_map = getattr(model_config.hf_config, "auto_map", None)
# Make sure that config class is always initialized before model class, if auto_map is not None and "AutoModel" in auto_map:
# otherwise the model class won't be able to access the config class, custom_module = get_class_from_dynamic_module(
# the expected auto_map should have correct order like: model_config.hf_config.auto_map["AutoModel"],
# "auto_map": { model_config.model)
# "AutoConfig": "<your-repo-name>--<config-name>",
# "AutoModel": "<your-repo-name>--<config-name>",
# "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
# },
auto_modules = {
name: get_class_from_dynamic_module(module, model_config.model)
for name, module in sorted(auto_map.items(), key=lambda x: x[0])
}
custom_model_module = auto_modules.get("AutoModel")
# TODO(Isotr0py): Further clean up these raises. # TODO(Isotr0py): Further clean up these raises.
# perhaps handled them in _ModelRegistry._raise_for_unsupported? # perhaps handled them in _ModelRegistry._raise_for_unsupported?
if model_config.model_impl == ModelImpl.TRANSFORMERS: if model_config.model_impl == ModelImpl.TRANSFORMERS:
if not is_transformers_impl_compatible(arch, custom_model_module): if not is_transformers_impl_compatible(arch, custom_module):
raise ValueError( raise ValueError(
f"The Transformers implementation of {arch} is not " f"The Transformers implementation of {arch} is not "
"compatible with vLLM.") "compatible with vLLM.")
architectures[i] = "TransformersForCausalLM" architectures[i] = "TransformersModel"
if model_config.model_impl == ModelImpl.AUTO: if model_config.model_impl == ModelImpl.AUTO:
if not is_transformers_impl_compatible(arch, custom_model_module): if not is_transformers_impl_compatible(arch, custom_module):
raise ValueError( raise ValueError(
f"{arch} has no vLLM implementation and the Transformers " f"{arch} has no vLLM implementation and the Transformers "
"implementation is not compatible with vLLM. Try setting " "implementation is not compatible with vLLM.")
"VLLM_USE_V1=0.")
logger.warning( logger.warning(
"%s has no vLLM implementation, falling back to Transformers " "%s has no vLLM implementation, falling back to Transformers "
"implementation. Some features may not be supported and " "implementation. Some features may not be supported and "
"performance may not be optimal.", arch) "performance may not be optimal.", arch)
architectures[i] = "TransformersForCausalLM" architectures[i] = "TransformersModel"
return architectures return architectures
...@@ -118,6 +111,9 @@ def get_model_architecture( ...@@ -118,6 +111,9 @@ def get_model_architecture(
os.environ['AWQ_PAD'] = '1' os.environ['AWQ_PAD'] = '1'
else: else:
os.environ['AWQ_PAD'] = '0' os.environ['AWQ_PAD'] = '0'
else:
if os.getenv('LLAMA_NN') == '1':
os.environ['LLAMA_NN'] = '1'
else: else:
os.environ['LLAMA_NN'] = '0' os.environ['LLAMA_NN'] = '0'
os.environ['LM_NN'] = '0' os.environ['LM_NN'] = '0'
...@@ -141,7 +137,8 @@ def get_model_architecture( ...@@ -141,7 +137,8 @@ def get_model_architecture(
for arch in architectures) for arch in architectures)
if (not is_vllm_supported if (not is_vllm_supported
or model_config.model_impl == ModelImpl.TRANSFORMERS): or model_config.model_impl == ModelImpl.TRANSFORMERS):
architectures = resolve_transformers_arch(model_config, architectures) architectures = resolve_transformers_fallback(model_config,
architectures)
model_cls, arch = ModelRegistry.resolve_model_cls(architectures) model_cls, arch = ModelRegistry.resolve_model_cls(architectures)
if model_config.task == "embed": if model_config.task == "embed":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment