Commit 66b809cc authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.7.2' into v0.7.2-dev

parents 37b63c24 0408efc6
# SPDX-License-Identifier: Apache-2.0
from typing import Callable, List, Optional, Set
import torch
......
# SPDX-License-Identifier: Apache-2.0
from typing import Callable, List, Optional, Set
import torch
......
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Type
import torch
......
# SPDX-License-Identifier: Apache-2.0
import re
from typing import Iterable, Optional
from types import MappingProxyType
from typing import Iterable, List, Mapping, Optional
from compressed_tensors import CompressionFormat
from torch.nn import Module
from vllm.model_executor.layers.quantization.utils.quant_utils import (
FUSED_LAYER_NAME_MAPPING)
def is_activation_quantization_format(format: str) -> bool:
_ACTIVATION_QUANTIZATION_FORMATS = [
......@@ -17,8 +17,11 @@ def is_activation_quantization_format(format: str) -> bool:
return format in _ACTIVATION_QUANTIZATION_FORMATS
def should_ignore_layer(layer_name: Optional[str],
ignore: Iterable[str]) -> bool:
def should_ignore_layer(
layer_name: Optional[str],
ignore: Iterable[str] = tuple(),
fused_mapping: Mapping[str, List[str]] = MappingProxyType({})
) -> bool:
if layer_name is None:
return False
......@@ -30,8 +33,8 @@ def should_ignore_layer(layer_name: Optional[str],
# in the safetensors checkpoint. So, we convert the name
# from the fused version to unfused + check to make sure that
# each shard of the fused layer has the same scheme.
if proj_name in FUSED_LAYER_NAME_MAPPING and layer_name not in ignore:
shard_proj_names = FUSED_LAYER_NAME_MAPPING[proj_name]
if proj_name in fused_mapping and layer_name not in ignore:
shard_proj_names = fused_mapping[proj_name]
# Convert fused_name --> [shard_names]
shard_names = [
......@@ -77,55 +80,12 @@ def check_equal_or_regex_match(layer_name: str,
return False
def _handle_fused_layers(func):
"""
Decorator to handle fused layers by mapping vllm fused layer names
to their corresponding unfused layer names for quantization/pruning schemes.
"""
# fused_layer_name -> unfused_layer_name
fused_layer_map = {
"qkv_proj": "q_proj",
"gate_up_proj": "up_proj",
}
def fused_layer_handler(layer_name: Optional[str], module: Module,
targets: Iterable[str]) -> Optional[str]:
"""
Wrapper function specifically designed to support the
find_matched_target function.
It handles cases where the provided layer name corresponds to a
fused layer in vllm, mapping it to its equivalent unfused layer name
based on the predefined fused_layer_map. If the original layer name
raises a ValueError in the wrapped function, this handler
will attempt to resolve the issue by substituting with unfused
layer name.
:param layer_name: Name of the layer, which may be fused.
:param module: An instance of torch.nn.Module.
:param targets: A list of target names or patterns to match.
:return: The result of the wrapped find_matched_target function with
the resolved layer name.
:raises ValueError: If the layer name cannot be resolved to a
valid target.
"""
try:
return func(layer_name, module, targets)
except ValueError:
if layer_name is None:
layer_name = ""
parent_name, fused_proj_name = layer_name.rsplit(".", 1)
unfused_proj_name = fused_layer_map.get(fused_proj_name,
fused_proj_name)
new_layer_name = f"{parent_name}.{unfused_proj_name}"
return func(new_layer_name, module, targets)
return fused_layer_handler
@_handle_fused_layers
def find_matched_target(layer_name: Optional[str], module: Module,
targets: Iterable[str]) -> str:
def find_matched_target(
layer_name: Optional[str],
module: Module,
targets: Iterable[str],
fused_mapping: Mapping[str, List[str]] = MappingProxyType({})
) -> str:
"""
Helper function to look up which "target" in the compressed-tensors
config that a layer corresponds to.
......@@ -139,19 +99,25 @@ def find_matched_target(layer_name: Optional[str], module: Module,
First, we try to match the layer_name with a target
Second, we try to match the module's name with a target
Third, we try to map the layer_name to a list of fused module names.
*All* component module names must match in order for a match to be
successful. A successful match returns the first component target
:param layer_name: layer name
:param module: torch.nn.Module
:param targets: list of targets to match the layer against
:param fused_mapping: map from fused layer names to its components
:param fused_strategy: either "all" or "any". If using "all", fused
layers match if "all" of its components match
"""
if layer_name is None:
layer_name = ""
matched_target = (_find_first_match(layer_name, targets)
or _find_first_match(module.__class__.__name__, targets,
True)
or _match_fused_layer(layer_name, targets))
matched_target = (
_find_first_match(layer_name, targets)
or _find_first_match(module.__class__.__name__, targets, True)
or _match_fused_layer(layer_name, targets, fused_mapping))
if matched_target is None:
raise ValueError(
......@@ -203,11 +169,19 @@ def _is_equal_or_regex_match(value: str,
return False
def _match_fused_layer(layer_name: str,
target_layers: Iterable[str]) -> Optional[str]:
def _match_fused_layer(
layer_name: str, target_layers: Iterable[str],
fused_mapping: Mapping[str, List[str]]) -> Optional[str]:
"""
Match a fused layer name to its corresponding individual layer in
target_layers.
target_layers. Returns first value in fused_mapping which matches targets
Implements an "all" matching strategy where a fused layer matches iff
"all" of its components match
:param layer_name: layer name
:param target_layers: list of targets to match the layer against
:param fused_mapping: map from fused layer names to its components
Examples:
layer_name = "model.layers.0.self_attn.qkv_proj"
......@@ -215,27 +189,25 @@ def _match_fused_layer(layer_name: str,
"model.layers.0.self_attn.k_proj",
"model.layers.0.self_attn.v_proj"]
"""
# Split into parent path and layer type
# e.g., "model.layers.0.self_attn" and "qkv_proj"
parent_path = ".".join(layer_name.split(".")[:-1])
layer_type = layer_name.split(".")[-1]
if layer_type not in FUSED_LAYER_NAME_MAPPING:
# find layer_name in mapping
fused = next((key for key in fused_mapping if layer_name.endswith(key)),
None)
if fused is None:
return None
possible_layer_types = FUSED_LAYER_NAME_MAPPING[layer_type]
# Look for a target layer that:
# 1. Has the same parent path
# 2. Ends with one of the possible individual layer types
for target in target_layers:
is_same_parent = parent_path in target
is_matching_type = any(type_suffix in target
for type_suffix in possible_layer_types)
if is_same_parent and is_matching_type and all(
'.'.join([parent_path, type_suffix])
for type_suffix in possible_layer_types):
return target
# expand path of unfused components
unfused_paths = [
layer_name.replace(fused, unfused) for unfused in fused_mapping[fused]
]
return None
# for each unfused component, find a match in targets
unfused_matches: List[Optional[str]] = []
for unfused in unfused_paths:
for target in target_layers:
if _is_equal_or_regex_match(unfused, target):
unfused_matches.append(target)
break
else:
unfused_matches.append(None)
return unfused_matches[0] if all(unfused_matches) else None
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional
import torch
......
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Callable, Dict, List, Optional
import torch
......
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional
import torch
......
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Callable, Dict, List, Optional
import torch
......
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional
import gguf
......
# SPDX-License-Identifier: Apache-2.0
import enum
from enum import Enum
from fractions import Fraction
......
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Callable, Dict, List, Optional, Set, Union
import torch
......
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional
import torch
......
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional
import torch
......
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional
import torch
......
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Callable, Optional, Tuple
......
# SPDX-License-Identifier: Apache-2.0
from typing import List, Optional, Type
import vllm.envs as envs
......
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple
import torch
......
# SPDX-License-Identifier: Apache-2.0
from functools import partial
from typing import Optional, Tuple
......
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple
import torch
......
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional, Tuple
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment