Unverified Commit dbe55885 authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[ Misc ] non-uniform quantization via `compressed-tensors` for `Llama` (#6515)

parent d4201e06
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-nonuniform-test -b auto -l 1000 -f 5 -t 1
model_name: "nm-testing/Meta-Llama-3-8B-Instruct-nonuniform-test"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.758
- name: "exact_match,flexible-extract"
value: 0.759
limit: 1000
num_fewshot: 5
...@@ -2,4 +2,5 @@ Meta-Llama-3-8B-Instruct.yaml ...@@ -2,4 +2,5 @@ Meta-Llama-3-8B-Instruct.yaml
Meta-Llama-3-8B-Instruct-FP8.yaml Meta-Llama-3-8B-Instruct-FP8.yaml
Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml
...@@ -158,6 +158,7 @@ class FusedMoE(torch.nn.Module): ...@@ -158,6 +158,7 @@ class FusedMoE(torch.nn.Module):
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None, tp_size: Optional[int] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
......
...@@ -171,6 +171,8 @@ class ReplicatedLinear(LinearBase): ...@@ -171,6 +171,8 @@ class ReplicatedLinear(LinearBase):
skip_bias_add: If true, skip adding bias but instead return it. skip_bias_add: If true, skip adding bias but instead return it.
params_dtype: Data type for the parameters. params_dtype: Data type for the parameters.
quant_config: Quantization configure. quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
""" """
def __init__(self, def __init__(self,
...@@ -179,15 +181,19 @@ class ReplicatedLinear(LinearBase): ...@@ -179,15 +181,19 @@ class ReplicatedLinear(LinearBase):
bias: bool = True, bias: bool = True,
skip_bias_add: bool = False, skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None,
prefix: Optional[str] = None):
super().__init__(input_size, output_size, skip_bias_add, params_dtype, super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config) quant_config)
# All the linear layer supports quant method. # All the linear layer supports quant method.
assert self.quant_method is not None assert self.quant_method is not None
self.quant_method.create_weights(self, self.input_size, self.quant_method.create_weights(self,
[self.output_size], self.input_size, self.input_size, [self.output_size],
self.output_size, self.params_dtype) self.input_size,
self.output_size,
self.params_dtype,
prefix=prefix)
if bias: if bias:
self.bias = Parameter( self.bias = Parameter(
...@@ -239,6 +245,8 @@ class ColumnParallelLinear(LinearBase): ...@@ -239,6 +245,8 @@ class ColumnParallelLinear(LinearBase):
quant_config: Quantization configure. quant_config: Quantization configure.
output_sizes: list of output sizes packed into one output, like for QKV output_sizes: list of output sizes packed into one output, like for QKV
the list would be size 3. the list would be size 3.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
""" """
def __init__(self, def __init__(self,
...@@ -249,7 +257,8 @@ class ColumnParallelLinear(LinearBase): ...@@ -249,7 +257,8 @@ class ColumnParallelLinear(LinearBase):
skip_bias_add: bool = False, skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
output_sizes: Optional[List[int]] = None): output_sizes: Optional[List[int]] = None,
prefix: Optional[str] = None):
super().__init__(input_size, output_size, skip_bias_add, params_dtype, super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config) quant_config)
...@@ -276,7 +285,8 @@ class ColumnParallelLinear(LinearBase): ...@@ -276,7 +285,8 @@ class ColumnParallelLinear(LinearBase):
input_size=self.input_size, input_size=self.input_size,
output_size=self.output_size, output_size=self.output_size,
params_dtype=self.params_dtype, params_dtype=self.params_dtype,
weight_loader=self.weight_loader) weight_loader=self.weight_loader,
prefix=prefix)
if bias: if bias:
self.bias = Parameter( self.bias = Parameter(
torch.empty(self.output_size_per_partition, torch.empty(self.output_size_per_partition,
...@@ -348,6 +358,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -348,6 +358,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
skip adding bias but instead return it. skip adding bias but instead return it.
params_dtype: Data type for the parameters. params_dtype: Data type for the parameters.
quant_config: Quantization configure. quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
""" """
def __init__(self, def __init__(self,
...@@ -357,7 +369,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -357,7 +369,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
gather_output: bool = False, gather_output: bool = False,
skip_bias_add: bool = False, skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None,
prefix: Optional[str] = None):
self.output_sizes = output_sizes self.output_sizes = output_sizes
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
assert all(output_size % tp_size == 0 for output_size in output_sizes) assert all(output_size % tp_size == 0 for output_size in output_sizes)
...@@ -367,7 +380,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -367,7 +380,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
gather_output=gather_output, gather_output=gather_output,
skip_bias_add=skip_bias_add, skip_bias_add=skip_bias_add,
params_dtype=params_dtype, params_dtype=params_dtype,
quant_config=quant_config) quant_config=quant_config,
prefix=prefix)
def weight_loader(self, def weight_loader(self,
param: Parameter, param: Parameter,
...@@ -487,6 +501,8 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -487,6 +501,8 @@ class QKVParallelLinear(ColumnParallelLinear):
skip adding bias but instead return it. skip adding bias but instead return it.
params_dtype: Data type for the parameters. params_dtype: Data type for the parameters.
quant_config: Quantization configure. quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
""" """
def __init__(self, def __init__(self,
...@@ -497,7 +513,8 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -497,7 +513,8 @@ class QKVParallelLinear(ColumnParallelLinear):
bias: bool = True, bias: bool = True,
skip_bias_add: bool = False, skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None,
prefix: Optional[str] = None):
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.head_size = head_size self.head_size = head_size
self.total_num_heads = total_num_heads self.total_num_heads = total_num_heads
...@@ -529,7 +546,8 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -529,7 +546,8 @@ class QKVParallelLinear(ColumnParallelLinear):
gather_output=False, gather_output=False,
skip_bias_add=skip_bias_add, skip_bias_add=skip_bias_add,
params_dtype=params_dtype, params_dtype=params_dtype,
quant_config=quant_config) quant_config=quant_config,
prefix=prefix)
def weight_loader(self, def weight_loader(self,
param: Parameter, param: Parameter,
...@@ -688,7 +706,8 @@ class RowParallelLinear(LinearBase): ...@@ -688,7 +706,8 @@ class RowParallelLinear(LinearBase):
skip_bias_add: bool = False, skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = True, reduce_results: bool = True,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None,
prefix: Optional[str] = None):
super().__init__(input_size, output_size, skip_bias_add, params_dtype, super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config) quant_config)
...@@ -706,7 +725,8 @@ class RowParallelLinear(LinearBase): ...@@ -706,7 +725,8 @@ class RowParallelLinear(LinearBase):
input_size=self.input_size, input_size=self.input_size,
output_size=self.output_size, output_size=self.output_size,
params_dtype=self.params_dtype, params_dtype=self.params_dtype,
weight_loader=self.weight_loader) weight_loader=self.weight_loader,
prefix=prefix)
if not reduce_results and (bias and not skip_bias_add): if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the " raise ValueError("When not reduce the results, adding bias to the "
"results can lead to incorrect results") "results can lead to incorrect results")
......
...@@ -8,23 +8,25 @@ from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 ...@@ -8,23 +8,25 @@ from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS,
CompressedTensorsScheme, CompressedTensorsW4A16Sparse24, CompressedTensorsScheme, CompressedTensorsUnquantized,
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8,
CompressedTensorsWNA16) CompressedTensorsW8A8Int8, CompressedTensorsWNA16)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
CompressionFormat, QuantizationArgs, QuantizationStrategy, CompressionFormat, QuantizationArgs, QuantizationStrategy,
QuantizationType, find_first_name_or_class_match, QuantizationType, find_matched_target, is_activation_quantization_format,
is_activation_quantization_format) should_ignore_layer)
from vllm.platforms import current_platform from vllm.platforms import current_platform
class CompressedTensorsConfig(QuantizationConfig): class CompressedTensorsConfig(QuantizationConfig):
def __init__(self, layer_quant_details: Dict[str, Any], ignore: List[str], def __init__(self, target_scheme_map: Dict[str, Any], ignore: List[str],
quant_format: str): quant_format: str):
self.ignore = ignore self.ignore = ignore
self.layer_quant_details = layer_quant_details
self.quant_format = quant_format self.quant_format = quant_format
# Map from [target -> scheme]
self.target_scheme_map = target_scheme_map
def get_linear_method(self) -> "CompressedTensorsLinearMethod": def get_linear_method(self) -> "CompressedTensorsLinearMethod":
return CompressedTensorsLinearMethod(self) return CompressedTensorsLinearMethod(self)
...@@ -51,7 +53,7 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -51,7 +53,7 @@ class CompressedTensorsConfig(QuantizationConfig):
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig": def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
layer_quant_details: Dict[str, Any] = dict() target_scheme_map: Dict[str, Any] = dict()
ignore: List[str] = config.get("ignore", None) ignore: List[str] = config.get("ignore", None)
quant_format: str = config.get("format", None) quant_format: str = config.get("format", None)
...@@ -63,21 +65,21 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -63,21 +65,21 @@ class CompressedTensorsConfig(QuantizationConfig):
# details follow the structure defined by the QuantizationArgs # details follow the structure defined by the QuantizationArgs
# pydantic model, which is used to verify the structure of the # pydantic model, which is used to verify the structure of the
# quant_config and also store the details for later use. # quant_config and also store the details for later use.
for key, quant_config in config["config_groups"].items(): for _, quant_config in config["config_groups"].items():
targets = quant_config.get("targets") targets = quant_config.get("targets")
for target in targets: for target in targets:
layer_quant_details[target] = {} target_scheme_map[target] = {}
layer_quant_details[target][ target_scheme_map[target][
"weights"] = QuantizationArgs.parse_obj( "weights"] = QuantizationArgs.parse_obj(
quant_config.get("weights")) quant_config.get("weights"))
try: try:
layer_quant_details[target][ target_scheme_map[target][
"input_activations"] = QuantizationArgs.parse_obj( "input_activations"] = QuantizationArgs.parse_obj(
quant_config.get("input_activations")) quant_config.get("input_activations"))
except Exception: except Exception:
layer_quant_details[target]["input_activations"] = None target_scheme_map[target]["input_activations"] = None
return cls(layer_quant_details=layer_quant_details, return cls(target_scheme_map=target_scheme_map,
ignore=ignore, ignore=ignore,
quant_format=quant_format) quant_format=quant_format)
...@@ -167,7 +169,8 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -167,7 +169,8 @@ class CompressedTensorsConfig(QuantizationConfig):
return (is_channel_group and input_quant_none and is_symmetric return (is_channel_group and input_quant_none and is_symmetric
and is_static) and is_static)
def _get_schema(self, weight_quant: BaseModel, def _get_scheme_from_parts(
self, weight_quant: BaseModel,
input_quant: BaseModel) -> "CompressedTensorsScheme": input_quant: BaseModel) -> "CompressedTensorsScheme":
# Detect If Mixed Precision # Detect If Mixed Precision
...@@ -205,26 +208,47 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -205,26 +208,47 @@ class CompressedTensorsConfig(QuantizationConfig):
raise NotImplementedError( raise NotImplementedError(
"No compressed-tensors compatible scheme was found.") "No compressed-tensors compatible scheme was found.")
def get_scheme(self, layer: torch.nn.Module) -> "CompressedTensorsScheme": def get_scheme(
self,
layer: torch.nn.Module,
layer_name: Optional[str] = None) -> "CompressedTensorsScheme":
"""
compressed-tensors supports non uniform in the following way:
layer_type_name = find_first_name_or_class_match( ignore: List of layer_names or nn.Module names to be ignored.
name="", targets of config_groups: There can be N config_groups which each
module=layer, have a quantization scheme. Each config_group has a list of targets
targets=self.layer_quant_details.keys(), which can be a full layer_name, a regex for a layer_name, or
check_contains=True) an nn.Module name.
if layer_type_name is None: We first check whether a layer is in the ignore group and use
raise ValueError(f"Could not matching target for layer {layer}") CompressedTensorsUnquantized (i.e. fp16/bf16) scheme for the layer
layer_quant_details: Dict[str, Any] = self.layer_quant_details.get( We then detect whether a layer_name is found in any target and
layer_type_name, None) use the quantization scheme corresponding to the matched target
if layer_quant_details is None: to select the CompressedTensorsScheme used for infernece.
raise ValueError( """
f"Could not find quantization details for {layer}.")
scheme = self._get_schema( # Check if the layer is skipped for quantization.
weight_quant=layer_quant_details["weights"], # TODO (@robertgshaw2): support module names
input_quant=layer_quant_details["input_activations"]) if should_ignore_layer(layer_name, ignore=self.ignore):
return CompressedTensorsUnquantized()
# Find the "target" in the compressed-tensors config
# that our layer conforms to.
# TODO (@robertgshaw): add compressed-tensors as dep
# so we do not have to re-write these functions
matched_target = find_matched_target(
layer_name=layer_name,
module=layer,
targets=self.target_scheme_map.keys())
# Find the quant_scheme
scheme = self.target_scheme_map[matched_target]
return self._get_scheme_from_parts(
weight_quant=scheme["weights"],
input_quant=scheme["input_activations"])
# Raise error if device does not support the scheme # Raise error if device does not support the scheme
# (e.g. fp8 needs ada lovelace) # (e.g. fp8 needs ada lovelace)
...@@ -250,11 +274,11 @@ class CompressedTensorsLinearMethod(LinearMethodBase): ...@@ -250,11 +274,11 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
Use the CompressedTensorsScheme associated with each layer to create Use the CompressedTensorsScheme associated with each layer to create
the necessary parameters for the layer. See LinearMethodBase for param the necessary parameters for the layer. See LinearMethodBase for param
details details
""" """
weight_loader = extra_weight_attrs.get("weight_loader") weight_loader = extra_weight_attrs.get("weight_loader")
layer_name = extra_weight_attrs.get("prefix")
scheme = self.quantization_config.get_scheme(layer=layer) scheme = self.quantization_config.get_scheme(layer, layer_name)
scheme.create_weights( scheme.create_weights(
layer=layer, layer=layer,
input_size=input_size, input_size=input_size,
......
...@@ -33,7 +33,6 @@ class CompressedTensorsUnquantized(CompressedTensorsScheme): ...@@ -33,7 +33,6 @@ class CompressedTensorsUnquantized(CompressedTensorsScheme):
weight = Parameter(torch.empty(sum(output_partition_sizes), weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition, input_size_per_partition,
device="cuda",
dtype=params_dtype), dtype=params_dtype),
requires_grad=False) requires_grad=False)
......
...@@ -86,25 +86,106 @@ def is_activation_quantization_format(format: str) -> bool: ...@@ -86,25 +86,106 @@ def is_activation_quantization_format(format: str) -> bool:
return format in _ACTIVATION_QUANTIZATION_FORMATS return format in _ACTIVATION_QUANTIZATION_FORMATS
def find_first_name_or_class_match( # fused_name: List[shard_name]
name: str, _FUSED_LAYER_NAME_MAPPING = {
module: Module, "qkv_proj": ["q_proj", "k_proj", "v_proj"],
targets: Iterable[str], "gate_up_proj": ["gate_proj", "up_proj"]
check_contains: bool = False) -> Optional[str]: }
def should_ignore_layer(layer_name: Optional[str],
ignore: Iterable[str]) -> bool:
if layer_name is None:
return False
# layer_name = model.layers.0.self_attn.qkv_proj
# proj_name = qkv_proj
proj_name = layer_name.split(".")[-1]
# Fused layers like gate_up_proj or qkv_proj will not be fused
# in the safetensors checkpoint. So, we convert the name
# from the fused version to unfused + check to make sure that
# each shard of the fused layer has the same scheme.
if proj_name in _FUSED_LAYER_NAME_MAPPING:
shard_proj_names = _FUSED_LAYER_NAME_MAPPING[proj_name]
# Convert fused_name --> [shard_names]
shard_names = [
layer_name.replace(proj_name, shard_proj_name)
for shard_proj_name in shard_proj_names
]
# Layer should be ignored if shards are ignored.
should_ignore_layer = None
for shard_name in shard_names:
should_ignore_shard = check_equal_or_regex_match(
layer_name=shard_name, targets=ignore)
# If shard_idx=0, set layer ignore to match shard.
if should_ignore_layer is None:
should_ignore_layer = should_ignore_shard
# If shard_idx=1+ confirm scheme matches prior shards.
elif should_ignore_shard != should_ignore_layer:
raise ValueError(f"Found a different quantization schemes for "
f"{shard_proj_names} in {layer_name}. vLLM "
"requires all to use the same scheme.")
# Unfused layers like down_proj and o_proj will match
# the safetensors checkpoint already.
else:
should_ignore_layer = check_equal_or_regex_match(layer_name=layer_name,
targets=ignore)
assert should_ignore_layer is not None
return should_ignore_layer
def check_equal_or_regex_match(layer_name: str,
targets: Iterable[str]) -> bool:
"""
Checks whether a layer_name is exactly equal or a regex match for
if target starts with 're:' to any target in list.
"""
for target in targets:
if _is_equal_or_regex_match(layer_name, target):
return True
return False
def find_matched_target(layer_name: Optional[str], module: Module,
targets: Iterable[str]) -> str:
""" """
Helper function to map the quantization details listed in the config Helper function to look up which "target" in the compressed-tensors
for a given list of targets against each model layer. First uses the config that a layer corresponds to.
layer name to try and find a match. If no name match is found, uses
the layer class name. Returns None otherwise. Recall that a compressed-tensors configs has a concept of
config_groups, where each layer can be quantized with with a different
scheme.
targets in each config_group will be a list of either layer names
(or regexes corresponding to layer names) or names of torch Modules.
:param name: layer name First, we try to match the layer_name with a target
Second, we try to match the module's name with a target
:param layer_name: layer name
:param module: torch.nn.Module :param module: torch.nn.Module
:param targets: list of targets to match the layer against :param targets: list of targets to match the layer against
:param check_contains: whether or not to do a substring match
""" """
return _find_first_match(name, targets) or _find_first_match( if layer_name is None:
module.__class__.__name__, targets, check_contains) layer_name = ""
matched_target = (_find_first_match(layer_name, targets)
or _find_first_match(module.__class__.__name__, targets,
True))
if matched_target is None:
raise ValueError(f"Unable to find matching target for {module} in the "
"compressed-tensors config.")
return matched_target
def _find_first_match(value: str, def _find_first_match(value: str,
...@@ -121,13 +202,29 @@ def _find_first_match(value: str, ...@@ -121,13 +202,29 @@ def _find_first_match(value: str,
""" """
for target in targets: for target in targets:
if _is_equal_or_regex_match(value,
target,
check_contains=check_contains):
return target
return None
def _is_equal_or_regex_match(value: str,
target: str,
check_contains: bool = False) -> bool:
"""
Checks whether a value is exactly equal or a regex match for target
if target starts with 're:'. If check_contains is set to True,
additionally checks if the target string is contained within the value.
"""
if target.startswith("re:"): if target.startswith("re:"):
pattern = target[3:] pattern = target[3:]
if re.match(pattern, value): if re.match(pattern, value):
return target return True
elif check_contains: elif check_contains:
if target.lower() in value.lower(): if target.lower() in value.lower():
return target return True
elif target == value: elif target == value:
return target return True
return None return False
...@@ -51,6 +51,7 @@ class GPT2Attention(nn.Module): ...@@ -51,6 +51,7 @@ class GPT2Attention(nn.Module):
config: GPT2Config, config: GPT2Config,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -68,12 +69,14 @@ class GPT2Attention(nn.Module): ...@@ -68,12 +69,14 @@ class GPT2Attention(nn.Module):
total_num_heads, total_num_heads,
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.c_attn",
) )
self.c_proj = RowParallelLinear( self.c_proj = RowParallelLinear(
self.hidden_size, self.hidden_size,
self.hidden_size, self.hidden_size,
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.c_proj",
) )
self.attn = Attention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_dim, self.head_dim,
...@@ -101,6 +104,7 @@ class GPT2MLP(nn.Module): ...@@ -101,6 +104,7 @@ class GPT2MLP(nn.Module):
intermediate_size: int, intermediate_size: int,
config: GPT2Config, config: GPT2Config,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
...@@ -109,12 +113,14 @@ class GPT2MLP(nn.Module): ...@@ -109,12 +113,14 @@ class GPT2MLP(nn.Module):
intermediate_size, intermediate_size,
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.c_fc",
) )
self.c_proj = RowParallelLinear( self.c_proj = RowParallelLinear(
intermediate_size, intermediate_size,
hidden_size, hidden_size,
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.c_proj",
) )
self.act = get_act_fn(config.activation_function, quant_config, self.act = get_act_fn(config.activation_function, quant_config,
intermediate_size) intermediate_size)
...@@ -133,6 +139,7 @@ class GPT2Block(nn.Module): ...@@ -133,6 +139,7 @@ class GPT2Block(nn.Module):
config: GPT2Config, config: GPT2Config,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
...@@ -140,9 +147,15 @@ class GPT2Block(nn.Module): ...@@ -140,9 +147,15 @@ class GPT2Block(nn.Module):
hidden_size) hidden_size)
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = GPT2Attention(config, cache_config, quant_config) self.attn = GPT2Attention(config,
cache_config,
quant_config,
prefix=f"{prefix}.attn")
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = GPT2MLP(inner_dim, config, quant_config) self.mlp = GPT2MLP(inner_dim,
config,
quant_config,
prefix=f"{prefix}.mlp")
def forward( def forward(
self, self,
...@@ -175,6 +188,7 @@ class GPT2Model(nn.Module): ...@@ -175,6 +188,7 @@ class GPT2Model(nn.Module):
config: GPT2Config, config: GPT2Config,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -186,7 +200,9 @@ class GPT2Model(nn.Module): ...@@ -186,7 +200,9 @@ class GPT2Model(nn.Module):
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.start_layer, self.end_layer, self.h = make_layers( self.start_layer, self.end_layer, self.h = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda: GPT2Block(config, cache_config, quant_config)) lambda prefix: GPT2Block(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.h")
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
def forward( def forward(
...@@ -229,7 +245,10 @@ class GPT2LMHeadModel(nn.Module): ...@@ -229,7 +245,10 @@ class GPT2LMHeadModel(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.transformer = GPT2Model(config, cache_config, quant_config) self.transformer = GPT2Model(config,
cache_config,
quant_config,
prefix="transformer")
self.lm_head = self.transformer.wte self.lm_head = self.transformer.wte
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
......
...@@ -62,17 +62,20 @@ class LlamaMLP(nn.Module): ...@@ -62,17 +62,20 @@ class LlamaMLP(nn.Module):
hidden_act: str, hidden_act: str,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
bias: bool = False, bias: bool = False,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
input_size=hidden_size, input_size=hidden_size,
output_sizes=[intermediate_size] * 2, output_sizes=[intermediate_size] * 2,
bias=bias, bias=bias,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj")
self.down_proj = RowParallelLinear(input_size=intermediate_size, self.down_proj = RowParallelLinear(input_size=intermediate_size,
output_size=hidden_size, output_size=hidden_size,
bias=bias, bias=bias,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.down_proj")
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. " raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.") "Only silu is supported for now.")
...@@ -99,6 +102,7 @@ class LlamaAttention(nn.Module): ...@@ -99,6 +102,7 @@ class LlamaAttention(nn.Module):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
bias: bool = False, bias: bool = False,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -132,12 +136,14 @@ class LlamaAttention(nn.Module): ...@@ -132,12 +136,14 @@ class LlamaAttention(nn.Module):
total_num_kv_heads=self.total_num_kv_heads, total_num_kv_heads=self.total_num_kv_heads,
bias=bias, bias=bias,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
input_size=self.total_num_heads * self.head_dim, input_size=self.total_num_heads * self.head_dim,
output_size=hidden_size, output_size=hidden_size,
bias=bias, bias=bias,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.o_proj",
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
...@@ -176,6 +182,7 @@ class LlamaDecoderLayer(nn.Module): ...@@ -176,6 +182,7 @@ class LlamaDecoderLayer(nn.Module):
config: LlamaConfig, config: LlamaConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -203,6 +210,7 @@ class LlamaDecoderLayer(nn.Module): ...@@ -203,6 +210,7 @@ class LlamaDecoderLayer(nn.Module):
quant_config=quant_config, quant_config=quant_config,
bias=attention_bias, bias=attention_bias,
cache_config=cache_config, cache_config=cache_config,
prefix=f"{prefix}.self_attn",
) )
self.mlp = LlamaMLP( self.mlp = LlamaMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
...@@ -210,6 +218,7 @@ class LlamaDecoderLayer(nn.Module): ...@@ -210,6 +218,7 @@ class LlamaDecoderLayer(nn.Module):
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
quant_config=quant_config, quant_config=quant_config,
bias=getattr(config, "mlp_bias", False), bias=getattr(config, "mlp_bias", False),
prefix=f"{prefix}.mlp",
) )
self.input_layernorm = RMSNorm(config.hidden_size, self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
...@@ -253,6 +262,7 @@ class LlamaModel(nn.Module): ...@@ -253,6 +262,7 @@ class LlamaModel(nn.Module):
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -272,9 +282,11 @@ class LlamaModel(nn.Module): ...@@ -272,9 +282,11 @@ class LlamaModel(nn.Module):
self.embed_tokens = PPMissingLayer() self.embed_tokens = PPMissingLayer()
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda: LlamaDecoderLayer(config=config, lambda prefix: LlamaDecoderLayer(config=config,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config)) quant_config=quant_config,
prefix=prefix),
prefix=f"{prefix}.layers")
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else: else:
...@@ -370,7 +382,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): ...@@ -370,7 +382,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
self.model = LlamaModel(config, self.model = LlamaModel(config,
cache_config, cache_config,
quant_config, quant_config,
lora_config=lora_config) lora_config=lora_config,
prefix="model")
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size
if lora_config: if lora_config:
......
...@@ -67,7 +67,8 @@ class MixtralMoE(nn.Module): ...@@ -67,7 +67,8 @@ class MixtralMoE(nn.Module):
intermediate_size: int, intermediate_size: int,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None): tp_size: Optional[int] = None,
prefix: str = ""):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -76,7 +77,8 @@ class MixtralMoE(nn.Module): ...@@ -76,7 +77,8 @@ class MixtralMoE(nn.Module):
num_experts, num_experts,
bias=False, bias=False,
params_dtype=params_dtype, params_dtype=params_dtype,
quant_config=None) quant_config=None,
prefix=f"{prefix}.gate")
self.experts = FusedMoE(num_experts=num_experts, self.experts = FusedMoE(num_experts=num_experts,
top_k=top_k, top_k=top_k,
...@@ -86,7 +88,8 @@ class MixtralMoE(nn.Module): ...@@ -86,7 +88,8 @@ class MixtralMoE(nn.Module):
reduce_results=True, reduce_results=True,
renormalize=True, renormalize=True,
quant_config=quant_config, quant_config=quant_config,
tp_size=tp_size) tp_size=tp_size,
prefix=f"{prefix}.experts")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# NOTE: hidden_states can have either 1D or 2D shape. # NOTE: hidden_states can have either 1D or 2D shape.
...@@ -109,6 +112,7 @@ class MixtralAttention(nn.Module): ...@@ -109,6 +112,7 @@ class MixtralAttention(nn.Module):
rope_theta: float = 10000, rope_theta: float = 10000,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -139,12 +143,14 @@ class MixtralAttention(nn.Module): ...@@ -139,12 +143,14 @@ class MixtralAttention(nn.Module):
self.total_num_kv_heads, self.total_num_kv_heads,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.o_proj",
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
...@@ -182,6 +188,7 @@ class MixtralDecoderLayer(nn.Module): ...@@ -182,6 +188,7 @@ class MixtralDecoderLayer(nn.Module):
config: MixtralConfig, config: MixtralConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -194,13 +201,15 @@ class MixtralDecoderLayer(nn.Module): ...@@ -194,13 +201,15 @@ class MixtralDecoderLayer(nn.Module):
num_kv_heads=config.num_key_value_heads, num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta, rope_theta=rope_theta,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.self_attn")
self.block_sparse_moe = MixtralMoE( self.block_sparse_moe = MixtralMoE(
num_experts=config.num_local_experts, num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok, top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.block_sparse_moe")
self.input_layernorm = RMSNorm(config.hidden_size, self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size, self.post_attention_layernorm = RMSNorm(config.hidden_size,
...@@ -243,6 +252,7 @@ class MixtralModel(nn.Module): ...@@ -243,6 +252,7 @@ class MixtralModel(nn.Module):
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
...@@ -258,8 +268,11 @@ class MixtralModel(nn.Module): ...@@ -258,8 +268,11 @@ class MixtralModel(nn.Module):
) )
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, lambda: MixtralDecoderLayer( config.num_hidden_layers,
config, cache_config, quant_config=quant_config)) lambda prefix: MixtralDecoderLayer(
config, cache_config, quant_config=quant_config, prefix=prefix
),
prefix=f"{prefix}.layers")
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
...@@ -331,7 +344,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA): ...@@ -331,7 +344,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
self.model = MixtralModel(config, self.model = MixtralModel(config,
cache_config, cache_config,
quant_config, quant_config,
lora_config=lora_config) lora_config=lora_config,
prefix="model")
self.unpadded_vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size
if lora_config: if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
......
from typing import Callable, Dict, List, Tuple from typing import Dict, List, Protocol, Tuple
import torch import torch
from torch.func import functional_call from torch.func import functional_call
...@@ -45,6 +45,15 @@ def merge_vision_embeddings(input_ids: torch.Tensor, ...@@ -45,6 +45,15 @@ def merge_vision_embeddings(input_ids: torch.Tensor,
return inputs_embeds return inputs_embeds
class LayerFn(Protocol):
def __call__(
self,
prefix="",
) -> torch.nn.Module:
...
class PPMissingLayer(torch.nn.Identity): class PPMissingLayer(torch.nn.Identity):
""" """
A placeholder layer for missing layers in a pipeline parallel model. A placeholder layer for missing layers in a pipeline parallel model.
...@@ -119,7 +128,9 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module: ...@@ -119,7 +128,9 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
def make_layers( def make_layers(
num_hidden_layers: int, layer_fn: Callable[[], torch.nn.Module] num_hidden_layers: int,
layer_fn: LayerFn,
prefix: str,
) -> Tuple[int, int, torch.nn.ModuleList]: ) -> Tuple[int, int, torch.nn.ModuleList]:
"""Make a list of layers with the given layer function, taking """Make a list of layers with the given layer function, taking
pipeline parallelism into account. pipeline parallelism into account.
...@@ -131,8 +142,8 @@ def make_layers( ...@@ -131,8 +142,8 @@ def make_layers(
get_pp_group().world_size) get_pp_group().world_size)
modules = torch.nn.ModuleList( modules = torch.nn.ModuleList(
[PPMissingLayer() for _ in range(start_layer)] + [ [PPMissingLayer() for _ in range(start_layer)] + [
maybe_offload_to_cpu(layer_fn()) maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}"))
for _ in range(start_layer, end_layer) for idx in range(start_layer, end_layer)
] + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)]) ] + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)])
return start_layer, end_layer, modules return start_layer, end_layer, modules
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment