Commit db94f061 authored by zhuwenwen's avatar zhuwenwen
Browse files

fix llama and qwen layout

parent 14f46a65
......@@ -33,45 +33,52 @@ def set_default_torch_dtype(dtype: torch.dtype):
def is_transformers_impl_compatible(
arch: str,
module: Optional[transformers.PreTrainedModel] = None) -> bool:
module: Optional["transformers.PreTrainedModel"] = None) -> bool:
mod = module or getattr(transformers, arch, None)
if mod is None:
return False
if hasattr(mod, "supports_backend"):
return mod.is_backend_compatible()
else:
return mod._supports_flex_attn
return mod.is_backend_compatible()
def resolve_transformers_fallback(model_config: ModelConfig,
architectures: list[str]):
def resolve_transformers_arch(model_config: ModelConfig,
architectures: list[str]):
for i, arch in enumerate(architectures):
if arch == "TransformersModel":
if arch == "TransformersForCausalLM":
continue
custom_module = None
auto_map = getattr(model_config.hf_config, "auto_map", None)
if auto_map is not None and "AutoModel" in auto_map:
custom_module = get_class_from_dynamic_module(
model_config.hf_config.auto_map["AutoModel"],
model_config.model)
auto_map: dict[str, str] = getattr(model_config.hf_config, "auto_map",
None) or dict()
# Make sure that config class is always initialized before model class,
# otherwise the model class won't be able to access the config class,
# the expected auto_map should have correct order like:
# "auto_map": {
# "AutoConfig": "<your-repo-name>--<config-name>",
# "AutoModel": "<your-repo-name>--<config-name>",
# "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
# },
auto_modules = {
name: get_class_from_dynamic_module(module, model_config.model)
for name, module in sorted(auto_map.items(), key=lambda x: x[0])
}
custom_model_module = auto_modules.get("AutoModel")
# TODO(Isotr0py): Further clean up these raises.
# perhaps handled them in _ModelRegistry._raise_for_unsupported?
if model_config.model_impl == ModelImpl.TRANSFORMERS:
if not is_transformers_impl_compatible(arch, custom_module):
if not is_transformers_impl_compatible(arch, custom_model_module):
raise ValueError(
f"The Transformers implementation of {arch} is not "
"compatible with vLLM.")
architectures[i] = "TransformersModel"
architectures[i] = "TransformersForCausalLM"
if model_config.model_impl == ModelImpl.AUTO:
if not is_transformers_impl_compatible(arch, custom_module):
if not is_transformers_impl_compatible(arch, custom_model_module):
raise ValueError(
f"{arch} has no vLLM implementation and the Transformers "
"implementation is not compatible with vLLM.")
"implementation is not compatible with vLLM. Try setting "
"VLLM_USE_V1=0.")
logger.warning(
"%s has no vLLM implementation, falling back to Transformers "
"implementation. Some features may not be supported and "
"performance may not be optimal.", arch)
architectures[i] = "TransformersModel"
architectures[i] = "TransformersForCausalLM"
return architectures
......@@ -85,7 +92,7 @@ def get_model_architecture(
'Qwen2MoeForCausalLM', 'ChatGLMModel', 'ChatGLMForConditionalGeneration',
'BaichuanForCausalLM', 'BloomForCausalLM', 'MedusaModel', 'MixtralForCausalLM',
'MLPSpeculatorPreTrainedModel', 'FalconForCausalLM', 'DeepseekV2ForCausalLM',
'DeepseekV3ForCausalLM', 'DeepSeekMTPModel']
'DeepseekV3ForCausalLM', 'DeepSeekMTPModel']
if any(arch in architectures for arch in support_nn_architectures):
if os.getenv('LLAMA_NN') != '0':
if (architectures == ['QWenLMHeadModel'] or architectures == ['ChatGLMModel'] ) and visions != []:
......@@ -112,10 +119,7 @@ def get_model_architecture(
else:
os.environ['AWQ_PAD'] = '0'
else:
if os.getenv('LLAMA_NN') == '1':
os.environ['LLAMA_NN'] = '1'
else:
os.environ['LLAMA_NN'] = '0'
os.environ['LLAMA_NN'] = '0'
os.environ['LM_NN'] = '0'
os.environ['GEMM_PAD'] = '0'
os.environ['FA_PAD'] = '0'
......@@ -137,8 +141,7 @@ def get_model_architecture(
for arch in architectures)
if (not is_vllm_supported
or model_config.model_impl == ModelImpl.TRANSFORMERS):
architectures = resolve_transformers_fallback(model_config,
architectures)
architectures = resolve_transformers_arch(model_config, architectures)
model_cls, arch = ModelRegistry.resolve_model_cls(architectures)
if model_config.task == "embed":
......
......@@ -407,6 +407,11 @@ def safetensors_weights_iterator(
hf_weights_files: List[str]
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files."""
total_count = 0
for st_file in hf_weights_files:
with safe_open(st_file, framework="pt") as f:
total_count += len(f.keys())
current_count = 0
enable_tqdm = not torch.distributed.is_initialized(
) or torch.distributed.get_rank() == 0
for st_file in tqdm(
......@@ -417,7 +422,10 @@ def safetensors_weights_iterator(
):
with safe_open(st_file, framework="pt") as f:
for name in f.keys(): # noqa: SIM118
current_count += 1
param = f.get_tensor(name)
param.current_count = current_count
param.total_count = total_count
yield name, param
......
......@@ -414,6 +414,8 @@ class LlamaModel(nn.Module):
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
current_count = loaded_weight.current_count
total_count = loaded_weight.total_count
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
......@@ -466,7 +468,7 @@ class LlamaModel(nn.Module):
weight_loader(param, loaded_weight)
loaded_params.add(name)
if self.use_llama_nn and self.quant_method is None :
if self.use_llama_nn and self.quant_method is None and current_count==total_count:
lay_key_words = [
"self_attn.qkv_proj.weight",
"self_attn.o_proj.weight",
......@@ -479,7 +481,8 @@ class LlamaModel(nn.Module):
# qkv_words = "|".join(lay_qkv_words)
# for layername, weight in params_dict.items():
for layername in loaded_params:
# for layername in loaded_params:
for layername in params_dict.keys():
weight = params_dict[layername]
if "lm_head.weight" in layername and weight.shape[1] >= 4096:
lay_key_words.append("lm_head.weight")
......
......@@ -398,6 +398,8 @@ class Qwen2Model(nn.Module):
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
current_count = loaded_weight.current_count
total_count = loaded_weight.total_count
if "rotary_emb.inv_freq" in name:
continue
if (self.quant_config is not None and
......@@ -440,7 +442,7 @@ class Qwen2Model(nn.Module):
weight_loader(param, loaded_weight)
loaded_params.add(name)
if self.use_llama_nn and self.quant_method is None:
if self.use_llama_nn and self.quant_method is None and current_count==total_count:
lay_key_words = [
"self_attn.qkv_proj.weight",
"self_attn.o_proj.weight",
......@@ -456,7 +458,8 @@ class Qwen2Model(nn.Module):
# qkv_bias_words = "|".join(lay_qkv_bias_words)
# for layername, weight in params_dict.items():
for layername in loaded_params:
# for layername in loaded_params:
for layername in params_dict.keys():
weight = params_dict[layername]
if "lm_head.weight" in layername and weight.shape[1] >= 3584:
lay_key_words.append("lm_head.weight")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment