"vllm/vscode:/vscode.git/clone" did not exist on "67fc16cd8cf778a30ad0f7619fe77bd85f1d1633"
Commit 4851c202 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.1' into v0.6.1-dev

parents 9b902f9e 3fd2b0d2
...@@ -29,7 +29,8 @@ WEIGHT_LOADER_V2_SUPPORTED = [ ...@@ -29,7 +29,8 @@ WEIGHT_LOADER_V2_SUPPORTED = [
"CompressedTensorsLinearMethod", "AWQMarlinLinearMethod", "CompressedTensorsLinearMethod", "AWQMarlinLinearMethod",
"AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod", "AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod",
"MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod", "MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod",
"TPUInt8LinearMethod", "GPTQLinearMethod", "FBGEMMFp8LinearMethod" "TPUInt8LinearMethod", "GPTQLinearMethod", "FBGEMMFp8LinearMethod",
"ModelOptFp8LinearMethod"
] ]
......
...@@ -22,10 +22,10 @@ from vllm.model_executor.layers.quantization.gptq_marlin import ( ...@@ -22,10 +22,10 @@ from vllm.model_executor.layers.quantization.gptq_marlin import (
from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQMarlin24Config) GPTQMarlin24Config)
from vllm.model_executor.layers.quantization.marlin import MarlinConfig from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.layers.quantization.modelopt import ModelOptFp8Config
from vllm.model_executor.layers.quantization.neuron_quant import ( from vllm.model_executor.layers.quantization.neuron_quant import (
NeuronQuantConfig) NeuronQuantConfig)
from vllm.model_executor.layers.quantization.qqq import QQQConfig from vllm.model_executor.layers.quantization.qqq import QQQConfig
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
...@@ -35,6 +35,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { ...@@ -35,6 +35,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"tpu_int8": Int8TpuConfig, "tpu_int8": Int8TpuConfig,
"fp8": Fp8Config, "fp8": Fp8Config,
"fbgemm_fp8": FBGEMMFp8Config, "fbgemm_fp8": FBGEMMFp8Config,
"modelopt": ModelOptFp8Config,
# The order of gptq methods is important for config.py iteration over # The order of gptq methods is important for config.py iteration over
# override_quantization_method(..) # override_quantization_method(..)
"marlin": MarlinConfig, "marlin": MarlinConfig,
...@@ -43,7 +44,6 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { ...@@ -43,7 +44,6 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"gptq_marlin": GPTQMarlinConfig, "gptq_marlin": GPTQMarlinConfig,
"awq_marlin": AWQMarlinConfig, "awq_marlin": AWQMarlinConfig,
"gptq": GPTQConfig, "gptq": GPTQConfig,
"squeezellm": SqueezeLLMConfig,
"compressed-tensors": CompressedTensorsConfig, "compressed-tensors": CompressedTensorsConfig,
"bitsandbytes": BitsAndBytesConfig, "bitsandbytes": BitsAndBytesConfig,
"qqq": QQQConfig, "qqq": QQQConfig,
......
...@@ -22,7 +22,7 @@ def awq_dequantize_kernel( ...@@ -22,7 +22,7 @@ def awq_dequantize_kernel(
# Compute offsets and masks for qweight_ptr. # Compute offsets and masks for qweight_ptr.
offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y) offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y)
offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X * 8) // 8 offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X)
offsets = num_cols * offsets_y[:, None] + offsets_x[None, :] offsets = num_cols * offsets_y[:, None] + offsets_x[None, :]
masks_y = offsets_y < num_rows masks_y = offsets_y < num_rows
...@@ -43,6 +43,9 @@ def awq_dequantize_kernel( ...@@ -43,6 +43,9 @@ def awq_dequantize_kernel(
# Load the weights. # Load the weights.
iweights = tl.load(qweight_ptr + offsets, masks) iweights = tl.load(qweight_ptr + offsets, masks)
iweights = tl.interleave(iweights, iweights)
iweights = tl.interleave(iweights, iweights)
iweights = tl.interleave(iweights, iweights)
# Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7] # Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7]
# that will map given indices to the correct order. # that will map given indices to the correct order.
...@@ -59,9 +62,8 @@ def awq_dequantize_kernel( ...@@ -59,9 +62,8 @@ def awq_dequantize_kernel(
iweights = (iweights >> shifts) & 0xF iweights = (iweights >> shifts) & 0xF
# Compute zero offsets and masks. # Compute zero offsets and masks.
zero_offsets_y = (pid_y * BLOCK_SIZE_Y // group_size + zero_offsets_y = pid_y * BLOCK_SIZE_Y // group_size + tl.arange(0, 1)
tl.arange(0, BLOCK_SIZE_Y) // group_size) zero_offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X)
zero_offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X * 8) // 8
zero_offsets = num_cols * zero_offsets_y[:, None] + zero_offsets_x[None, :] zero_offsets = num_cols * zero_offsets_y[:, None] + zero_offsets_x[None, :]
zero_masks_y = zero_offsets_y < num_rows // group_size zero_masks_y = zero_offsets_y < num_rows // group_size
...@@ -70,13 +72,16 @@ def awq_dequantize_kernel( ...@@ -70,13 +72,16 @@ def awq_dequantize_kernel(
# Load the zeros. # Load the zeros.
zeros = tl.load(zeros_ptr + zero_offsets, zero_masks) zeros = tl.load(zeros_ptr + zero_offsets, zero_masks)
zeros = tl.interleave(zeros, zeros)
zeros = tl.interleave(zeros, zeros)
zeros = tl.interleave(zeros, zeros)
zeros = tl.broadcast_to(zeros, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8))
# Unpack and reorder: shift out the correct 4-bit value and mask. # Unpack and reorder: shift out the correct 4-bit value and mask.
zeros = (zeros >> shifts) & 0xF zeros = (zeros >> shifts) & 0xF
# Compute scale offsets and masks. # Compute scale offsets and masks.
scale_offsets_y = (pid_y * BLOCK_SIZE_Y // group_size + scale_offsets_y = pid_y * BLOCK_SIZE_Y // group_size + tl.arange(0, 1)
tl.arange(0, BLOCK_SIZE_Y) // group_size)
scale_offsets_x = (pid_x * BLOCK_SIZE_X * 8 + scale_offsets_x = (pid_x * BLOCK_SIZE_X * 8 +
tl.arange(0, BLOCK_SIZE_X * 8)) tl.arange(0, BLOCK_SIZE_X * 8))
scale_offsets = (num_cols * 8 * scale_offsets_y[:, None] + scale_offsets = (num_cols * 8 * scale_offsets_y[:, None] +
...@@ -87,6 +92,7 @@ def awq_dequantize_kernel( ...@@ -87,6 +92,7 @@ def awq_dequantize_kernel(
# Load the scales. # Load the scales.
scales = tl.load(scales_ptr + scale_offsets, scale_masks) scales = tl.load(scales_ptr + scale_offsets, scale_masks)
scales = tl.broadcast_to(scales, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8))
# Dequantize. # Dequantize.
iweights = (iweights - zeros) * scales iweights = (iweights - zeros) * scales
...@@ -137,12 +143,10 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K, ...@@ -137,12 +143,10 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K,
offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
masks_am = offsets_am < M masks_am = offsets_am < M
offsets_bn = (pid_n * (BLOCK_SIZE_N // 8) + offsets_bn = pid_n * (BLOCK_SIZE_N // 8) + tl.arange(0, BLOCK_SIZE_N // 8)
tl.arange(0, BLOCK_SIZE_N) // 8)
masks_bn = offsets_bn < N // 8 masks_bn = offsets_bn < N // 8
offsets_zn = (pid_n * (BLOCK_SIZE_N // 8) + offsets_zn = pid_n * (BLOCK_SIZE_N // 8) + tl.arange(0, BLOCK_SIZE_N // 8)
tl.arange(0, BLOCK_SIZE_N) // 8)
masks_zn = offsets_zn < N // 8 masks_zn = offsets_zn < N // 8
offsets_sn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offsets_sn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
...@@ -165,22 +169,30 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K, ...@@ -165,22 +169,30 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K,
masks_b = masks_k[:, None] & masks_bn[None, :] masks_b = masks_k[:, None] & masks_bn[None, :]
b = tl.load(b_ptrs, mask=masks_b) b = tl.load(b_ptrs, mask=masks_b)
b = tl.interleave(b, b)
b = tl.interleave(b, b)
b = tl.interleave(b, b)
# Dequantize b. # Dequantize b.
offsets_szk = ( offsets_szk = (
(BLOCK_SIZE_K * SPLIT_K * k + pid_z * BLOCK_SIZE_K) // group_size + (BLOCK_SIZE_K * SPLIT_K * k + pid_z * BLOCK_SIZE_K) // group_size +
tl.arange(0, BLOCK_SIZE_K) // group_size) tl.arange(0, 1))
offsets_z = (N // 8) * offsets_szk[:, None] + offsets_zn[None, :] offsets_z = (N // 8) * offsets_szk[:, None] + offsets_zn[None, :]
masks_zk = offsets_szk < K // group_size masks_zk = offsets_szk < K // group_size
masks_z = masks_zk[:, None] & masks_zn[None, :] masks_z = masks_zk[:, None] & masks_zn[None, :]
zeros_ptrs = zeros_ptr + offsets_z zeros_ptrs = zeros_ptr + offsets_z
zeros = tl.load(zeros_ptrs, mask=masks_z) zeros = tl.load(zeros_ptrs, mask=masks_z)
zeros = tl.interleave(zeros, zeros)
zeros = tl.interleave(zeros, zeros)
zeros = tl.interleave(zeros, zeros)
zeros = tl.broadcast_to(zeros, (BLOCK_SIZE_K, BLOCK_SIZE_N))
offsets_s = N * offsets_szk[:, None] + offsets_sn[None, :] offsets_s = N * offsets_szk[:, None] + offsets_sn[None, :]
masks_sk = offsets_szk < K // group_size masks_sk = offsets_szk < K // group_size
masks_s = masks_sk[:, None] & masks_sn[None, :] masks_s = masks_sk[:, None] & masks_sn[None, :]
scales_ptrs = scales_ptr + offsets_s scales_ptrs = scales_ptr + offsets_s
scales = tl.load(scales_ptrs, mask=masks_s) scales = tl.load(scales_ptrs, mask=masks_s)
scales = tl.broadcast_to(scales, (BLOCK_SIZE_K, BLOCK_SIZE_N))
b = (b >> shifts) & 0xF b = (b >> shifts) & 0xF
zeros = (zeros >> shifts) & 0xF zeros = (zeros >> shifts) & 0xF
......
...@@ -116,15 +116,19 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -116,15 +116,19 @@ class CompressedTensorsConfig(QuantizationConfig):
def _check_scheme_supported(self, def _check_scheme_supported(self,
min_capability: int, min_capability: int,
error: bool = True) -> bool: error: bool = True) -> bool:
capability = current_platform.get_device_capability() capability = current_platform.get_device_capability() # type: ignore
capability = capability[0] * 10 + capability[1]
supported = capability >= min_capability if capability is not None:
if error and not supported: capability = capability[0] * 10 + capability[1]
raise RuntimeError( supported = capability >= min_capability
"Quantization scheme is not supported for ", if error and not supported:
f"the current GPU. Min capability: {min_capability}. ", raise RuntimeError(
f"Current capability: {capability}.") "Quantization scheme is not supported for ",
return supported f"the current GPU. Min capability: {min_capability}. ",
f"Current capability: {capability}.")
return supported
else:
return False
def _is_static_tensor_w8a8(self, weight_quant: BaseModel, def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool: input_quant: BaseModel) -> bool:
...@@ -232,7 +236,8 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -232,7 +236,8 @@ class CompressedTensorsConfig(QuantizationConfig):
return CompressedTensorsWNA16( return CompressedTensorsWNA16(
num_bits=weight_quant.num_bits, num_bits=weight_quant.num_bits,
strategy=weight_quant.strategy, strategy=weight_quant.strategy,
group_size=weight_quant.group_size) group_size=weight_quant.group_size,
actorder=weight_quant.actorder)
# Detect If Activation Quantization. # Detect If Activation Quantization.
# TODO @dsikka: clean-up conditions # TODO @dsikka: clean-up conditions
......
...@@ -5,9 +5,7 @@ from typing import Callable, List, Optional ...@@ -5,9 +5,7 @@ from typing import Callable, List, Optional
import torch import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
WNA16_SUPPORTED_BITS)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
CompressionFormat) CompressionFormat)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
...@@ -40,11 +38,10 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): ...@@ -40,11 +38,10 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
if not (self.quant_config.quant_format if not (self.quant_config.quant_format
== CompressionFormat.pack_quantized.value == CompressionFormat.pack_quantized.value
and self.num_bits in WNA16_SUPPORTED_BITS): and self.num_bits == 4):
raise ValueError("For Fused MoE layers, only ", raise ValueError("For Fused MoE layers, only ",
f"{CompressionFormat.pack_quantized.value} ", f"{CompressionFormat.pack_quantized.value} ",
"is supported for the following bits: ", "is supported for 4 bits")
f"{WNA16_SUPPORTED_BITS}")
def create_weights(self, layer: torch.nn.Module, num_experts: int, def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size: int, hidden_size: int, intermediate_size: int,
...@@ -269,19 +266,30 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): ...@@ -269,19 +266,30 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
) -> torch.Tensor: ) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
fused_marlin_moe) fused_marlin_moe)
return fused_marlin_moe(x, topk_weights, topk_ids = FusedMoE.select_experts(
layer.w13_weight_packed, hidden_states=x,
layer.w2_weight_packed, router_logits=router_logits,
router_logits, use_grouped_topk=use_grouped_topk,
layer.w13_g_idx, top_k=top_k,
layer.w2_g_idx, renormalize=renormalize,
layer.w13_g_idx_sort_indices, topk_group=topk_group,
layer.w2_g_idx_sort_indices, num_expert_group=num_expert_group,
top_k, custom_routing_function=custom_routing_function)
custom_routing_function,
renormalize=renormalize, return fused_marlin_moe(
w1_scale=layer.w13_weight_scale, x,
w2_scale=layer.w2_weight_scale) layer.w13_weight_packed,
layer.w2_weight_packed,
router_logits,
layer.w13_g_idx,
layer.w2_g_idx,
layer.w13_g_idx_sort_indices,
layer.w2_g_idx_sort_indices,
topk_weights,
topk_ids,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
)
...@@ -5,20 +5,24 @@ import torch ...@@ -5,20 +5,24 @@ import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme) CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
ActivationOrdering)
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_gptq_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace, apply_gptq_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace,
marlin_permute_scales, replace_tensor, verify_marlin_supported, marlin_permute_scales, marlin_repeat_scales_on_all_ranks,
marlin_sort_g_idx, replace_tensor, verify_marlin_supported,
verify_marlin_supports_shape) verify_marlin_supports_shape)
from vllm.model_executor.parameter import (BasevLLMParameter, from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter, ChannelQuantScaleParameter,
GroupQuantScaleParameter, GroupQuantScaleParameter,
PackedvLLMParameter) PackedvLLMParameter,
RowvLLMParameter)
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
__all__ = ["CompressedTensorsWNA16"] __all__ = ["CompressedTensorsWNA16"]
WNA16_SUPPORTED_TYPES_MAP = { WNA16_SUPPORTED_TYPES_MAP = {
4: scalar_types.uint4b8, 4: scalar_types.uint4b8,
8: scalar_types.uint8b128, 8: scalar_types.uint8b128
} }
WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys()) WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys())
...@@ -28,11 +32,13 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): ...@@ -28,11 +32,13 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
def __init__(self, def __init__(self,
strategy: str, strategy: str,
num_bits: int, num_bits: int,
group_size: Optional[int] = None): group_size: Optional[int] = None,
actorder: Optional[ActivationOrdering] = None):
self.pack_factor = 32 // num_bits self.pack_factor = 32 // num_bits
self.strategy = strategy self.strategy = strategy
self.group_size = -1 if group_size is None else group_size self.group_size = -1 if group_size is None else group_size
self.has_g_idx = actorder == ActivationOrdering.GROUP
if self.group_size == -1 and self.strategy != "channel": if self.group_size == -1 and self.strategy != "channel":
raise ValueError("Marlin kernels require group quantization or " raise ValueError("Marlin kernels require group quantization or "
...@@ -64,12 +70,10 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): ...@@ -64,12 +70,10 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
output_size_per_partition = sum(output_partition_sizes) output_size_per_partition = sum(output_partition_sizes)
# If group_size is -1, we are in channelwise case. # If group_size is -1, we are in channelwise case.
channelwise = (self.group_size == -1)
group_size = self.group_size if self.group_size != -1 else input_size group_size = self.group_size if self.group_size != -1 else input_size
row_parallel = (input_size != input_size_per_partition) row_parallel = (input_size != input_size_per_partition)
# In the case of channelwise quantization, we need to replicate the partition_scales = not marlin_repeat_scales_on_all_ranks(
# scales across all gpus. self.has_g_idx, self.group_size, row_parallel)
partition_scales = (row_parallel and not channelwise)
verify_marlin_supports_shape( verify_marlin_supports_shape(
output_size_per_partition=output_size_per_partition, output_size_per_partition=output_size_per_partition,
...@@ -123,6 +127,16 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): ...@@ -123,6 +127,16 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
layer.register_parameter("weight_scale", weight_scale) layer.register_parameter("weight_scale", weight_scale)
layer.register_parameter("weight_shape", weight_shape) layer.register_parameter("weight_shape", weight_shape)
# group index (for activation reordering)
if self.has_g_idx:
weight_g_idx = RowvLLMParameter(data=torch.empty(
input_size_per_partition,
dtype=torch.int32,
),
input_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight_g_idx", weight_g_idx)
layer.input_size_per_partition = input_size_per_partition layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition layer.output_size_per_partition = output_size_per_partition
layer.input_size = input_size layer.input_size = input_size
...@@ -137,9 +151,14 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): ...@@ -137,9 +151,14 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
layer.workspace = marlin_make_workspace( layer.workspace = marlin_make_workspace(
layer.output_size_per_partition, device) layer.output_size_per_partition, device)
# Act-order not supported in compressed-tensors yet, so set to empty. # Handle sorting for activation reordering if needed.
layer.g_idx = marlin_make_empty_g_idx(device) if self.has_g_idx:
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) g_idx, g_idx_sort_indices = marlin_sort_g_idx(layer.weight_g_idx)
layer.g_idx_sort_indices = g_idx_sort_indices
replace_tensor(layer, "weight_g_idx", g_idx)
else:
layer.weight_g_idx = marlin_make_empty_g_idx(device)
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
# No zero-point # No zero-point
layer.weight_zp = marlin_make_empty_g_idx(device) layer.weight_zp = marlin_make_empty_g_idx(device)
...@@ -159,9 +178,11 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): ...@@ -159,9 +178,11 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
replace_tensor(layer, "weight_packed", marlin_qweight) replace_tensor(layer, "weight_packed", marlin_qweight)
# Permute scales from compressed-tensors format to marlin format. # Permute scales from compressed-tensors format to marlin format.
# scale is required on all partitions if activation reordering
marlin_scales = marlin_permute_scales( marlin_scales = marlin_permute_scales(
layer.weight_scale, layer.weight_scale,
size_k=layer.input_size_per_partition, size_k=(layer.input_size
if self.has_g_idx else layer.input_size_per_partition),
size_n=layer.output_size_per_partition, size_n=layer.output_size_per_partition,
group_size=layer.group_size) group_size=layer.group_size)
replace_tensor(layer, "weight_scale", marlin_scales) replace_tensor(layer, "weight_scale", marlin_scales)
...@@ -174,7 +195,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): ...@@ -174,7 +195,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
weight=layer.weight_packed, weight=layer.weight_packed,
weight_scale=layer.weight_scale, weight_scale=layer.weight_scale,
weight_zp=layer.weight_zp, weight_zp=layer.weight_zp,
g_idx=layer.g_idx, g_idx=layer.weight_g_idx,
g_idx_sort_indices=layer.g_idx_sort_indices, g_idx_sort_indices=layer.g_idx_sort_indices,
workspace=layer.workspace, workspace=layer.workspace,
wtype=self.quant_type, wtype=self.quant_type,
......
import re import re
from enum import Enum from enum import Enum
from typing import Any, Dict, Iterable, Optional from typing import Any, Dict, Iterable, Optional, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, field_validator
from torch.nn import Module from torch.nn import Module
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
...@@ -40,6 +40,19 @@ class QuantizationStrategy(str, Enum): ...@@ -40,6 +40,19 @@ class QuantizationStrategy(str, Enum):
TOKEN = "token" TOKEN = "token"
class ActivationOrdering(str, Enum):
"""
Enum storing strategies for activation ordering
Group: reorder groups and weight\n
Weight: only reorder weight, not groups. Slightly lower latency and
accuracy compared to group actorder\n
"""
GROUP = "group"
WEIGHT = "weight"
class QuantizationArgs(BaseModel): class QuantizationArgs(BaseModel):
""" """
User facing arguments used to define a quantization config User facing arguments used to define a quantization config
...@@ -58,6 +71,8 @@ class QuantizationArgs(BaseModel): ...@@ -58,6 +71,8 @@ class QuantizationArgs(BaseModel):
observed with every sample. Defaults to False for static observed with every sample. Defaults to False for static
quantization. Note that enabling dynamic quantization quantization. Note that enabling dynamic quantization
will change the default observer to a memoryless one will change the default observer to a memoryless one
:param actorder: whether to apply group quantization in decreasing order of
activation. Defaults to None for arbitrary ordering
""" """
num_bits: int = 8 num_bits: int = 8
...@@ -67,6 +82,7 @@ class QuantizationArgs(BaseModel): ...@@ -67,6 +82,7 @@ class QuantizationArgs(BaseModel):
strategy: Optional[QuantizationStrategy] = None strategy: Optional[QuantizationStrategy] = None
block_structure: Optional[str] = None block_structure: Optional[str] = None
dynamic: bool = False dynamic: bool = False
actorder: Union[ActivationOrdering, bool, None] = None
observer: str = Field( observer: str = Field(
default="minmax", default="minmax",
description=("The class to use to compute the quantization param - " description=("The class to use to compute the quantization param - "
...@@ -79,6 +95,16 @@ class QuantizationArgs(BaseModel): ...@@ -79,6 +95,16 @@ class QuantizationArgs(BaseModel):
"Observers constructor excluding quantization range or symmetry"), "Observers constructor excluding quantization range or symmetry"),
) )
@field_validator("actorder", mode="before")
def validate_actorder(cls, value) -> Optional[ActivationOrdering]:
if isinstance(value, bool):
return ActivationOrdering.GROUP if value else None
if isinstance(value, str):
return ActivationOrdering(value.lower())
return value
def is_activation_quantization_format(format: str) -> bool: def is_activation_quantization_format(format: str) -> bool:
_ACTIVATION_QUANTIZATION_FORMATS = [ _ACTIVATION_QUANTIZATION_FORMATS = [
......
from typing import Any, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional, Union
import torch import torch
from torch.nn import Parameter from torch.nn import Parameter
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_gptq_marlin_linear, check_marlin_supported, marlin_is_k_full, apply_gptq_marlin_linear, check_marlin_supported, marlin_is_k_full,
marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales, marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales,
marlin_repeat_scales_on_all_ranks, marlin_sort_g_idx, replace_tensor, marlin_permute_scales, marlin_repeat_scales_on_all_ranks,
verify_marlin_supported, verify_marlin_supports_shape) marlin_sort_g_idx, replace_tensor, verify_marlin_supported,
verify_marlin_supports_shape)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.parameter import (ChannelQuantScaleParameter, from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
GroupQuantScaleParameter, GroupQuantScaleParameter,
...@@ -33,8 +37,14 @@ class GPTQMarlinConfig(QuantizationConfig): ...@@ -33,8 +37,14 @@ class GPTQMarlinConfig(QuantizationConfig):
(8, True): scalar_types.uint8b128, (8, True): scalar_types.uint8b128,
} }
def __init__(self, weight_bits: int, group_size: int, desc_act: bool, def __init__(
is_sym: bool, lm_head_quantized: bool) -> None: self,
weight_bits: int,
group_size: int,
desc_act: bool,
is_sym: bool,
lm_head_quantized: bool,
) -> None:
if desc_act and group_size == -1: if desc_act and group_size == -1:
# In this case, act_order == True is the same as act_order == False # In this case, act_order == True is the same as act_order == False
# (since we have only one group per output channel) # (since we have only one group per output channel)
...@@ -51,10 +61,6 @@ class GPTQMarlinConfig(QuantizationConfig): ...@@ -51,10 +61,6 @@ class GPTQMarlinConfig(QuantizationConfig):
self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)] self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)]
# Verify supported on platform.
verify_marlin_supported(quant_type=self.quant_type,
group_size=self.group_size)
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"GPTQMarlinConfig(quant_type={self.quant_type}, " return (f"GPTQMarlinConfig(quant_type={self.quant_type}, "
f"group_size={self.group_size}, " f"group_size={self.group_size}, "
...@@ -109,11 +115,14 @@ class GPTQMarlinConfig(QuantizationConfig): ...@@ -109,11 +115,14 @@ class GPTQMarlinConfig(QuantizationConfig):
" faster inference") " faster inference")
return None return None
def get_quant_method(self, layer: torch.nn.Module, def get_quant_method(
prefix: str) -> Optional["GPTQMarlinLinearMethod"]: self, layer: torch.nn.Module, prefix: str
if (isinstance(layer, LinearBase) or ) -> Optional[Union["GPTQMarlinLinearMethod", "GPTQMarlinMoEMethod"]]:
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)): if isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead)
and self.lm_head_quantized):
return GPTQMarlinLinearMethod(self) return GPTQMarlinLinearMethod(self)
elif isinstance(layer, FusedMoE):
return GPTQMarlinMoEMethod(self)
return None return None
def get_scaled_act_names(self) -> List[str]: def get_scaled_act_names(self) -> List[str]:
...@@ -153,6 +162,10 @@ class GPTQMarlinLinearMethod(LinearMethodBase): ...@@ -153,6 +162,10 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
def __init__(self, quant_config: GPTQMarlinConfig) -> None: def __init__(self, quant_config: GPTQMarlinConfig) -> None:
self.quant_config = quant_config self.quant_config = quant_config
# Verify supported on platform.
verify_marlin_supported(quant_type=self.quant_config.quant_type,
group_size=self.quant_config.group_size)
def create_weights( def create_weights(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -179,7 +192,8 @@ class GPTQMarlinLinearMethod(LinearMethodBase): ...@@ -179,7 +192,8 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
output_size_per_partition=output_size_per_partition, output_size_per_partition=output_size_per_partition,
input_size_per_partition=input_size_per_partition, input_size_per_partition=input_size_per_partition,
input_size=input_size, input_size=input_size,
group_size=group_size) group_size=group_size,
)
# Determine sharding # Determine sharding
if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act, if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act,
...@@ -299,7 +313,8 @@ class GPTQMarlinLinearMethod(LinearMethodBase): ...@@ -299,7 +313,8 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
perm=layer.g_idx_sort_indices, perm=layer.g_idx_sort_indices,
size_k=layer.input_size_per_partition, size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition, size_n=layer.output_size_per_partition,
num_bits=self.quant_config.quant_type.size_bits) num_bits=self.quant_config.quant_type.size_bits,
)
replace_tensor(layer, "qweight", marlin_qweight) replace_tensor(layer, "qweight", marlin_qweight)
# Permute scales from autogptq format to marlin format. # Permute scales from autogptq format to marlin format.
...@@ -308,7 +323,8 @@ class GPTQMarlinLinearMethod(LinearMethodBase): ...@@ -308,7 +323,8 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
size_k=(layer.input_size if self.quant_config.desc_act else size_k=(layer.input_size if self.quant_config.desc_act else
layer.input_size_per_partition), layer.input_size_per_partition),
size_n=layer.output_size_per_partition, size_n=layer.output_size_per_partition,
group_size=self.quant_config.group_size) group_size=self.quant_config.group_size,
)
replace_tensor(layer, "scales", marlin_scales) replace_tensor(layer, "scales", marlin_scales)
def apply( def apply(
...@@ -329,4 +345,270 @@ class GPTQMarlinLinearMethod(LinearMethodBase): ...@@ -329,4 +345,270 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
output_size_per_partition=layer.output_size_per_partition, output_size_per_partition=layer.output_size_per_partition,
input_size_per_partition=layer.input_size_per_partition, input_size_per_partition=layer.input_size_per_partition,
is_k_full=layer.is_k_full, is_k_full=layer.is_k_full,
bias=bias) bias=bias,
)
class GPTQMarlinMoEMethod(FusedMoEMethodBase):
"""MoE Marlin method with quantization."""
def __init__(self, quant_config: GPTQMarlinConfig) -> None:
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
# Currently assuming is_k_full is always True
# (input size per partition is the same as full input size)
# Supports only sym for now (no zp)
if self.quant_config.group_size != -1:
scales_size13 = hidden_size // self.quant_config.group_size
scales_size2 = intermediate_size // self.quant_config.group_size
strategy = FusedMoeWeightScaleSupported.GROUP.value
else:
scales_size13 = 1
scales_size2 = 1
strategy = FusedMoeWeightScaleSupported.CHANNEL.value
extra_weight_attrs.update({
"quant_method": strategy,
"is_transposed": True
})
# Fused gate_up_proj (column parallel)
w13_qweight = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size // self.quant_config.pack_factor,
2 * intermediate_size,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w13_qweight", w13_qweight)
set_weight_attrs(w13_qweight, extra_weight_attrs)
# down_proj (row parallel)
w2_qweight = torch.nn.Parameter(
torch.empty(
num_experts,
intermediate_size // self.quant_config.pack_factor,
hidden_size,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w2_qweight", w2_qweight)
set_weight_attrs(w2_qweight, extra_weight_attrs)
# up_proj scales
w13_scales = torch.nn.Parameter(
torch.empty(num_experts,
scales_size13,
2 * intermediate_size,
dtype=torch.half),
requires_grad=False,
)
layer.register_parameter("w13_scales", w13_scales)
set_weight_attrs(w13_scales, extra_weight_attrs)
# down_proj scales
w2_scales = torch.nn.Parameter(
torch.empty(num_experts,
scales_size2,
hidden_size,
dtype=torch.half),
requires_grad=False,
)
layer.register_parameter("w2_scales", w2_scales)
set_weight_attrs(w2_scales, extra_weight_attrs)
# up_proj scales
w13_qzeros = torch.nn.Parameter(
torch.empty(num_experts,
scales_size13,
2 * intermediate_size // self.quant_config.pack_factor,
dtype=params_dtype),
requires_grad=False,
)
layer.register_parameter("w13_qzeros", w13_qzeros)
set_weight_attrs(w13_qzeros, extra_weight_attrs)
# down_proj scales
w2_qzeros = torch.nn.Parameter(
torch.empty(num_experts,
scales_size2,
hidden_size // self.quant_config.pack_factor,
dtype=params_dtype),
requires_grad=False,
)
layer.register_parameter("w2_qzeros", w2_qzeros)
set_weight_attrs(w2_qzeros, extra_weight_attrs)
w13_g_idx = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w13_g_idx", w13_g_idx)
set_weight_attrs(w13_g_idx, extra_weight_attrs)
w2_g_idx = torch.nn.Parameter(
torch.empty(
num_experts,
intermediate_size,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w2_g_idx", w2_g_idx)
set_weight_attrs(w2_g_idx, extra_weight_attrs)
w13_g_idx_sort_indices = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w13_g_idx_sort_indices",
w13_g_idx_sort_indices)
set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs)
w2_g_idx_sort_indices = torch.nn.Parameter(
torch.empty(
num_experts,
intermediate_size,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w2_g_idx_sort_indices",
w2_g_idx_sort_indices)
set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Process act_order
if self.quant_config.desc_act:
# Get sorting based on g_idx
num_experts = layer.w13_g_idx.shape[0]
w13_g_idx_sort_indices = torch.empty_like(layer.w13_g_idx)
w2_g_idx_sort_indices = torch.empty_like(layer.w2_g_idx)
w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx)
w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx)
for e in range(num_experts):
w13_g_idx_sort_indices[e] = torch.argsort(
layer.w13_g_idx[e]).to(torch.int32)
w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_g_idx[e]).to(
torch.int32)
w13_sorted_g_idx[e] = layer.w13_g_idx[e][
w13_g_idx_sort_indices[e]]
w2_sorted_g_idx[e] = layer.w2_g_idx[e][
w2_g_idx_sort_indices[e]]
replace_tensor(layer, "w13_g_idx", w13_sorted_g_idx)
replace_tensor(layer, "w2_g_idx", w2_sorted_g_idx)
replace_tensor(layer, "w13_g_idx_sort_indices",
w13_g_idx_sort_indices)
replace_tensor(layer, "w2_g_idx_sort_indices",
w2_g_idx_sort_indices)
else:
# Reset g_idx related tensors
num_experts = layer.w13_g_idx.shape[0]
device = layer.w13_g_idx.device
layer.w13_g_idx = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32,
device=device),
requires_grad=False,
)
layer.w2_g_idx = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32,
device=device),
requires_grad=False,
)
layer.w13_g_idx_sort_indices = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32,
device=device),
requires_grad=False,
)
layer.w2_g_idx_sort_indices = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32,
device=device),
requires_grad=False,
)
# Repack weights
marlin_w13_qweight = ops.gptq_marlin_moe_repack(
layer.w13_qweight,
layer.w13_g_idx_sort_indices,
layer.w13_qweight.shape[1] * self.quant_config.pack_factor,
layer.w13_qweight.shape[2],
self.quant_config.quant_type.size_bits,
)
replace_tensor(layer, "w13_qweight", marlin_w13_qweight)
marlin_w2_qweight = ops.gptq_marlin_moe_repack(
layer.w2_qweight,
layer.w2_g_idx_sort_indices,
layer.w2_qweight.shape[1] * self.quant_config.pack_factor,
layer.w2_qweight.shape[2],
self.quant_config.quant_type.size_bits,
)
replace_tensor(layer, "w2_qweight", marlin_w2_qweight)
# Repack scales
marlin_w13_scales = marlin_moe_permute_scales(
s=layer.w13_scales,
size_k=layer.intermediate_size_per_partition,
size_n=layer.w13_scales.shape[2],
group_size=self.quant_config.group_size,
)
replace_tensor(layer, "w13_scales", marlin_w13_scales)
marlin_w2_scales = marlin_moe_permute_scales(
s=layer.w2_scales,
size_k=layer.w2_scales.shape[1] * self.quant_config.pack_factor,
size_n=layer.w2_scales.shape[2],
group_size=self.quant_config.group_size,
)
replace_tensor(layer, "w2_scales", marlin_w2_scales)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
fused_marlin_moe)
# The input must currently be float16
orig_dtype = x.dtype
x = x.half()
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=None)
return fused_marlin_moe(
x,
layer.w13_qweight,
layer.w2_qweight,
router_logits,
layer.w13_g_idx,
layer.w2_g_idx,
layer.w13_g_idx_sort_indices,
layer.w2_g_idx_sort_indices,
topk_weights,
topk_ids,
w1_scale=layer.w13_scales,
w2_scale=layer.w2_scales,
).to(orig_dtype)
from typing import Any, Dict, List, Optional
import torch
from torch.nn import Module
from torch.nn.parameter import Parameter
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear, cutlass_fp8_supported, requantize_with_max_scale)
from vllm.model_executor.parameter import (ModelWeightParameter,
PerTensorScaleParameter)
logger = init_logger(__name__)
ACTIVATION_SCHEMES = ["static"]
class ModelOptFp8Config(QuantizationConfig):
"""Config class for ModelOpt FP8."""
def __init__(
self,
is_checkpoint_fp8_serialized: bool = False,
) -> None:
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
if is_checkpoint_fp8_serialized:
logger.warning("Detected ModelOpt fp8 checkpoint. Please note that"
" the format is experimental and could change.")
@classmethod
def get_name(cls) -> str:
return "modelopt"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.bfloat16, torch.half]
@classmethod
def get_min_capability(cls) -> int:
return 89
@classmethod
def get_config_filenames(cls) -> List[str]:
return ["hf_quant_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp8Config":
quant_config = cls.get_from_keys(config, ["quantization"])
quant_method = quant_config["quant_algo"]
is_checkpoint_fp8_serialized = ("FP8" in quant_method)
if not is_checkpoint_fp8_serialized:
raise ValueError("ModelOpt currently only supports static FP8"
"quantization in vLLM. Please check the "
"`hf_quant_config.json` file for your model's "
"quant configuration.")
return cls(is_checkpoint_fp8_serialized)
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
if isinstance(layer, LinearBase):
return ModelOptFp8LinearMethod(self)
elif isinstance(layer, Attention):
return ModelOptFp8KVCacheMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
"""
Supports loading kv-cache scaling factors from FP8 checkpoints.
"""
def __init__(self, quant_config: ModelOptFp8Config):
super().__init__(quant_config)
class ModelOptFp8LinearMethod(LinearMethodBase):
"""Linear method for Model Optimizer static quantization.
Supports loading FP8 checkpoints with static weight scale and
activation scale. Future support might be added for dynamic
scales.
Limitations:
1. Only support per-tensor quantization due to torch._scaled_mm support.
2. Only support float8_e4m3fn datatype
Args: quant_config: The ModelOpt quantization config.
"""
def __init__(self, quant_config: ModelOptFp8Config):
self.quant_config = quant_config
self.cutlass_fp8_supported = cutlass_fp8_supported()
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
del input_size, output_size
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
weight_dtype = (torch.float8_e4m3fn
if self.quant_config.is_checkpoint_fp8_serialized else
params_dtype)
weight = ModelWeightParameter(data=torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=weight_dtype),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight", weight)
if self.quant_config.is_checkpoint_fp8_serialized:
# WEIGHT SCALE
weight_scale = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
weight_scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("weight_scale", weight_scale)
# INPUT SCALE
scale = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("input_scale", scale)
def process_weights_after_loading(self, layer: Module) -> None:
max_w_scale, weight = requantize_with_max_scale(
layer.weight, layer.weight_scale, layer.logical_widths)
layer.weight = Parameter(weight.t(), requires_grad=False)
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
layer.input_scale = Parameter(layer.input_scale.max(),
requires_grad=False)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return apply_fp8_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias,
cutlass_fp8_supported=self.cutlass_fp8_supported)
from typing import Any, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import is_hip
class SqueezeLLMConfig(QuantizationConfig):
"""Config class for SqueezeLLM.
Reference: https://arxiv.org/pdf/2306.07629
"""
def __init__(
self,
weight_bits: int,
) -> None:
self.weight_bits = weight_bits
if self.weight_bits != 4:
raise ValueError(
"Currently, only 4-bit weight quantization is supported for "
f"SqueezeLLM, but got {self.weight_bits} bits.")
self.pack_factor = 32 // self.weight_bits
def __repr__(self) -> str:
return f"SqueezeLLMConfig(weight_bits={self.weight_bits})"
def get_name(self) -> str:
return "squeezellm"
def get_supported_act_dtypes(self) -> List[torch.dtype]:
return [torch.half]
@classmethod
def get_min_capability(cls) -> int:
return 70
@staticmethod
def get_config_filenames() -> List[str]:
return ["quant_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "SqueezeLLMConfig":
weight_bits = cls.get_from_keys(config, ["wbits"])
return cls(weight_bits)
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional[QuantizeMethodBase]:
if isinstance(layer, LinearBase):
return SqueezeLLMLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
class SqueezeLLMLinearMethod(QuantizeMethodBase):
"""Linear method for SqueezeLLM.
Args:
quant_config: The SqueezeLLM quantization config.
"""
def __init__(self, quant_config: SqueezeLLMConfig):
self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
if input_size_per_partition % self.quant_config.pack_factor != 0:
raise ValueError(
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
output_size_per_partition = sum(output_partition_sizes)
qweight = Parameter(
torch.empty(
input_size_per_partition // self.quant_config.pack_factor,
output_size_per_partition,
dtype=torch.int32,
),
requires_grad=False,
)
set_weight_attrs(
qweight, {
"input_dim": 0,
"output_dim": 1,
"packed_dim": 0,
"pack_factor": self.quant_config.pack_factor,
})
lookup_table = Parameter(
torch.empty(
output_size,
self.quant_config.weight_bits**2,
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(lookup_table, {
"output_dim": 0,
})
layer.register_parameter("qweight", qweight)
set_weight_attrs(qweight, extra_weight_attrs)
layer.register_parameter("lookup_table", lookup_table)
set_weight_attrs(lookup_table, extra_weight_attrs)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
qweight = layer.qweight
lookup_table = layer.lookup_table
out_shape = x.shape[:-1] + (qweight.shape[-1], )
reshaped_x = x.reshape(-1, x.shape[-1])
if is_hip():
out_f = torch.zeros(out_shape, dtype=torch.float)
ops.squeezellm_gemm(reshaped_x, qweight, out_f, lookup_table)
out = out_f.to(dtype=torch.float16)
else:
# NOTE: The output tensor should be zero-initialized.
out = torch.zeros(out_shape, dtype=torch.float16)
ops.squeezellm_gemm(reshaped_x, qweight, out, lookup_table)
if bias is not None:
out.add_(bias)
return out.reshape(out_shape)
...@@ -176,6 +176,23 @@ def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int, ...@@ -176,6 +176,23 @@ def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int,
return s return s
def marlin_moe_permute_scales(
s: torch.Tensor,
size_k: int,
size_n: int,
group_size: int,
):
num_experts = s.shape[0]
output = torch.empty(
(num_experts, s.shape[1], s.shape[2]),
device=s.device,
dtype=s.dtype,
)
for e in range(num_experts):
output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size)
return output
def marlin_zero_points(zp: torch.Tensor, size_k: int, size_n: int, def marlin_zero_points(zp: torch.Tensor, size_k: int, size_n: int,
num_bits: int) -> torch.Tensor: num_bits: int) -> torch.Tensor:
# Permute zero-points in a similar way to scales, but do not use the # Permute zero-points in a similar way to scales, but do not use the
......
"""Utility functions used for tests and benchmarks""" """Utility functions used for tests and benchmarks"""
from typing import List from typing import List, Optional
import numpy as np import numpy as np
import torch import torch
...@@ -92,8 +92,11 @@ def get_weight_perm(num_bits: int): ...@@ -92,8 +92,11 @@ def get_weight_perm(num_bits: int):
return perm return perm
def marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int, def marlin_quantize(w: torch.Tensor,
act_order: bool): quant_type: ScalarType,
group_size: int,
act_order: bool,
test_perm: Optional[torch.Tensor] = None):
size_k, size_n = w.shape size_k, size_n = w.shape
num_bits = quant_type.size_bits num_bits = quant_type.size_bits
...@@ -104,7 +107,7 @@ def marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int, ...@@ -104,7 +107,7 @@ def marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int,
# Quantize (and apply act_order if provided) # Quantize (and apply act_order if provided)
w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights( w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(
w, quant_type, group_size, act_order) w, quant_type, group_size, act_order, test_perm)
# For act_order, sort the "weights" and "g_idx" so that group ids are # For act_order, sort the "weights" and "g_idx" so that group ids are
# increasing # increasing
......
"""This file is used for /tests and /benchmarks""" """This file is used for /tests and /benchmarks"""
from typing import List from typing import List, Optional
import numpy import numpy
import torch import torch
...@@ -53,7 +53,10 @@ def get_pack_factor(num_bits): ...@@ -53,7 +53,10 @@ def get_pack_factor(num_bits):
return 32 // num_bits return 32 // num_bits
def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int): def permute_rows(q_w: torch.Tensor,
w_ref: torch.Tensor,
group_size: int,
test_perm: Optional[torch.Tensor] = None):
assert q_w.shape == w_ref.shape assert q_w.shape == w_ref.shape
orig_device = q_w.device orig_device = q_w.device
...@@ -64,7 +67,7 @@ def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int): ...@@ -64,7 +67,7 @@ def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int):
g_idx[i] = i // group_size g_idx[i] = i // group_size
# Simulate act_order by doing a random permutation on K # Simulate act_order by doing a random permutation on K
rand_perm = torch.randperm(k_size) rand_perm = test_perm if test_perm is not None else torch.randperm(k_size)
g_idx = g_idx[rand_perm].contiguous() g_idx = g_idx[rand_perm].contiguous()
q_w = q_w[rand_perm, :].contiguous() q_w = q_w[rand_perm, :].contiguous()
...@@ -164,8 +167,11 @@ def quantize_weights(w: torch.Tensor, ...@@ -164,8 +167,11 @@ def quantize_weights(w: torch.Tensor,
) )
def gptq_quantize_weights(w: torch.Tensor, quant_type: ScalarType, def gptq_quantize_weights(w: torch.Tensor,
group_size: int, act_order: bool): quant_type: ScalarType,
group_size: int,
act_order: bool,
test_perm: Optional[torch.Tensor] = None):
size_k, _ = w.shape size_k, _ = w.shape
assert w.is_floating_point(), "w must be float" assert w.is_floating_point(), "w must be float"
...@@ -186,7 +192,8 @@ def gptq_quantize_weights(w: torch.Tensor, quant_type: ScalarType, ...@@ -186,7 +192,8 @@ def gptq_quantize_weights(w: torch.Tensor, quant_type: ScalarType,
), "For act_order, groupsize = {} must be less than size_k = {}".format( ), "For act_order, groupsize = {} must be less than size_k = {}".format(
group_size, size_k) group_size, size_k)
w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size) w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size,
test_perm)
return w_ref, w_q, w_s, g_idx, rand_perm return w_ref, w_q, w_s, g_idx, rand_perm
......
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
#
# Copyright 2023 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Shared resampler perceiver network used in multimodal models and
related helpers for sincos positional embeddings.
Example models: Qwen (Qwen-VL), Minicpmv2.0
"""
import math
from functools import partial
from typing import Callable, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.init import trunc_normal_
from vllm.model_executor.layers.linear import ReplicatedLinear
DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)
def get_abs_pos(abs_pos: torch.Tensor, tgt_size: Union[torch.Tensor,
int]) -> torch.Tensor:
# abs_pos: L, C
# tgt_size: (H, W)
# return: M, C
src_size = int(math.sqrt(abs_pos.size(0)))
dtype = abs_pos.dtype
if isinstance(tgt_size, int):
tgt_size = (tgt_size, tgt_size)
if (src_size == tgt_size[0] and src_size == tgt_size[1]):
return abs_pos
return (F.interpolate(
abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
size=(tgt_size[0], tgt_size[1]),
mode="bicubic",
align_corners=False,
).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype))
# sin/cos positional embedding helpers are adapted from:
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
def get_1d_sincos_pos_embed_from_grid(
embed_dim: int, pos: np.ndarray,
version: Tuple[int, int] = (2, 0)) -> torch.Tensor:
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,) / (H, W)
out: (M, D) / (H, W, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float32)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)
if version == (2, 0):
pos = pos.reshape(-1) # (M,)
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
else:
out = np.einsum("hw,d->hwd", pos, omega) # (H, W, D/2), outer product
emb_sin = np.sin(out) # (H, W, D/2)
emb_cos = np.cos(out) # (H, W, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (H, W, D)
return emb
def get_2d_sincos_pos_embed_from_grid(
embed_dim: int, grid: np.ndarray,
version: Tuple[int, int] = (2, 0)) -> torch.Tensor:
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(
embed_dim // 2, grid[0], version) # (H*W, D/2) or (H, W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(
embed_dim // 2, grid[1], version) # (H*W, D/2) or (H, W, D/2)
if version == (2, 0):
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
else:
emb = np.concatenate([emb_h, emb_w], axis=-1) # (H, W, D)
return emb
def get_2d_sincos_pos_embed(
embed_dim: int,
grid_size: Union[int, Tuple[int, int]],
cls_token: bool = False,
version: Tuple[int, int] = (2, 0),
) -> torch.Tensor:
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
if isinstance(grid_size, int):
grid_h_size, grid_w_size = grid_size, grid_size
else:
grid_h_size, grid_w_size = grid_size[0], grid_size[1]
grid_h = np.arange(grid_h_size, dtype=np.float32)
grid_w = np.arange(grid_w_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
assert isinstance(grid, np.ndarray) and \
grid.shape == (2, grid_h_size, grid_w_size)
if version == (2, 0):
grid = grid.reshape([2, 1, grid_h_size, grid_w_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version)
if cls_token:
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed],
axis=0)
else:
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version)
return pos_embed
class BaseResampler(nn.Module):
"""
A 2D perceiver-resampler network with one cross attention layers by
(grid_size**2) learnable queries and 2d sincos pos_emb.
Outputs:
A tensor with the shape of (grid_size**2, embed_dim)
"""
def __init__(
self,
num_queries: int,
embed_dim: int,
num_heads: int,
kv_dim: Optional[int] = None,
norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
do_post_projection: bool = True,
) -> None:
super().__init__()
self.num_queries = num_queries
self.embed_dim = embed_dim
self.num_heads = num_heads
self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
trunc_normal_(self.query, std=0.02)
if kv_dim is not None and kv_dim != embed_dim:
self.kv_proj = ReplicatedLinear(kv_dim, embed_dim, bias=False)
else:
# Maintain the same return value with ReplicatedLinear.forward
self.kv_proj = lambda *args, **kwargs: ( # type: ignore # noqa
nn.Identity()(*args, **kwargs),
None,
)
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
self.ln_q = norm_layer(embed_dim)
self.ln_kv = norm_layer(embed_dim)
self.do_post_projection = do_post_projection
self.ln_post = norm_layer(embed_dim) if do_post_projection else None
self.proj = nn.Parameter(
(embed_dim**-0.5) *
torch.randn(embed_dim, embed_dim)) if do_post_projection else None
def _init_weights(self, m: nn.Module) -> None:
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def _repeat(self, query, N: int):
return query.unsqueeze(1).repeat(1, N, 1)
class Resampler2(BaseResampler):
"""Resampler-perceiver network to be used for a variety of model types,
e.g., Qwen-vl / Minicpmv 2.0. The main difference is the addition of the
do_post_projection arg, which indicates whether or not there should be
a post layer normalization and projector after the attention. This is
present in minicpmv2.0, but not qwen-vl.
"""
def __init__(
self,
grid_size: int,
embed_dim: int,
num_heads: int,
kv_dim: Optional[int] = None,
norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
adaptive: bool = False,
do_post_projection: bool = True,
) -> None:
super().__init__(grid_size**2,
embed_dim,
num_heads,
kv_dim,
norm_layer,
do_post_projection=do_post_projection)
self.adaptive = adaptive
pos_embed_arr = get_2d_sincos_pos_embed(embed_dim,
grid_size,
version=(2, 0))
self.pos_embed = nn.Parameter(
torch.from_numpy(pos_embed_arr).requires_grad_(False))
self.apply(self._init_weights)
def forward(
self,
x: torch.Tensor,
tgt_sizes: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if tgt_sizes is None:
tgt_sizes = int(math.sqrt(x.size(1)))
if self.adaptive:
pos_embed_arr = get_2d_sincos_pos_embed(self.embed_dim,
tgt_sizes,
version=(2, 0))
pos_embed = torch.from_numpy(pos_embed_arr).to(device=x.device,
dtype=x.dtype)
else:
pos_embed = get_abs_pos(self.pos_embed,
tgt_sizes).to(device=x.device,
dtype=x.dtype)
x, _ = self.kv_proj(x)
x = self.ln_kv(x).permute(1, 0, 2)
N = x.shape[1]
q = self.ln_q(self.query)
out = self.attn(
self._repeat(q, N) + self.pos_embed.unsqueeze(1),
x + pos_embed.unsqueeze(1),
x,
attn_mask=attn_mask,
)[0]
x = out.permute(1, 0, 2)
if self.do_post_projection:
x = self.ln_post(x)
x = x @ self.proj
return x
...@@ -28,7 +28,6 @@ import torch ...@@ -28,7 +28,6 @@ import torch
import torch.nn as nn import torch.nn as nn
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.platforms import current_platform
def _rotate_neox(x: torch.Tensor) -> torch.Tensor: def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
...@@ -48,21 +47,29 @@ def _apply_rotary_emb( ...@@ -48,21 +47,29 @@ def _apply_rotary_emb(
x: torch.Tensor, x: torch.Tensor,
cos: torch.Tensor, cos: torch.Tensor,
sin: torch.Tensor, sin: torch.Tensor,
is_neox_style: bool,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
x: [num_tokens, num_heads, head_size] x: [num_tokens, num_heads, head_size]
cos: [num_tokens, head_size // 2] cos: [num_tokens, head_size // 2]
sin: [num_tokens, head_size // 2] sin: [num_tokens, head_size // 2]
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
positional embeddings.
""" """
orig_dtype = x.dtype cos = cos.unsqueeze(-2).to(x.dtype)
x = x.float() sin = sin.unsqueeze(-2).to(x.dtype)
x1, x2 = torch.chunk(x, 2, dim=-1) if is_neox_style:
cos = cos.unsqueeze(-2) x1, x2 = torch.chunk(x, 2, dim=-1)
sin = sin.unsqueeze(-2) else:
x1 = x[..., ::2]
x2 = x[..., 1::2]
o1 = x1 * cos - x2 * sin o1 = x1 * cos - x2 * sin
o2 = x2 * cos + x1 * sin o2 = x2 * cos + x1 * sin
return torch.cat((o1, o2), dim=-1).to(orig_dtype) if is_neox_style:
return torch.cat((o1, o2), dim=-1)
else:
return torch.stack((o1, o2), dim=-1).flatten(-2)
class RotaryEmbedding(CustomOp): class RotaryEmbedding(CustomOp):
...@@ -87,10 +94,9 @@ class RotaryEmbedding(CustomOp): ...@@ -87,10 +94,9 @@ class RotaryEmbedding(CustomOp):
cache = self._compute_cos_sin_cache() cache = self._compute_cos_sin_cache()
cache = cache.to(dtype) cache = cache.to(dtype)
self.cos_sin_cache: torch.Tensor
self.register_buffer("cos_sin_cache", cache, persistent=False) self.register_buffer("cos_sin_cache", cache, persistent=False)
self.use_native2 = current_platform.is_tpu() and is_neox_style
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
"""Compute the inverse frequency.""" """Compute the inverse frequency."""
# NOTE(woosuk): To exactly match the HF implementation, we need to # NOTE(woosuk): To exactly match the HF implementation, we need to
...@@ -119,59 +125,7 @@ class RotaryEmbedding(CustomOp): ...@@ -119,59 +125,7 @@ class RotaryEmbedding(CustomOp):
key: torch.Tensor, key: torch.Tensor,
offsets: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
"""A PyTorch-native implementation equivalent to forward(). """A PyTorch-native implementation of forward()."""
This method mimics the implementation of the custom CUDA kernel
used in `forward_cuda()`.
"""
query = query.view(*query.shape[:-1], -1, self.head_size)
key = key.view(*key.shape[:-1], -1, self.head_size)
query_rot = query[..., :self.rotary_dim]
key_rot = key[..., :self.rotary_dim]
if self.rotary_dim < self.head_size:
query_pass = query[..., self.rotary_dim:]
key_pass = key[..., self.rotary_dim:]
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(
positions.device, dtype=query.dtype)
cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
if offsets is not None else positions]
cos, sin = cos_sin.chunk(2, dim=-1)
if self.is_neox_style:
# NOTE(woosuk): Here we assume that the positions tensor has the
# shape [batch_size, seq_len].
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
else:
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
query_rot = query_rot * cos + rotate_fn(query_rot) * sin
key_rot = key_rot * cos + rotate_fn(key_rot) * sin
if self.rotary_dim < self.head_size:
query = torch.cat((query_rot, query_pass), dim=-1)
key = torch.cat((key_rot, key_pass), dim=-1)
else:
query = query_rot
key = key_rot
query = query.flatten(-2)
key = key.flatten(-2)
return query, key
def forward_native2(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Another PyTorch-native implementation of forward().
This method might perform better than `forward_native()` when compiled.
"""
if offsets is not None: if offsets is not None:
positions = positions + offsets positions = positions + offsets
positions = positions.flatten() positions = positions.flatten()
...@@ -183,14 +137,14 @@ class RotaryEmbedding(CustomOp): ...@@ -183,14 +137,14 @@ class RotaryEmbedding(CustomOp):
query = query.view(num_tokens, -1, self.head_size) query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., :self.rotary_dim] query_rot = query[..., :self.rotary_dim]
query_pass = query[..., self.rotary_dim:] query_pass = query[..., self.rotary_dim:]
query_rot = _apply_rotary_emb(query_rot, cos, sin) query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key_shape = key.shape key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size) key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., :self.rotary_dim] key_rot = key[..., :self.rotary_dim]
key_pass = key[..., self.rotary_dim:] key_pass = key[..., self.rotary_dim:]
key_rot = _apply_rotary_emb(key_rot, cos, sin) key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key return query, key
...@@ -203,7 +157,7 @@ class RotaryEmbedding(CustomOp): ...@@ -203,7 +157,7 @@ class RotaryEmbedding(CustomOp):
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
self.cos_sin_cache = self.cos_sin_cache.to(positions.device, self.cos_sin_cache = self.cos_sin_cache.to(query.device,
dtype=query.dtype) dtype=query.dtype)
# ops.rotary_embedding()/batched_rotary_embedding() # ops.rotary_embedding()/batched_rotary_embedding()
# are in-place operations that update the query and key tensors. # are in-place operations that update the query and key tensors.
...@@ -240,17 +194,6 @@ class RotaryEmbedding(CustomOp): ...@@ -240,17 +194,6 @@ class RotaryEmbedding(CustomOp):
self.cos_sin_cache, self.is_neox_style) self.cos_sin_cache, self.is_neox_style)
return query, key return query, key
def forward_tpu(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
forward_fn = (self.forward_native2
if self.use_native2 else self.forward_native)
return forward_fn(positions, query, key, offsets)
def extra_repr(self) -> str: def extra_repr(self) -> str:
s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}" s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
s += f", max_position_embeddings={self.max_position_embeddings}" s += f", max_position_embeddings={self.max_position_embeddings}"
...@@ -769,6 +712,179 @@ class Llama3RotaryEmbedding(RotaryEmbedding): ...@@ -769,6 +712,179 @@ class Llama3RotaryEmbedding(RotaryEmbedding):
return new_freqs return new_freqs
class MRotaryEmbedding(RotaryEmbedding):
"""Rotary Embedding with Multimodal Sections."""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: torch.dtype,
mrope_section: Optional[List[int]] = None,
) -> None:
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, dtype)
self.mrope_section = mrope_section
if self.mrope_section:
assert sum(self.mrope_section) == rotary_dim // 2
def forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""PyTorch-native implementation equivalent to forward().
Args:
positions:
[num_tokens,] (text only) or
[3, num_tokens] (T/H/W positions with multimodal inputs)
query: [num_tokens, num_heads * head_size]
key: [num_tokens, num_kv_heads * head_size]
"""
assert positions.ndim == 1 or positions.ndim == 2
num_tokens = positions.shape[-1]
cos_sin = self.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
if positions.ndim == 2:
assert self.mrope_section
cos = torch.cat([
m[i]
for i, m in enumerate(cos.split(self.mrope_section, dim=-1))
],
dim=-1)
sin = torch.cat([
m[i]
for i, m in enumerate(sin.split(self.mrope_section, dim=-1))
],
dim=-1)
query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., :self.rotary_dim]
query_pass = query[..., self.rotary_dim:]
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., :self.rotary_dim]
key_pass = key[..., self.rotary_dim:]
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
@staticmethod
def get_input_positions(
input_tokens: List[int],
image_grid_thw: Union[List[List[int]], torch.Tensor],
video_grid_thw: Union[List[List[int]], torch.Tensor],
image_token_id: int,
video_token_id: int,
vision_start_token_id: int,
vision_end_token_id: int,
spatial_merge_size: int,
context_len: int = 0,
) -> Tuple[List[List[int]], int]:
"""Get mrope input positions and delta value."""
if isinstance(image_grid_thw, torch.Tensor):
image_grid_thw = image_grid_thw.tolist()
if isinstance(video_grid_thw, torch.Tensor):
video_grid_thw = video_grid_thw.tolist()
input_tokens_tensor = torch.tensor(input_tokens)
vision_start_indices = torch.argwhere(
input_tokens_tensor == vision_start_token_id).squeeze(1)
vision_tokens = input_tokens_tensor[vision_start_indices + 1]
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (vision_tokens == video_token_id).sum()
llm_pos_ids_list: list = []
st = 0
remain_images, remain_videos = image_nums, video_nums
image_index, video_index = 0, 0
for _ in range(image_nums + video_nums):
if image_token_id in input_tokens and remain_images > 0:
ed_image = input_tokens.index(image_token_id, st)
else:
ed_image = len(input_tokens) + 1
if video_token_id in input_tokens and remain_videos > 0:
ed_video = input_tokens.index(video_token_id, st)
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = (
image_grid_thw[image_index][0],
image_grid_thw[image_index][1],
image_grid_thw[image_index][2],
)
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = (
video_grid_thw[video_index][0],
video_grid_thw[video_index][1],
video_grid_thw[video_index][2],
)
video_index += 1
remain_videos -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = \
t, h // spatial_merge_size, w // spatial_merge_size
text_len = ed - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(
-1, llm_grid_h * llm_grid_w).flatten()
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
llm_grid_t, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
llm_grid_t, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
llm_positions = llm_positions[:, context_len:]
mrope_position_delta = (llm_positions.max() + 1 -
len(input_tokens)).item()
return llm_positions.tolist(), mrope_position_delta
@staticmethod
def get_next_input_positions(
mrope_position_delta: int,
context_len: int,
seq_len: int,
) -> List[List[int]]:
return [
list(
range(context_len + mrope_position_delta,
seq_len + mrope_position_delta)) for _ in range(3)
]
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {} _ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
...@@ -809,7 +925,7 @@ def get_rope( ...@@ -809,7 +925,7 @@ def get_rope(
# The correct one should be "longrope" but keep "su" here # The correct one should be "longrope" but keep "su" here
# for backward compatible # for backward compatible
if scaling_type not in {"su", "longrope"}: if scaling_type not in {"su", "longrope"}:
scaling_factor = rope_scaling["factor"] scaling_factor = rope_scaling.get("factor", 1.0)
if scaling_type == "llama3": if scaling_type == "llama3":
low_freq_factor = rope_scaling["low_freq_factor"] low_freq_factor = rope_scaling["low_freq_factor"]
high_freq_factor = rope_scaling["high_freq_factor"] high_freq_factor = rope_scaling["high_freq_factor"]
...@@ -873,6 +989,16 @@ def get_rope( ...@@ -873,6 +989,16 @@ def get_rope(
head_size, rotary_dim, max_position, original_max_position, head_size, rotary_dim, max_position, original_max_position,
base, is_neox_style, dtype, short_factor, long_factor, base, is_neox_style, dtype, short_factor, long_factor,
**extra_kwargs) **extra_kwargs)
elif scaling_type == "mrope":
return MRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
mrope_section=rope_scaling["mrope_section"],
)
else: else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}") raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
_ROPE_DICT[key] = rotary_emb _ROPE_DICT[key] = rotary_emb
......
...@@ -17,6 +17,7 @@ import torch ...@@ -17,6 +17,7 @@ import torch
from huggingface_hub import HfApi, hf_hub_download from huggingface_hub import HfApi, hf_hub_download
from torch import nn from torch import nn
from transformers import AutoModelForCausalLM, PretrainedConfig from transformers import AutoModelForCausalLM, PretrainedConfig
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat,
LoRAConfig, ModelConfig, MultiModalConfig, LoRAConfig, ModelConfig, MultiModalConfig,
...@@ -94,8 +95,9 @@ def _get_quantization_config( ...@@ -94,8 +95,9 @@ def _get_quantization_config(
"""Get the quantization config.""" """Get the quantization config."""
if model_config.quantization is not None: if model_config.quantization is not None:
quant_config = get_quant_config(model_config, load_config) quant_config = get_quant_config(model_config, load_config)
if not current_platform.is_tpu(): capability = current_platform.get_device_capability() # type: ignore
capability = current_platform.get_device_capability()
if capability is not None:
capability = capability[0] * 10 + capability[1] capability = capability[0] * 10 + capability[1]
if capability < quant_config.get_min_capability(): if capability < quant_config.get_min_capability():
raise ValueError( raise ValueError(
...@@ -187,6 +189,11 @@ class BaseModelLoader(ABC): ...@@ -187,6 +189,11 @@ class BaseModelLoader(ABC):
def __init__(self, load_config: LoadConfig): def __init__(self, load_config: LoadConfig):
self.load_config = load_config self.load_config = load_config
@abstractmethod
def download_model(self, model_config: ModelConfig) -> None:
"""Download a model so that it can be immediately loaded."""
raise NotImplementedError
@abstractmethod @abstractmethod
def load_model(self, *, model_config: ModelConfig, def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig, device_config: DeviceConfig,
...@@ -195,7 +202,7 @@ class BaseModelLoader(ABC): ...@@ -195,7 +202,7 @@ class BaseModelLoader(ABC):
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module: cache_config: CacheConfig) -> nn.Module:
"""Load a model with the given configurations.""" """Load a model with the given configurations."""
... raise NotImplementedError
class DefaultModelLoader(BaseModelLoader): class DefaultModelLoader(BaseModelLoader):
...@@ -244,12 +251,17 @@ class DefaultModelLoader(BaseModelLoader): ...@@ -244,12 +251,17 @@ class DefaultModelLoader(BaseModelLoader):
is_local = os.path.isdir(model_name_or_path) is_local = os.path.isdir(model_name_or_path)
load_format = self.load_config.load_format load_format = self.load_config.load_format
use_safetensors = False use_safetensors = False
index_file = SAFE_WEIGHTS_INDEX_NAME
# Some quantized models use .pt files for storing the weights. # Some quantized models use .pt files for storing the weights.
if load_format == LoadFormat.AUTO: if load_format == LoadFormat.AUTO:
allow_patterns = ["*.safetensors", "*.bin"] allow_patterns = ["*.safetensors", "*.bin"]
elif load_format == LoadFormat.SAFETENSORS: elif load_format == LoadFormat.SAFETENSORS:
use_safetensors = True use_safetensors = True
allow_patterns = ["*.safetensors"] allow_patterns = ["*.safetensors"]
elif load_format == LoadFormat.MISTRAL:
use_safetensors = True
allow_patterns = ["consolidated*.safetensors"]
index_file = "consolidated.safetensors.index.json"
elif load_format == LoadFormat.PT: elif load_format == LoadFormat.PT:
allow_patterns = ["*.pt"] allow_patterns = ["*.pt"]
elif load_format == LoadFormat.NPCACHE: elif load_format == LoadFormat.NPCACHE:
...@@ -287,10 +299,10 @@ class DefaultModelLoader(BaseModelLoader): ...@@ -287,10 +299,10 @@ class DefaultModelLoader(BaseModelLoader):
# any files not found in the index. # any files not found in the index.
if not is_local: if not is_local:
download_safetensors_index_file_from_hf( download_safetensors_index_file_from_hf(
model_name_or_path, self.load_config.download_dir, model_name_or_path, index_file,
revision) self.load_config.download_dir, revision)
hf_weights_files = filter_duplicate_safetensors_files( hf_weights_files = filter_duplicate_safetensors_files(
hf_weights_files, hf_folder) hf_weights_files, hf_folder, index_file)
else: else:
hf_weights_files = filter_files_not_needed_for_inference( hf_weights_files = filter_files_not_needed_for_inference(
hf_weights_files) hf_weights_files)
...@@ -332,6 +344,11 @@ class DefaultModelLoader(BaseModelLoader): ...@@ -332,6 +344,11 @@ class DefaultModelLoader(BaseModelLoader):
weights_iterator = _xla_weights_iterator(weights_iterator) weights_iterator = _xla_weights_iterator(weights_iterator)
return weights_iterator return weights_iterator
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model,
model_config.revision,
fall_back_to_pt=True)
def load_model(self, *, model_config: ModelConfig, def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig, device_config: DeviceConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
...@@ -374,6 +391,9 @@ class DummyModelLoader(BaseModelLoader): ...@@ -374,6 +391,9 @@ class DummyModelLoader(BaseModelLoader):
raise ValueError(f"Model loader extra config is not supported for " raise ValueError(f"Model loader extra config is not supported for "
f"load format {load_config.load_format}") f"load format {load_config.load_format}")
def download_model(self, model_config: ModelConfig) -> None:
pass # Nothing to download
def load_model(self, *, model_config: ModelConfig, def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig, device_config: DeviceConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
...@@ -464,6 +484,12 @@ class TensorizerLoader(BaseModelLoader): ...@@ -464,6 +484,12 @@ class TensorizerLoader(BaseModelLoader):
model = load_with_tensorizer(tensorizer_config, **extra_kwargs) model = load_with_tensorizer(tensorizer_config, **extra_kwargs)
return model.eval() return model.eval()
def download_model(self, model_config: ModelConfig) -> None:
self.tensorizer_config.verify_with_model_config(model_config)
with self.tensorizer_config.open_stream():
pass
def load_model(self, *, model_config: ModelConfig, def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig, device_config: DeviceConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
...@@ -565,6 +591,9 @@ class ShardedStateLoader(BaseModelLoader): ...@@ -565,6 +591,9 @@ class ShardedStateLoader(BaseModelLoader):
ignore_patterns=self.load_config.ignore_patterns, ignore_patterns=self.load_config.ignore_patterns,
) )
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model, model_config.revision)
def load_model(self, *, model_config: ModelConfig, def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig, device_config: DeviceConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
...@@ -992,6 +1021,9 @@ class BitsAndBytesModelLoader(BaseModelLoader): ...@@ -992,6 +1021,9 @@ class BitsAndBytesModelLoader(BaseModelLoader):
set_weight_attrs( set_weight_attrs(
param, {"matmul_state": [None] * len(quant_states)}) param, {"matmul_state": [None] * len(quant_states)})
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model, model_config.revision)
def load_model(self, *, model_config: ModelConfig, def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig, device_config: DeviceConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
...@@ -1067,6 +1099,9 @@ class GGUFModelLoader(BaseModelLoader): ...@@ -1067,6 +1099,9 @@ class GGUFModelLoader(BaseModelLoader):
return gguf_quant_weights_iterator(model_name_or_path, return gguf_quant_weights_iterator(model_name_or_path,
gguf_to_hf_name_map) gguf_to_hf_name_map)
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model)
def load_model(self, *, model_config: ModelConfig, def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig, device_config: DeviceConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
......
...@@ -99,6 +99,13 @@ class TensorizerConfig: ...@@ -99,6 +99,13 @@ class TensorizerConfig:
"Loading a model using Tensorizer with quantization on vLLM" "Loading a model using Tensorizer with quantization on vLLM"
" is unstable and may lead to errors.") " is unstable and may lead to errors.")
def open_stream(self, tensorizer_args: Optional["TensorizerArgs"] = None):
if tensorizer_args is None:
tensorizer_args = self._construct_tensorizer_args()
return open_stream(self.tensorizer_uri,
**tensorizer_args.stream_params)
def load_with_tensorizer(tensorizer_config: TensorizerConfig, def load_with_tensorizer(tensorizer_config: TensorizerConfig,
**extra_kwargs) -> nn.Module: **extra_kwargs) -> nn.Module:
......
...@@ -43,10 +43,18 @@ def get_model_architecture( ...@@ -43,10 +43,18 @@ def get_model_architecture(
# Special handling for quantized Mixtral. # Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack. # FIXME(woosuk): This is a temporary hack.
mixtral_supported = ["fp8", "compressed-tensors"] mixtral_supported = ["fp8", "compressed-tensors"]
# for gptq_marlin, only run fused MoE for int4
if model_config.quantization == "gptq_marlin":
hf_quant_config = getattr(model_config.hf_config,
"quantization_config", None)
if hf_quant_config and hf_quant_config.get("bits") == 4:
mixtral_supported.append("gptq_marlin")
if (model_config.quantization is not None if (model_config.quantization is not None
and model_config.quantization not in mixtral_supported and model_config.quantization not in mixtral_supported
and "MixtralForCausalLM" in architectures): and "MixtralForCausalLM" in architectures):
architectures = ["QuantMixtralForCausalLM"] architectures = ["QuantMixtralForCausalLM"]
return ModelRegistry.resolve_model_cls(architectures) return ModelRegistry.resolve_model_cls(architectures)
......
...@@ -16,7 +16,6 @@ import torch ...@@ -16,7 +16,6 @@ import torch
from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
from safetensors.torch import load_file, safe_open, save_file from safetensors.torch import load_file, safe_open, save_file
from tqdm.auto import tqdm from tqdm.auto import tqdm
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from vllm.config import LoadConfig, ModelConfig from vllm.config import LoadConfig, ModelConfig
from vllm.distributed import get_tensor_model_parallel_rank from vllm.distributed import get_tensor_model_parallel_rank
...@@ -193,6 +192,13 @@ def get_quant_config(model_config: ModelConfig, ...@@ -193,6 +192,13 @@ def get_quant_config(model_config: ModelConfig,
if model_config.quantization == "bitsandbytes": if model_config.quantization == "bitsandbytes":
config["adapter_name_or_path"] = model_name_or_path config["adapter_name_or_path"] = model_name_or_path
elif model_config.quantization == "modelopt":
if config["producer"]["name"] == "modelopt":
return quant_cls.from_config(config)
else:
raise ValueError(
f"Unsupported quantization config"
f" found for {model_config.quantization} in {f}.")
return quant_cls.from_config(config) return quant_cls.from_config(config)
...@@ -251,6 +257,7 @@ def download_weights_from_hf( ...@@ -251,6 +257,7 @@ def download_weights_from_hf(
def download_safetensors_index_file_from_hf( def download_safetensors_index_file_from_hf(
model_name_or_path: str, model_name_or_path: str,
index_file: str,
cache_dir: Optional[str], cache_dir: Optional[str],
revision: Optional[str] = None, revision: Optional[str] = None,
) -> None: ) -> None:
...@@ -269,36 +276,37 @@ def download_safetensors_index_file_from_hf( ...@@ -269,36 +276,37 @@ def download_safetensors_index_file_from_hf(
# Download the safetensors index file. # Download the safetensors index file.
hf_hub_download( hf_hub_download(
repo_id=model_name_or_path, repo_id=model_name_or_path,
filename=SAFE_WEIGHTS_INDEX_NAME, filename=index_file,
cache_dir=cache_dir, cache_dir=cache_dir,
revision=revision, revision=revision,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
) )
# If file not found on remote or locally, we should not fail since # If file not found on remote or locally, we should not fail since
# only some models will have SAFE_WEIGHTS_INDEX_NAME. # only some models will have index_file.
except huggingface_hub.utils.EntryNotFoundError: except huggingface_hub.utils.EntryNotFoundError:
logger.info("No %s found in remote.", SAFE_WEIGHTS_INDEX_NAME) logger.info("No %s found in remote.", index_file)
except huggingface_hub.utils.LocalEntryNotFoundError: except huggingface_hub.utils.LocalEntryNotFoundError:
logger.info("No %s found in local cache.", SAFE_WEIGHTS_INDEX_NAME) logger.info("No %s found in local cache.", index_file)
# For models like Mistral-7B-v0.3, there are both sharded # For models like Mistral-7B-v0.3, there are both sharded
# safetensors files and a consolidated safetensors file. # safetensors files and a consolidated safetensors file.
# Passing both of these to the weight loader functionality breaks. # Passing both of these to the weight loader functionality breaks.
# So, we use the SAFE_WEIGHTS_INDEX_NAME to # So, we use the index_file to
# look up which safetensors files should be used. # look up which safetensors files should be used.
def filter_duplicate_safetensors_files(hf_weights_files: List[str], def filter_duplicate_safetensors_files(hf_weights_files: List[str],
hf_folder: str) -> List[str]: hf_folder: str,
index_file: str) -> List[str]:
# model.safetensors.index.json is a mapping from keys in the # model.safetensors.index.json is a mapping from keys in the
# torch state_dict to safetensors file holding that weight. # torch state_dict to safetensors file holding that weight.
index_file_name = os.path.join(hf_folder, SAFE_WEIGHTS_INDEX_NAME) index_file_name = os.path.join(hf_folder, index_file)
if not os.path.isfile(index_file_name): if not os.path.isfile(index_file_name):
return hf_weights_files return hf_weights_files
# Iterate through the weight_map (weight_name: safetensors files) # Iterate through the weight_map (weight_name: safetensors files)
# to identify weights that we should use. # to identify weights that we should use.
with open(index_file_name) as index_file: with open(index_file_name, "r") as f:
weight_map = json.load(index_file)["weight_map"] weight_map = json.load(f)["weight_map"]
weight_files_in_index = set() weight_files_in_index = set()
for weight_name in weight_map: for weight_name in weight_map:
weight_files_in_index.add( weight_files_in_index.add(
......
...@@ -51,9 +51,10 @@ _GENERATION_MODELS = { ...@@ -51,9 +51,10 @@ _GENERATION_MODELS = {
"PhiForCausalLM": ("phi", "PhiForCausalLM"), "PhiForCausalLM": ("phi", "PhiForCausalLM"),
"Phi3ForCausalLM": ("llama", "LlamaForCausalLM"), "Phi3ForCausalLM": ("llama", "LlamaForCausalLM"),
"PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"), "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
"Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"), "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
"Qwen2VLForConditionalGeneration":
("qwen2_vl", "Qwen2VLForConditionalGeneration"),
"RWForCausalLM": ("falcon", "FalconForCausalLM"), "RWForCausalLM": ("falcon", "FalconForCausalLM"),
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"), "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"), "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
...@@ -81,13 +82,20 @@ _MULTIMODAL_MODELS = { ...@@ -81,13 +82,20 @@ _MULTIMODAL_MODELS = {
"InternVLChatModel": ("internvl", "InternVLChatModel"), "InternVLChatModel": ("internvl", "InternVLChatModel"),
"LlavaForConditionalGeneration": "LlavaForConditionalGeneration":
("llava", "LlavaForConditionalGeneration"), ("llava", "LlavaForConditionalGeneration"),
"LlavaNextForConditionalGeneration": "LlavaNextForConditionalGeneration": ("llava_next",
("llava_next", "LlavaNextForConditionalGeneration"), "LlavaNextForConditionalGeneration"),
"LlavaNextVideoForConditionalGeneration":
("llava_next_video", "LlavaNextVideoForConditionalGeneration"),
"MiniCPMV": ("minicpmv", "MiniCPMV"), "MiniCPMV": ("minicpmv", "MiniCPMV"),
"PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration": ("paligemma",
"PaliGemmaForConditionalGeneration"), "PaliGemmaForConditionalGeneration"),
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
"UltravoxModel": ("ultravox", "UltravoxModel"), "UltravoxModel": ("ultravox", "UltravoxModel"),
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
"PixtralForConditionalGeneration": ("pixtral",
"PixtralForConditionalGeneration"),
"Qwen2VLForConditionalGeneration": ("qwen2_vl",
"Qwen2VLForConditionalGeneration"),
} }
_CONDITIONAL_GENERATION_MODELS = { _CONDITIONAL_GENERATION_MODELS = {
"BartModel": ("bart", "BartForConditionalGeneration"), "BartModel": ("bart", "BartForConditionalGeneration"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment