Commit 66b809cc authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.7.2' into v0.7.2-dev

parents 37b63c24 0408efc6
# SPDX-License-Identifier: Apache-2.0
"""A layer that compute logits from hidden_stats."""
import inspect
from concurrent.futures import ThreadPoolExecutor
from typing import Optional
import torch
......@@ -14,6 +16,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform
_logits_processor_threadpool: Optional[ThreadPoolExecutor] = None
if envs.VLLM_LOGITS_PROCESSOR_THREADS is not None:
_logits_processor_threadpool = ThreadPoolExecutor(
envs.VLLM_LOGITS_PROCESSOR_THREADS)
class LogitsProcessor(nn.Module):
"""Process logits and apply logits processors from sampling metadata.
......@@ -134,6 +141,7 @@ def _apply_logits_processors(
) -> torch.Tensor:
found_logits_processors = False
logits_processed = 0
logits_row_ids_and_logits_row_futures = []
for seq_group in sampling_metadata.seq_groups:
seq_ids = seq_group.seq_ids
sampling_params = seq_group.sampling_params
......@@ -147,22 +155,39 @@ def _apply_logits_processors(
past_tokens_ids = seq_group.seq_data[seq_id].output_token_ids
prompt_tokens_ids = seq_group.seq_data[seq_id].prompt_token_ids
for logits_processor in logits_processors:
parameters = inspect.signature(logits_processor).parameters
if len(parameters) == 3:
logits_row = logits_processor(prompt_tokens_ids,
past_tokens_ids,
logits_row)
else:
logits_row = logits_processor(past_tokens_ids,
logits_row)
logits[logits_row_idx] = logits_row
if _logits_processor_threadpool is not None:
logits_row_ids_and_logits_row_futures.append(
(logits_row_idx,
_logits_processor_threadpool.submit(
_apply_logits_processors_single_seq, logits_row,
logits_processors, past_tokens_ids,
prompt_tokens_ids)))
else:
logits[logits_row_idx] = \
_apply_logits_processors_single_seq(
logits_row, logits_processors, past_tokens_ids,
prompt_tokens_ids)
logits_processed += len(seq_group.sample_indices) + len(
seq_group.prompt_logprob_indices)
for logits_row_idx, future in logits_row_ids_and_logits_row_futures:
logits[logits_row_idx] = future.result()
if found_logits_processors:
# verifies that no rows in logits were missed unexpectedly
assert logits_processed == logits.shape[0]
return logits
def _apply_logits_processors_single_seq(logits_row, logits_processors,
past_tokens_ids,
prompt_tokens_ids) -> torch.Tensor:
for logits_processor in logits_processors:
parameters = inspect.signature(logits_processor).parameters
if len(parameters) == 3:
logits_row = logits_processor(prompt_tokens_ids, past_tokens_ids,
logits_row)
else:
logits_row = logits_processor(past_tokens_ids, logits_row)
return logits_row
# SPDX-License-Identifier: Apache-2.0
import torch
from torch import nn
from torch.nn.parameter import Parameter
......
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2024, Tri Dao.
# Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py
......
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/selective_state_update.py
......
# SPDX-License-Identifier: Apache-2.0
from enum import IntEnum
from typing import List, Optional, Union
......
# SPDX-License-Identifier: Apache-2.0
from typing import Dict, List, Type
from vllm.model_executor.layers.quantization.base_config import (
......
# SPDX-License-Identifier: Apache-2.0
# Supports AQLM compression, see https://github.com/Vahe1994/AQLM
# and https://arxiv.org/pdf/2401.06118.pdf
......
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional
import torch
......
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Callable, Dict, List, Optional
import torch
......
# SPDX-License-Identifier: Apache-2.0
import torch
import triton
import triton.language as tl
......
# SPDX-License-Identifier: Apache-2.0
import inspect
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Type
from typing import Any, Dict, List, Mapping, Optional, Type
import torch
from torch import nn
......@@ -57,6 +59,7 @@ def method_has_implemented_embedding(
class QuantizationConfig(ABC):
"""Base class for quantization configs."""
packed_modules_mapping: Mapping[str, List[str]] = dict()
@abstractmethod
def get_name(self) -> str:
......
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional
import torch
......
# SPDX-License-Identifier: Apache-2.0
from contextlib import suppress
from typing import Any, Dict, List, Literal, Optional, Tuple, cast
......@@ -81,7 +83,9 @@ class CompressedTensorsConfig(QuantizationConfig):
# Check if the layer is skipped for quantization.
# TODO (@robertgshaw2): support module names
if should_ignore_layer(prefix, ignore=self.ignore):
if should_ignore_layer(prefix,
ignore=self.ignore,
fused_mapping=self.packed_modules_mapping):
return UnquantizedLinearMethod()
if isinstance(layer, LinearBase):
scheme = self.get_scheme(layer=layer, layer_name=prefix)
......@@ -377,34 +381,29 @@ class CompressedTensorsConfig(QuantizationConfig):
# Will be empty for models with only sparsity
weight_quant = input_quant = None
sparsity_scheme: Optional[SparsityCompressionConfig] = None
if self.target_scheme_map:
matched_target = find_matched_target(
layer_name=layer_name,
module=layer,
targets=self.target_scheme_map.keys())
targets=self.target_scheme_map.keys(),
fused_mapping=self.packed_modules_mapping)
scheme_dict = self.target_scheme_map[matched_target]
weight_quant = scheme_dict.get("weights")
input_quant = scheme_dict.get("input_activations")
if self.sparsity_scheme_map:
is_ignored = False
with suppress(ValueError):
is_ignored = find_matched_target(
layer_name=layer_name,
module=layer,
targets=self.sparsity_ignore_list)
# if the layer is in the sparsity ignore list,
# we should not apply any sparsity scheme
if not is_ignored:
matched_target = find_matched_target(
layer_name=layer_name,
module=layer,
targets=self.sparsity_scheme_map.keys())
sparsity_scheme = self.sparsity_scheme_map.get(matched_target)
# Find the sparsity scheme of the layer
# assume that fused layers inerhit first component's sparsity scheme
sparsity_targets = (self.sparsity_scheme_map.keys() -
set(self.sparsity_ignore_list))
sparsity_scheme: Optional[SparsityCompressionConfig] = None
with suppress(ValueError):
matched_target = find_matched_target(
layer_name=layer_name,
module=layer,
targets=sparsity_targets,
fused_mapping=self.packed_modules_mapping)
sparsity_scheme = self.sparsity_scheme_map[matched_target]
if self.supports_cutlass_24(weight_quant=weight_quant,
input_quant=input_quant,
......@@ -418,10 +417,22 @@ class CompressedTensorsConfig(QuantizationConfig):
return None
# Have a valid sparsity scheme
# Validate layer is supported by Cutlass 2:4 Kernel
scheme = CompressedTensors24(quantized=weight_quant is not None
or input_quant is not None,
weight_quant=weight_quant,
input_quant=input_quant)
model_compression_config = (None if sparsity_scheme is None
or sparsity_scheme.format == "dense"
else self.config)
scheme = CompressedTensors24(
quantized=weight_quant is not None or input_quant is not None,
weight_quant=weight_quant,
input_quant=input_quant,
model_compression_config=model_compression_config,
)
elif weight_quant is None:
logger.warning_once("Acceleration for non-quantized schemes is "
"not supported by Compressed Tensors. "
"Falling back to UnquantizedLinearMethod")
return None
else:
# Find the quant_scheme
scheme = self._get_scheme_from_parts( # type: ignore
......@@ -471,10 +482,21 @@ class CompressedTensorsConfig(QuantizationConfig):
:return: True if the layer is supported by the Cutlass 2:4 Kernel
False otherwise
"""
is_valid_sparsity = (sparsity_scheme is not None
and sparsity_scheme.sparsity_structure
== SparsityStructure.TWO_FOUR.value
and sparsity_scheme.format == "dense")
if sparsity_scheme is None:
return False
is_valid_sparsity_structure: bool = (
sparsity_scheme.sparsity_structure ==
SparsityStructure.TWO_FOUR.value)
valid_compressors = {
CompressionFormat.dense.value,
CompressionFormat.sparse_24_bitmask.value
}
is_valid_sparsity = (is_valid_sparsity_structure
and sparsity_scheme.format in valid_compressors)
if not is_valid_sparsity:
return False
......
# SPDX-License-Identifier: Apache-2.0
import enum
from enum import Enum
from typing import Callable, List, Optional
......
# SPDX-License-Identifier: Apache-2.0
from .compressed_tensors_scheme import CompressedTensorsScheme
from .compressed_tensors_w4a16_24 import (W4A16SPARSE24_SUPPORTED_BITS,
CompressedTensorsW4A16Sparse24)
......
from typing import Callable, List, Optional
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
from compressed_tensors import CompressionFormat, ModelCompressor
from compressed_tensors.quantization import (QuantizationArgs,
QuantizationStrategy,
QuantizationType)
from compressed_tensors.utils import combine_shards
from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
......@@ -20,26 +26,39 @@ __all__ = ["CompressedTensors24"]
class CompressedTensors24(CompressedTensorsScheme):
def __init__(self,
quantized: bool = False,
weight_quant: Optional[QuantizationArgs] = None,
input_quant: Optional[QuantizationArgs] = None):
def __init__(
self,
quantized: bool = False,
weight_quant: Optional[QuantizationArgs] = None,
input_quant: Optional[QuantizationArgs] = None,
model_compression_config: Optional[Dict[str, Any]] = None,
):
self.quantized = quantized
self.weight_quant = weight_quant
self.input_quant = input_quant
self.model_compressor = (
ModelCompressor.from_compression_config(model_compression_config)
if model_compression_config is not None else None)
self.do_sparse_decompress = (
self.model_compressor is not None
and self.model_compressor.sparsity_config.format
== CompressionFormat.sparse_24_bitmask.value)
@classmethod
def get_min_capability(cls) -> int:
# Only cutlass 3.x kernels are implemented so far
return 90
def create_weights(self, layer: torch.nn.Module, input_size: int,
output_partition_sizes: List[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
def create_weights(
self,
layer: torch.nn.Module,
input_size: int,
output_partition_sizes: List[int],
input_size_per_partition: int,
params_dtype: torch.dtype,
weight_loader: Callable,
**kwargs,
):
if not sparse_cutlass_supported():
raise ValueError(
"Sparse CUTLASS not supported. vLLM must be built with "
......@@ -47,16 +66,56 @@ class CompressedTensors24(CompressedTensorsScheme):
self.output_dtype = params_dtype
layer.logical_widths = output_partition_sizes
layer.input_size = input_size
layer.input_size_per_partition = input_size_per_partition
self.weights_dtype: torch.dtype = self._get_params_dtype(params_dtype)
# parameter to store uncompressed weight
weight = ModelWeightParameter(data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition,
dtype=self.weights_dtype),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
weight = ModelWeightParameter(
data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition,
dtype=self.weights_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
if self.do_sparse_decompress:
assert all(partition_size % 8 == 0
for partition_size in output_partition_sizes
), "All partitions must be divisible by 8 for "
"2:4 sparse compressed models"
shape = BasevLLMParameter(
data=torch.empty(2, 1, dtype=torch.int64),
weight_loader=weight_loader,
)
compressed_weight = ModelWeightParameter(
data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition // 2,
dtype=self.weights_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
bitmask = ModelWeightParameter(
data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition // 8,
dtype=torch.uint8,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("shape", shape)
layer.register_parameter("compressed", compressed_weight)
layer.register_parameter("bitmask", bitmask)
# Check if quantized, not just 2:4 Sparse
if self.quantized:
......@@ -66,14 +125,16 @@ class CompressedTensors24(CompressedTensorsScheme):
data=torch.empty((sum(output_partition_sizes), 1),
dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader)
weight_loader=weight_loader,
)
else:
assert (self.weight_quant and self.weight_quant.strategy
== QuantizationStrategy.TENSOR.value)
weight_scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes),
dtype=torch.float32),
weight_loader=weight_loader)
weight_loader=weight_loader,
)
layer.register_parameter("weight_scale", weight_scale)
......@@ -82,9 +143,10 @@ class CompressedTensors24(CompressedTensorsScheme):
# register input quant scale
assert (self.input_quant.strategy ==
QuantizationStrategy.TENSOR.value)
input_scale = BasevLLMParameter(data=torch.empty(
1, dtype=torch.float32),
weight_loader=weight_loader)
input_scale = BasevLLMParameter(
data=torch.empty(1, dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("input_scale", input_scale)
......@@ -105,13 +167,25 @@ class CompressedTensors24(CompressedTensorsScheme):
"""
Compress weights after loading. Store compressed weight and meta
tensor
:post-condition: layer.w_compressed and layer.meta are
set to the compressed weight and meta tensor in the
format expected by the Cutlass kernels
:param layer: The layer with the weights to be processed
"""
if self.do_sparse_decompress:
layer.weight.data = self._decompress_bitmask_compressed_weight(
compressed=layer.compressed,
bitmask=layer.bitmask,
layer=layer,
)
# compressed and bitmask tensors
# are no longer needed after decompression
del layer.compressed
del layer.bitmask
# torch.compile workaround
if hasattr(layer, "input_scale"):
layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
......@@ -119,10 +193,13 @@ class CompressedTensors24(CompressedTensorsScheme):
if self.weight_quant:
if self.weight_quant.strategy == QuantizationStrategy.TENSOR.value:
layer.weight_scale = torch.nn.Parameter(convert_to_channelwise(
weight_scale=layer.weight_scale,
logical_widths=layer.logical_widths),
requires_grad=False)
layer.weight_scale = torch.nn.Parameter(
convert_to_channelwise(
weight_scale=layer.weight_scale,
logical_widths=layer.logical_widths,
),
requires_grad=False,
)
else:
# torch.compile workaround
layer.weight_scale = torch.nn.Parameter(
......@@ -132,20 +209,22 @@ class CompressedTensors24(CompressedTensorsScheme):
layer.weight = torch.nn.Parameter(w_compressed, requires_grad=False)
layer.meta = torch.nn.Parameter(meta, requires_grad=False)
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Returns the output tensor for the layer with 2:4
Returns the output tensor for the layer with 2:4
sparse compressed weights, given the input tensor
and bias
:param layer: The layer with 2:4 sparse compressed
:param layer: The layer with 2:4 sparse compressed
weights to be used for the computation
:param x: The input tensor to the layer
:param bias: The bias to be added to the output tensor
:return: The output tensor of the layer
:return: The output tensor of the layer
"""
if self.quantized:
scale = None
......@@ -169,13 +248,15 @@ class CompressedTensors24(CompressedTensorsScheme):
input_scale = layer.input_scale
q_input = x
out = ops.cutlass_scaled_sparse_mm(a=q_input,
bt_nzs=layer.weight,
bt_meta=layer.meta,
scale_a=input_scale,
scale_b=layer.weight_scale,
out_dtype=self.output_dtype,
bias=bias)
out = ops.cutlass_scaled_sparse_mm(
a=q_input,
bt_nzs=layer.weight,
bt_meta=layer.meta,
scale_a=input_scale,
scale_b=layer.weight_scale,
out_dtype=self.output_dtype,
bias=bias,
)
assert out.is_contiguous()
return out
......@@ -201,8 +282,71 @@ class CompressedTensors24(CompressedTensorsScheme):
raise ValueError("Quantization type not supported by Cutlass")
def _decompress_bitmask_compressed_weight(
self,
compressed: torch.Tensor,
bitmask: torch.Tensor,
layer: torch.nn.Module,
) -> torch.Tensor:
"""
Decompress a compressed 2:4 sparse weight tensor using the bitmask and
return the result.
This function also supports sharded decompression.
:param compressed: The 2:4 sparse weight tensor compressed using the
sparse-24-bitmask compressor. This is different from
`cutlass_sparse_compress` which uses a different scheme (2 bits for
every nonzero element that represent the coordinate within the block
of 4). The bitmask compression here uses a bitmask to indicate the
positions of non-zero elements.
:param bitmask: The 2:4 bitmask associated with the compressed weights,
representing the positions of non-zero elements in the compressed
tensor.
:param layer: The layer whose weights need to be processed after
loading.
:return: The decompressed 2:4 sparse weight tensor.
"""
def check_24(tensor):
new_tensor = tensor.view(-1, 4)
zero_counts = (new_tensor == 0).sum(dim=1)
return (zero_counts >= 2).all().item()
sparsity_compressor = self.model_compressor.sparsity_compressor
def _process_split(
bitmask_compressed_weight: torch.Tensor,
shape,
bitmask: torch.Tensor,
) -> torch.Tensor:
weight_data = dict(
compressed=bitmask_compressed_weight,
shape=shape,
bitmask=bitmask,
)
return sparsity_compressor.decompress_weight(weight_data)
split_weights: List[torch.Tensor] = []
split_bitmask: List[torch.Tensor] = []
split_shape: List[Tuple[int, int]] = []
if isinstance(layer, (QKVParallelLinear, MergedColumnParallelLinear)):
split_weights = torch.split(compressed, layer.logical_widths)
split_bitmask = torch.split(bitmask, layer.logical_widths)
split_shape = [(out, layer.input_size_per_partition)
for out in layer.logical_widths]
if split_weights:
decompressed_shards = [
_process_split(compressed_weight, shape, bitmask)
for compressed_weight, shape, bitmask in zip(
split_weights, split_shape, split_bitmask)
]
decompressed = combine_shards(decompressed_shards)
else:
decompressed = sparsity_compressor.decompress_weight(
dict(
compressed=compressed,
shape=(
layer.logical_widths[0],
layer.input_size_per_partition,
),
bitmask=bitmask,
))
return decompressed
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from typing import Optional
......
# SPDX-License-Identifier: Apache-2.0
from typing import Callable, List, Optional
import torch
......
# SPDX-License-Identifier: Apache-2.0
from typing import Callable, List, Optional
import torch
......
# SPDX-License-Identifier: Apache-2.0
from typing import Callable, List, Optional
import torch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment