Commit dcb5624a authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.5' into v0.8.5-dev

parents 55880ca2 ba41cc90
# SPDX-License-Identifier: Apache-2.0
from typing import Dict, List, Optional, Tuple
import torch
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
BITBLAS_OPTIMIZE_FEATURES, BITBLAS_SUPPORTED_GROUP_SIZES,
MINIMUM_BITBLAS_VERSION, bitblas_make_empty_g_idx, bitblas_sort_g_idx,
check_bitblas_supports_shape, query_bitblas_supported_quant_types,
unpack_gptq_qweight, unpack_gptq_qzeros)
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
logger = init_logger(__name__)
class BitBLASLinearKernel(MPLinearKernel):
OPT_FEATURES: List[int] = BITBLAS_OPTIMIZE_FEATURES
ENABLE_TUNING: bool = True
MATMUL_LAYOUT: str = "nt"
BITBLAS_DTYPES: Dict[torch.dtype, str] = {
torch.float32: "float32",
torch.float16: "float16",
torch.bfloat16: "bfloat16",
torch.half: "float16",
torch.int8: "int8",
}
bitblas_matmul: object = None
def __init__(
self,
c: MPLinearLayerConfig,
w_q_param_name: str,
w_s_param_name: str,
w_zp_param_name: Optional[str] = None,
w_gidx_param_name: Optional[str] = None,
bitblas_quant_config: Optional[QuantizationConfig] = None,
):
self.quant_config = bitblas_quant_config
super().__init__(c, w_q_param_name, w_s_param_name, w_zp_param_name,
w_gidx_param_name)
def repack_bitblas_from_gptq(
self,
b_q_weight: torch.Tensor,
scales: torch.Tensor,
qzeros: Optional[torch.Tensor] = None,
):
from bitblas.quantization.utils import general_compress
assert self.bitblas_matmul is not None, "bitblas_matmul is None"
quant_config = self.quant_config
# qweight in gptq old quant linear stored with
# (outfeatures, infeatures), should be transposed.
qweight = b_q_weight.T.contiguous().view(
quant_config.torch_storage_dtype) # type: ignore[union-attr]
intweight = unpack_gptq_qweight(
qweight,
quant_config.weight_bits).contiguous() # type: ignore[union-attr]
if self.bitblas_matmul.weight_transform is not None: # type: ignore[attr-defined]
qweight = self.bitblas_matmul.weight_transform( # type: ignore[attr-defined]
intweight.cpu()).cuda()
# scales in gptq old quant linear stored with
# (infeatures // group_size, outfeatures), should be transposed.
scales = scales.T.contiguous()
if qzeros is None:
return qweight, scales, None
# qzeros should be de-quantized to int zeros.
weight_bits = quant_config.weight_bits # type: ignore[union-attr]
intzeros = unpack_gptq_qzeros(qzeros, weight_bits).T.contiguous()
zeros: Optional[torch.Tensor] = None
zeros_mode = self.bitblas_matmul.config.zeros_mode # type: ignore[attr-defined]
if zeros_mode == "original":
zeros = intzeros.to(torch.float16).contiguous()
elif zeros_mode == "rescale":
assert zeros is not None, "zeros should not be None"
zeros[:, :] = intzeros.to(torch.float16)[:, :] * scales[:, :]
elif zeros_mode == "quantized":
zeros = (
torch.Tensor(
general_compress(
intzeros.T.contiguous().cpu().numpy(),
weight_bits,
)).to(qweight.device).
to(quant_config.torch_storage_dtype # type: ignore[union-attr]
).contiguous())
else:
raise ValueError("Unsupported zeros type: {}".format(zeros_mode))
return qweight, scales, zeros
@classmethod
def get_min_capability(cls) -> int:
return 70
@classmethod
def can_implement(cls,
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
is_bitblas_installed = True
try:
import bitblas
if bitblas.__version__ < MINIMUM_BITBLAS_VERSION:
raise ImportError(
"bitblas version is wrong. Please "
f"install bitblas>={MINIMUM_BITBLAS_VERSION}")
except ImportError:
is_bitblas_installed = False
if not is_bitblas_installed:
return False, "bitblas is not installed. Please install bitblas "\
"by running `pip install bitblas>="\
f"{MINIMUM_BITBLAS_VERSION}`"
quant_types = query_bitblas_supported_quant_types(c.zero_points)
if c.weight_type not in quant_types:
return False, (f"Quant type ({c.weight_type}) not supported by"
f" BitBLAS, supported types are: {quant_types}")
if c.group_size not in BITBLAS_SUPPORTED_GROUP_SIZES:
return False, (f"Group size ({c.group_size}) not supported by "
"BitBLAS, supported group sizes are: "
f"{BITBLAS_SUPPORTED_GROUP_SIZES}")
return check_bitblas_supports_shape(
c.partition_weight_shape[1], # out_features
c.partition_weight_shape[0], # in_features
c.full_weight_shape[0], # in_features
c.group_size)
# note assumes that
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_scale` is: {input_dim = 0, output_dim = 1}
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
device = getattr(layer, self.w_q_name).device
c = self.config
quant_config = self.quant_config
# Default names since bitblas requires empty parameters for these,
# TODO: remove this requirement from bitblas (allow optional tensors)
if self.w_gidx_name is None:
self.w_gidx_name = "g_idx"
if self.w_zp_name is None:
self.w_zp_name = "qzeros"
if c.has_g_idx:
g_idx, g_idx_sort_indices = bitblas_sort_g_idx(
getattr(layer, self.w_gidx_name))
self._transform_param(layer, self.w_gidx_name, lambda _: g_idx)
layer.g_idx_sort_indices = g_idx_sort_indices
else:
setattr(layer, self.w_gidx_name, bitblas_make_empty_g_idx(device))
layer.g_idx_sort_indices = bitblas_make_empty_g_idx(device)
if c.zero_points:
raise NotImplementedError("Zero points not supported by BitBLAS")
else:
setattr(layer, self.w_zp_name, bitblas_make_empty_g_idx(device))
# Repack weights
bitblas_qweight, bitblas_scales, bitblas_qzeros = (
self.repack_bitblas_from_gptq(
layer.qweight,
layer.scales,
None if quant_config.is_sym else # type: ignore[union-attr]
layer.qzeros, # type: ignore[union-attr]
))
replace_parameter(layer, self.w_q_name, bitblas_qweight)
replace_parameter(layer, self.w_s_name, bitblas_scales)
if bitblas_qzeros is not None:
replace_parameter(layer, self.w_zp_name, bitblas_qzeros)
def configure_bitblas_matmul(
self,
infeatures: int,
outfeatures: int,
params_dtype: torch.dtype,
bias: bool,
) -> None:
enable_tuning = self.ENABLE_TUNING
layout = self.MATMUL_LAYOUT
bits = self.quant_config.weight_bits # type: ignore[union-attr]
self._configure_bitblas_matmul(
infeatures,
outfeatures,
params_dtype,
enable_tuning,
bias,
layout,
bits,
)
def _configure_bitblas_matmul(
self,
infeatures,
outfeatures,
params_dtype,
enable_tuning,
bias,
layout,
bits,
):
from bitblas import MatmulConfig
bitblas_dtype = self.BITBLAS_DTYPES[params_dtype]
quant_config = self.quant_config
with_scaling = False
with_zeros = False
group_size = quant_config.group_size # type: ignore[union-attr]
zeros_mode = quant_config.zeros_mode # type: ignore[union-attr]
if quant_config.quant_method == "gptq": # type: ignore[union-attr]
with_scaling = True
with_zeros = True
W_dtype = f"uint{bits}"
if quant_config.is_sym: # type: ignore[union-attr]
with_zeros = False
W_dtype = f"int{bits}"
else:
raise ValueError(
f"Unsupported quant_method {quant_config.quant_method}" # type: ignore[union-attr]
) # type: ignore[union-attr]
matmul_config = MatmulConfig(
M=self.OPT_FEATURES,
N=outfeatures,
K=infeatures,
A_dtype=bitblas_dtype,
W_dtype=W_dtype,
out_dtype=bitblas_dtype,
accum_dtype="int32" if bitblas_dtype == "int8" else bitblas_dtype,
storage_dtype=quant_config. # type: ignore[union-attr]
storage_dtype, # type: ignore[union-attr]
with_scaling=with_scaling,
with_zeros=with_zeros,
group_size=group_size,
with_bias=bias,
layout=layout,
zeros_mode=zeros_mode,
)
self.bitblas_matmul = self._get_or_create_bitblas_operator(
matmul_config, enable_tuning)
def _get_or_create_bitblas_operator(self, config, enable_tuning):
from bitblas import Matmul, auto_detect_nvidia_target
from bitblas.cache import get_database_path, global_operator_cache
BITBLAS_DATABASE_PATH = get_database_path()
BITBLAS_TARGET = auto_detect_nvidia_target()
if global_operator_cache.size() == 0:
global_operator_cache.load_from_database(BITBLAS_DATABASE_PATH,
BITBLAS_TARGET)
bitblas_matmul = global_operator_cache.get(config)
if bitblas_matmul is None:
bitblas_matmul = Matmul(config,
target=BITBLAS_TARGET,
enable_tuning=False)
if enable_tuning:
bitblas_matmul.hardware_aware_finetune(topk=20)
global_operator_cache.add(config, bitblas_matmul)
global_operator_cache.save_into_database(
BITBLAS_DATABASE_PATH, BITBLAS_TARGET)
TUNING_MESSAGE = (
f"BitBLAS Operator {config} tuned and saved to database.")
logger.info(TUNING_MESSAGE)
else:
_message = f"BitBLAS Operator {config} created without tuning. "
logger.info(_message)
else:
_message = f"BitBLAS Operator {config} retrieved from cache."
logger.info(_message)
return bitblas_matmul
def apply_gptq_bitblas_linear(
self,
layer: torch.nn.Module,
x: torch.Tensor,
) -> torch.Tensor:
output_size_per_partition = self.config.partition_weight_shape[1]
out_shape = x.shape[:-1] + (output_size_per_partition, )
args = [x, layer.qweight, layer.scales]
if self.bitblas_matmul.config.with_zeros: # type: ignore[attr-defined]
args.append(layer.qzeros)
output = self.bitblas_matmul(*args) # type: ignore[operator]
return output.view(out_shape)
def apply_weights(self, layer, x, bias=None):
NOT_IMPLEMENT_MESSAGE = (
f"{self.__class__.__name__}.apply_weights is not implemented. "
"Please use BitBLASLinearKernel.apply_gptq_bitblas_linear instead")
raise NotImplementedError(NOT_IMPLEMENT_MESSAGE)
...@@ -26,17 +26,14 @@ class MacheteLinearKernel(MPLinearKernel): ...@@ -26,17 +26,14 @@ class MacheteLinearKernel(MPLinearKernel):
@classmethod @classmethod
def can_implement(cls, def can_implement(cls,
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
if c.has_g_idx and\ if c.has_g_idx and\
c.partition_weight_shape[0] != c.full_weight_shape[0]: c.partition_weight_shape[0] != c.full_weight_shape[0]:
return False, "Act reordering currently not supported by Machete, "\ return False, "Act reordering currently not supported by Machete, "\
"when the input features are partitioned across "\ "when the input features are partitioned across "\
"devices" "devices"
if c.zero_points: if c.zero_points:
return False, "Zero points currently not supported by "\ return False, "Zero points currently not supported by Machete"
" Compressed Tensors + Machete. (Kernel supports it"\
" but CompressedTensorsWNA16 does not so support has"\
" not been added to MacheteWNA16Kernel yet"
if c.weight_type not in query_machete_supported_quant_types( if c.weight_type not in query_machete_supported_quant_types(
c.zero_points): c.zero_points):
......
...@@ -9,7 +9,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import ( ...@@ -9,7 +9,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
MARLIN_SUPPORTED_GROUP_SIZES, apply_gptq_marlin_linear, MARLIN_SUPPORTED_GROUP_SIZES, apply_gptq_marlin_linear,
check_marlin_supports_shape, marlin_is_k_full, marlin_make_empty_g_idx, check_marlin_supports_shape, marlin_is_k_full, marlin_make_empty_g_idx,
marlin_make_workspace, marlin_permute_scales, marlin_sort_g_idx, marlin_make_workspace, marlin_permute_scales, marlin_sort_g_idx,
query_marlin_supported_quant_types) marlin_zero_points, query_marlin_supported_quant_types, unpack_cols)
from vllm.model_executor.parameter import (BasevLLMParameter, from vllm.model_executor.parameter import (BasevLLMParameter,
permute_param_layout_) permute_param_layout_)
...@@ -25,10 +25,6 @@ class MarlinLinearKernel(MPLinearKernel): ...@@ -25,10 +25,6 @@ class MarlinLinearKernel(MPLinearKernel):
@classmethod @classmethod
def can_implement(cls, def can_implement(cls,
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
if c.zero_points:
return False, "Zero points currently not supported by "\
" MarlinLinearKernel. Will be added when AWQMarlin "\
"is migrated over to using MPLinearKernel backend"
quant_types = query_marlin_supported_quant_types(c.zero_points) quant_types = query_marlin_supported_quant_types(c.zero_points)
if c.weight_type not in quant_types: if c.weight_type not in quant_types:
...@@ -67,28 +63,6 @@ class MarlinLinearKernel(MPLinearKernel): ...@@ -67,28 +63,6 @@ class MarlinLinearKernel(MPLinearKernel):
if self.w_zp_name is None: if self.w_zp_name is None:
self.w_zp_name = "w_zp" self.w_zp_name = "w_zp"
if c.has_g_idx:
g_idx, g_idx_sort_indices = marlin_sort_g_idx(
getattr(layer, self.w_gidx_name))
self._transform_param(layer, self.w_gidx_name, lambda _: g_idx)
layer.g_idx_sort_indices = g_idx_sort_indices
else:
setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device))
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
if c.zero_points:
pass
# TODO (lucas): add the following when AWQMarlin is migrated over to
# using MPLinearKernel backend
# self._transform_param(layer, self.w_zp_name, lambda x: \
# marlin_zero_points(
# x,
# size_k=c.partition_weight_shape[0],
# size_n=c.partition_weight_shape[1],
# num_bits=c.weight_type.size_bits))
else:
setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device))
def transform_w_q(x): def transform_w_q(x):
assert isinstance(x, BasevLLMParameter) assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
...@@ -108,6 +82,28 @@ class MarlinLinearKernel(MPLinearKernel): ...@@ -108,6 +82,28 @@ class MarlinLinearKernel(MPLinearKernel):
group_size=c.group_size) group_size=c.group_size)
return x return x
if c.has_g_idx:
g_idx, g_idx_sort_indices = marlin_sort_g_idx(
getattr(layer, self.w_gidx_name))
self._transform_param(layer, self.w_gidx_name, lambda _: g_idx)
layer.g_idx_sort_indices = g_idx_sort_indices
else:
setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device))
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
if c.zero_points:
grouped_k = (c.partition_weight_shape[0] //
c.group_size if c.group_size != -1 else 1)
self._transform_param(layer, self.w_zp_name, lambda x: \
marlin_zero_points(
unpack_cols(x.t(), c.weight_type.size_bits,
grouped_k,
c.partition_weight_shape[1]),
size_k=grouped_k,
size_n=c.partition_weight_shape[1],
num_bits=c.weight_type.size_bits))
else:
setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device))
self._transform_param(layer, self.w_q_name, transform_w_q) self._transform_param(layer, self.w_q_name, transform_w_q)
self._transform_param(layer, self.w_s_name, transform_w_s) self._transform_param(layer, self.w_s_name, transform_w_s)
...@@ -131,5 +127,6 @@ class MarlinLinearKernel(MPLinearKernel): ...@@ -131,5 +127,6 @@ class MarlinLinearKernel(MPLinearKernel):
wtype=c.weight_type, wtype=c.weight_type,
input_size_per_partition=c.partition_weight_shape[0], input_size_per_partition=c.partition_weight_shape[0],
output_size_per_partition=c.partition_weight_shape[1], output_size_per_partition=c.partition_weight_shape[1],
has_zp=self.config.zero_points,
is_k_full=self.is_k_full, is_k_full=self.is_k_full,
bias=bias) bias=bias)
...@@ -38,6 +38,9 @@ class BaseKVCacheMethod(QuantizeMethodBase): ...@@ -38,6 +38,9 @@ class BaseKVCacheMethod(QuantizeMethodBase):
requires_grad=False) requires_grad=False)
layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0), layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0),
requires_grad=False) requires_grad=False)
# Initialize P = softmax(QK^T) scales
layer.prob_scale = torch.nn.Parameter(torch.tensor(-1.0),
requires_grad=False)
def apply(self, layer: torch.nn.Module) -> torch.Tensor: def apply(self, layer: torch.nn.Module) -> torch.Tensor:
raise RuntimeError( raise RuntimeError(
...@@ -97,5 +100,38 @@ class BaseKVCacheMethod(QuantizeMethodBase): ...@@ -97,5 +100,38 @@ class BaseKVCacheMethod(QuantizeMethodBase):
"may cause accuracy issues. Please make sure k/v_scale " "may cause accuracy issues. Please make sure k/v_scale "
"scaling factors are available in the fp8 checkpoint.") "scaling factors are available in the fp8 checkpoint.")
if layer.q_scale > 0.0:
q_scale = layer.q_scale
if current_platform.is_fp8_fnuz():
q_scale *= 2
layer.calculate_kv_scales = False
else:
q_scale = 1.0
if layer.prob_scale > 0.0:
prob_scale = layer.prob_scale
if current_platform.is_fp8_fnuz():
prob_scale *= 2
else:
prob_scale = 1.0
is_singleton_float = lambda x: isinstance(x, float) or isinstance(
x, torch.Tensor) and x.numel() == 1 and x.is_floating_point()
if not is_singleton_float(q_scale) or not is_singleton_float(
prob_scale):
raise ValueError("Only support per-tensor scaling factor"
"for fp8-quantized Q/prob")
# These are used in the final Attention.forward()
layer._q_scale.copy_(q_scale)
layer._prob_scale.copy_(prob_scale)
if q_scale == 1.0 or prob_scale == 1.0:
logger.warning_once(
f"Using Q scale {q_scale} and prob scale {prob_scale} "
"with fp8 attention. This may cause accuracy issues. "
"Please make sure Q/prob scaling factors are "
"available in the fp8 checkpoint.")
del layer.k_scale del layer.k_scale
del layer.v_scale del layer.v_scale
del layer.q_scale
del layer.prob_scale
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import fnmatch import fnmatch
import re
from typing import Any, Dict, List, Optional, cast from typing import Any, Dict, List, Optional, cast
import torch import torch
...@@ -125,6 +124,13 @@ class QuarkConfig(QuantizationConfig): ...@@ -125,6 +124,13 @@ class QuarkConfig(QuantizationConfig):
for q_config in q_configs: for q_config in q_configs:
q_config["output_tensors"] = None q_config["output_tensors"] = None
# In case q_proj output is also quantized, remove the configuration
# to keep qkv consistency.
q_proj_q_config = cast(Dict[str, Any],
layer_quant_config.get("*q_proj"))
if q_proj_q_config is not None:
q_proj_q_config["output_tensors"] = None
return cls(quant_config=config, return cls(quant_config=config,
kv_cache_group=kv_cache_group, kv_cache_group=kv_cache_group,
kv_cache_config=kv_cache_config, kv_cache_config=kv_cache_config,
...@@ -289,25 +295,14 @@ class QuarkConfig(QuantizationConfig): ...@@ -289,25 +295,14 @@ class QuarkConfig(QuantizationConfig):
:param name: param name :param name: param name
:return: matching param name for KV cache scale in vLLM :return: matching param name for KV cache scale in vLLM
""" """
if self.kv_cache_group is None or len(self.kv_cache_group) == 0: if name.endswith(".output_scale") and ".k_proj" in name:
return None return name.replace(".k_proj.output_scale", ".attn.k_scale")
if name.endswith(".output_scale") and ".v_proj" in name:
kv_proj_names = [ return name.replace(".v_proj.output_scale", ".attn.v_scale")
re.split(r"[*.]", kv_cache)[-1] for kv_cache in self.kv_cache_group if name.endswith(".output_scale") and ".q_proj" in name:
] return name.replace(".q_proj.output_scale", ".attn.q_scale")
if name.endswith(".output_scale"): if name.endswith("self_attn.prob_output_scale"):
if len(kv_proj_names) == 1 and kv_proj_names[0] in name: return name.replace(".prob_output_scale", ".attn.prob_scale")
kv_output_scale_name = "." + kv_proj_names[0] + ".output_scale"
return name.replace(kv_output_scale_name, ".attn.k_scale")
elif len(kv_proj_names) == 2:
for kv_proj_name in kv_proj_names:
if kv_proj_name in name and kv_proj_name == "k_proj":
return name.replace(".k_proj.output_scale",
".attn.k_scale")
elif kv_proj_name in name and kv_proj_name == "v_proj":
return name.replace(".v_proj.output_scale",
".attn.v_scale")
# If no matches, return None # If no matches, return None
return None return None
......
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple
import torch
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types
MINIMUM_BITBLAS_VERSION = "0.1.0"
BITBLAS_MIN_WEIGHT_SIZE_N = 16
BITBLAS_MIN_WEIGHT_SIZE_K = 16
GPTQ_BITBLAS_MAX_PARALLEL = 16
BITBLAS_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
# For dynamic shape code generation
BITBLAS_OPTIMIZE_FEATURES = [1, 16, 32, 64, 128, 256, 512, 1024]
# If want to enable high performance for contiguous batching
# Please use the following values
BITBLAS_OPTIMIZE_FEATURES_CONTIGUOUS = [16, 32, 64, 128, 256, 512, 1024]
BITBLAS_SUPPORTED_NUM_BITS = [1, 2, 4, 8]
BITBLAS_SUPPORTED_SYM = [False, True]
# Determines the supported quantization types for BitBLAS based on the
# device's capability and whether zero-point (zp) is used.
def query_bitblas_supported_quant_types(has_zp: bool,
device_capability: Optional[int] = None
):
if device_capability is None:
capability_tuple = current_platform.get_device_capability()
device_capability = (-1 if capability_tuple is None else
capability_tuple.to_int())
if device_capability < 70:
return []
if has_zp:
# AWQ style, unsigned + runtime zero-point
return [scalar_types.uint4, scalar_types.uint8]
else:
# GPTQ style, unsigned + symmetric bias
# TODO: once fp8_bitblas is merged into "gptq_bitblas" we should be able
# to add `scalar_types.float8_e4m3fn` here
return [scalar_types.uint4b8, scalar_types.uint8b128]
def _check_bitblas_supported(
quant_type: ScalarType,
group_size: Optional[int],
has_zp: bool,
device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]:
if device_capability is None:
capability_tuple = current_platform.get_device_capability()
device_capability = (-1 if capability_tuple is None else
capability_tuple.to_int())
supported_types = query_bitblas_supported_quant_types(
has_zp, device_capability)
if quant_type not in supported_types:
return (False, f"BitBLAS does not support weight_bits = {quant_type}. "
f"Only types = {supported_types} "
f"are supported (for group_size = {group_size}, "
f"device_capability = {device_capability}, zp = {has_zp}).")
if (group_size is None or group_size not in BITBLAS_SUPPORTED_GROUP_SIZES):
return (False, f"BitBLAS does not support group_size = {group_size}. "
f"Only group_sizes = {BITBLAS_SUPPORTED_GROUP_SIZES} "
"are supported.")
return True, None
def check_bitblas_supported(quant_type: ScalarType,
group_size: int,
has_zp: bool = False,
device_capability: Optional[int] = None) -> bool:
cond, _ = _check_bitblas_supported(quant_type, group_size, has_zp,
device_capability)
return cond
def verify_bitblas_supported(quant_type: ScalarType,
group_size: int,
has_zp: bool = False) -> None:
cond, err_msg = _check_bitblas_supported(quant_type, group_size, has_zp)
if not cond:
assert err_msg is not None
raise ValueError(err_msg)
def verify_bitblas_supports_shape(output_size_per_partition: int,
input_size_per_partition: int,
input_size: int, group_size: int) -> None:
# Validate output_size_per_partition
if output_size_per_partition % BITBLAS_MIN_WEIGHT_SIZE_N != 0:
raise ValueError(f"Weight output_size_per_partition = "
f"{output_size_per_partition} is not divisible by "
f" min_thread_n = {BITBLAS_MIN_WEIGHT_SIZE_N}. "
"Consider reducing tensor_parallel_size or running "
"with --quantization gptq.")
# Validate input_size_per_partition
if input_size_per_partition % BITBLAS_MIN_WEIGHT_SIZE_K != 0:
raise ValueError(f"Weight input_size_per_partition = "
f"{input_size_per_partition} is not divisible "
f"by min_thread_k = {BITBLAS_MIN_WEIGHT_SIZE_K}. "
"Consider reducing tensor_parallel_size or running "
"with --quantization gptq.")
if (group_size < input_size
and input_size_per_partition % group_size != 0):
raise ValueError(
f"Weight input_size_per_partition = {input_size_per_partition}"
f" is not divisible by group_size = {group_size}."
"Consider reducing tensor_parallel_size or running "
"with --quantization gptq.")
def check_bitblas_supports_shape(output_size_per_partition: int,
input_size_per_partition: int,
input_size: int, group_size: int) \
-> Tuple[bool, Optional[str]]:
try:
verify_bitblas_supports_shape(output_size_per_partition,
input_size_per_partition, input_size,
group_size)
except ValueError as e:
return False, e.__str__()
return True, None
def bitblas_is_k_full(act_order: bool, is_row_parallel: bool) -> bool:
return (not act_order) or (act_order and not is_row_parallel)
def bitblas_repeat_scales_on_all_ranks(act_order: bool, group_size: int,
is_row_parallel: bool) -> bool:
# Need to repeat scales on every rank if act_ordering or
# channelwise and RowParallelLinear
is_channelwise = group_size == -1
return act_order or (is_channelwise and is_row_parallel)
def bitblas_make_empty_g_idx(device: torch.device) -> torch.Tensor:
return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device),
requires_grad=False)
def bitblas_make_empty_zp(device: torch.device) -> torch.Tensor:
return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device),
requires_grad=False)
def bitblas_sort_g_idx(
g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
return g_idx[g_idx_sort_indices], g_idx_sort_indices
def unpack_gptq_qzeros(qzeros, bits, is_gptq_v2=False) -> torch.Tensor:
qzeros = qzeros.view(torch.int32)
elems_per_int32 = 32 // bits
unpacked_zeros = torch.zeros(
(qzeros.shape[0], qzeros.shape[1] * elems_per_int32),
dtype=torch.int8,
device=qzeros.device,
requires_grad=False,
)
for col in range(unpacked_zeros.shape[1]):
i = col % elems_per_int32
unpacked_zeros[:, col] = (qzeros[:, col // elems_per_int32] >>
(bits * i)) & 0xF
if not is_gptq_v2:
return unpacked_zeros + 1
return unpacked_zeros
def unpack_gptq_qweight(qweight, bits):
qweight = qweight.view(torch.int8)
elems_per_int8 = 8 // bits
unpacked_weight = torch.zeros(
(qweight.shape[0], qweight.shape[1] * elems_per_int8),
dtype=torch.int8,
device=qweight.device,
requires_grad=False,
)
for col in range(unpacked_weight.shape[1]):
i = col % elems_per_int8
unpacked_weight[:, col] = (qweight[:, col // elems_per_int8] >>
(bits * i))
return torch.bitwise_and(unpacked_weight, 2**bits - 1)
...@@ -151,6 +151,19 @@ def check_marlin_supports_layer(layer: LinearBase, group_size: int) \ ...@@ -151,6 +151,19 @@ def check_marlin_supports_layer(layer: LinearBase, group_size: int) \
group_size=group_size)[0] group_size=group_size)[0]
def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) \
-> bool:
hidden_size = layer.hidden_size
intermediate_size_per_partition = layer.intermediate_size_per_partition
# gate-up: (n, k) = (intermediate_size_per_partition * 2, hidden_size)
# down: (n, k) = (hidden_size, intermediate_size_per_partition)
# moe marlin requires n % 128 == 0 and k % 64 == 0
return hidden_size % 128 == 0 and \
intermediate_size_per_partition % max(64, group_size) == 0 and \
group_size in [-1, 32, 64, 128]
def marlin_make_workspace(output_size_per_partition: int, def marlin_make_workspace(output_size_per_partition: int,
device: torch.device) -> torch.Tensor: device: torch.device) -> torch.Tensor:
max_workspace_size = (output_size_per_partition // max_workspace_size = (output_size_per_partition //
...@@ -319,6 +332,7 @@ def apply_gptq_marlin_linear( ...@@ -319,6 +332,7 @@ def apply_gptq_marlin_linear(
wtype: ScalarType, wtype: ScalarType,
output_size_per_partition: int, output_size_per_partition: int,
input_size_per_partition: int, input_size_per_partition: int,
has_zp: bool,
is_k_full: bool, is_k_full: bool,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
...@@ -343,8 +357,8 @@ def apply_gptq_marlin_linear( ...@@ -343,8 +357,8 @@ def apply_gptq_marlin_linear(
size_n=output_size_per_partition, size_n=output_size_per_partition,
size_k=input_size_per_partition, size_k=input_size_per_partition,
is_k_full=is_k_full, is_k_full=is_k_full,
has_zp=False,
use_atomic_add=use_atomic_add, use_atomic_add=use_atomic_add,
has_zp=has_zp,
use_fp32_reduce=use_fp32_reduce, use_fp32_reduce=use_fp32_reduce,
is_zp_float=False) is_zp_float=False)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import List, Optional, Tuple, Union from typing import Callable, List, Optional, Tuple, Union
import torch import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm import envs
from vllm.config import CompilationLevel, get_current_vllm_config from vllm.config import CompilationLevel, get_current_vllm_config
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import W8a8GetCacheJSON from vllm.utils import W8a8GetCacheJSON
...@@ -19,6 +20,7 @@ W8A8_TRITONJSON=W8a8GetCacheJSON() ...@@ -19,6 +20,7 @@ W8A8_TRITONJSON=W8a8GetCacheJSON()
# The condition is determined once as the operations # The condition is determined once as the operations
# are time consuming. # are time consuming.
USE_ROWWISE_TORCH_SCALED_MM = (current_platform.is_rocm() USE_ROWWISE_TORCH_SCALED_MM = (current_platform.is_rocm()
and torch.__version__[0:3] >= "2.7"
and current_platform.has_device_capability(94)) and current_platform.has_device_capability(94))
def sparse_cutlass_supported() -> bool: def sparse_cutlass_supported() -> bool:
...@@ -132,6 +134,160 @@ def maybe_create_device_identity(): ...@@ -132,6 +134,160 @@ def maybe_create_device_identity():
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32) TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
def cutlass_w8a8_scaled_mm(*, qinput: torch.Tensor, weight: torch.Tensor,
out_dtype: torch.dtype, scale_a: torch.Tensor,
scale_b: torch.Tensor, bias: torch.Tensor,
output_shape: List, **kwargs) -> torch.Tensor:
# Fused GEMM_DQ
output = ops.cutlass_scaled_mm(qinput,
weight,
out_dtype=out_dtype,
scale_a=scale_a,
scale_b=scale_b,
bias=bias)
return output.view(*output_shape)
def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
weight: torch.Tensor,
out_dtype: torch.dtype,
scale_a: torch.Tensor,
scale_b: torch.Tensor, bias: torch.Tensor,
input_2d: torch.Tensor,
output_shape: List) -> torch.Tensor:
from vllm.platforms.rocm import on_mi250_mi300
if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi250_mi300(
) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0:
output = ops.wvSplitKQ(weight.t(), qinput, out_dtype, scale_a, scale_b,
current_platform.get_cu_count())
else:
output = torch._scaled_mm(qinput,
weight,
out_dtype=out_dtype,
scale_a=scale_a,
scale_b=scale_b,
bias=bias)
return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape)
def torch_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
weight: torch.Tensor,
out_dtype: torch.dtype,
scale_a: torch.Tensor,
scale_b: torch.Tensor, bias: torch.Tensor,
input_2d: torch.Tensor,
output_shape: List) -> torch.Tensor:
output = torch._scaled_mm(qinput,
weight,
out_dtype=out_dtype,
scale_a=scale_a,
scale_b=scale_b,
bias=bias)
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if type(output) is tuple and len(output) == 2:
output = output[0]
return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape)
def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor,
weight: torch.Tensor,
out_dtype: torch.dtype,
scale_a: torch.Tensor,
scale_b: torch.Tensor, bias: torch.Tensor,
input_2d: torch.Tensor,
output_shape: List) -> torch.Tensor:
# Note: Callers of this function should check USE_ROWWISE_TORCH_SCALED_MM
# when using it.
# For now it has only been validated on ROCm platform.
# fp8 rowwise scaling in torch._scaled_mm is introduced in
# https://github.com/pytorch/pytorch/pull/144432 using
# hipBLASLt and ROCm 6.3, which only exists in torch 2.7 and above.
#
# For CUDA platform please validate if the torch._scaled_mm supports
# rowwise scaled GEMM before using it
# Fused GEMM_DQ Rowwise GEMM
output = torch._scaled_mm(qinput,
weight,
out_dtype=out_dtype,
scale_a=scale_a,
scale_b=scale_b.t(),
bias=bias)
output = torch.narrow(output, 0, 0, input_2d.shape[0])
output = output.view(*output_shape)
return output
def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor,
weight: torch.Tensor,
out_dtype: torch.dtype,
scale_a: torch.Tensor,
scale_b: torch.Tensor, bias: torch.Tensor,
input_2d: torch.Tensor,
output_shape: List,
**kwargs) -> torch.Tensor:
# Use unfused DQ due to limitations with scaled_mm
# Symmetric quantized GEMM by definition computes the following:
# C = (s_x * X) (s_w * W) + bias
# This is equivalent to dequantizing the weights and activations
# before applying a GEMM.
#
# In order to compute quantized operands, a quantized kernel
# will rewrite the above like so:
# C = s_w * s_x * (X * W) + bias
#
# For the scaled_mm fallback case, we break this down, since it
# does not support s_w being a vector.
# GEMM
# This computes C = (X * W).
# Output in fp32 to allow subsequent ops to happen in-place
output = torch._scaled_mm(qinput,
weight,
scale_a=TORCH_DEVICE_IDENTITY,
scale_b=TORCH_DEVICE_IDENTITY,
out_dtype=torch.float32)
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if type(output) is tuple and len(output) == 2:
output = output[0]
# Unpad (undo num_token_padding)
output = torch.narrow(output, 0, 0, input_2d.shape[0])
x_scale = torch.narrow(scale_a, 0, 0, input_2d.shape[0])
# DQ
# C = sw * sx * (X * W) + bias
output = output * x_scale * scale_b.t()
if bias is not None:
output = output + bias
return output.to(out_dtype).view(*output_shape)
def dispatch_w8a8_scaled_mm(
cutlass_fp8_supported: bool, per_tensor_weights: bool,
per_tensor_activations: bool, use_per_token_if_dynamic: Optional[bool]
) -> Callable[..., torch.Tensor]:
if cutlass_fp8_supported:
return cutlass_w8a8_scaled_mm
if per_tensor_weights and per_tensor_activations:
if current_platform.is_rocm():
return rocm_per_tensor_w8a8_scaled_mm
return torch_per_tensor_w8a8_scaled_mm
# torch.scaled_mm supports per tensor weights + activations only
# so fallback to naive if per channel or per token
if (use_per_token_if_dynamic and not per_tensor_weights
and not per_tensor_activations and USE_ROWWISE_TORCH_SCALED_MM):
return torch_per_token_w8a8_scaled_mm
return torch_channelwise_w8a8_scaled_mm
# TODO(luka): follow similar pattern for marlin and block-fp8-linear # TODO(luka): follow similar pattern for marlin and block-fp8-linear
# https://github.com/vllm-project/vllm/issues/14397 # https://github.com/vllm-project/vllm/issues/14397
class Fp8LinearOp: class Fp8LinearOp:
...@@ -157,7 +313,8 @@ class Fp8LinearOp: ...@@ -157,7 +313,8 @@ class Fp8LinearOp:
if pad_output is None: if pad_output is None:
config = get_current_vllm_config().compilation_config config = get_current_vllm_config().compilation_config
pad_output = config.level < CompilationLevel.PIECEWISE pad_output = config.level < CompilationLevel.PIECEWISE
self.output_padding = 17 if pad_output else None self.output_padding = 17 if (
pad_output and not current_platform.is_rocm()) else None
def apply( def apply(
self, self,
...@@ -196,18 +353,6 @@ class Fp8LinearOp: ...@@ -196,18 +353,6 @@ class Fp8LinearOp:
input_scale, input_scale,
scale_ub=input_scale_ub, scale_ub=input_scale_ub,
use_per_token_if_dynamic=use_per_token_if_dynamic) use_per_token_if_dynamic=use_per_token_if_dynamic)
# Fused GEMM_DQ
output = ops.cutlass_scaled_mm(qinput,
weight,
out_dtype=out_dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias)
return output.view(*output_shape)
# torch.scaled_mm supports per tensor weights + activations only
# so fallback to naive if per channel or per token
else: else:
if input.dtype != current_platform.fp8_dtype(): if input.dtype != current_platform.fp8_dtype():
# Maybe apply padding to output, see comment in __init__ # Maybe apply padding to output, see comment in __init__
...@@ -219,84 +364,21 @@ class Fp8LinearOp: ...@@ -219,84 +364,21 @@ class Fp8LinearOp:
else: else:
qinput, x_scale = input_2d, input_scale qinput, x_scale = input_2d, input_scale
per_tensor_weights = (weight_scale.numel() == 1) per_tensor_weights = (weight_scale.numel() == 1)
per_tensor_activations = (x_scale.numel() == 1) per_tensor_activations = (x_scale.numel() == 1)
if per_tensor_weights and per_tensor_activations: w8a8_scaled_mm_func = dispatch_w8a8_scaled_mm(
# Fused GEMM_DQ self.cutlass_fp8_supported, per_tensor_weights,
output = torch._scaled_mm(qinput, per_tensor_activations, use_per_token_if_dynamic)
weight,
out_dtype=out_dtype, return w8a8_scaled_mm_func(qinput=qinput,
scale_a=x_scale, weight=weight,
scale_b=weight_scale, out_dtype=out_dtype,
bias=bias) scale_a=x_scale,
# A fix for discrepancy in scaled_mm which returns tuple scale_b=weight_scale,
# for torch < 2.5 and a single value in torch >= 2.5 bias=bias,
if type(output) is tuple and len(output) == 2: input_2d=input_2d,
output = output[0] output_shape=output_shape)
return torch.narrow(output, 0, 0,
input_2d.shape[0]).view(*output_shape)
elif (use_per_token_if_dynamic and not per_tensor_weights
and not per_tensor_activations
and USE_ROWWISE_TORCH_SCALED_MM):
# For now validated on ROCm platform
# fp8 rowwise scaling in torch._scaled_mm is introduced in
# https://github.com/pytorch/pytorch/pull/144432 using hipBLASLt
# and ROCm 6.3, which only exists in torch 2.7 and above.
# For CUDA platform please validate if the
# torch._scaled_mm support rowwise scaled GEMM
# Fused GEMM_DQ Rowwise GEMM
output = torch._scaled_mm(qinput,
weight,
out_dtype=out_dtype,
scale_a=x_scale,
scale_b=weight_scale.t(),
bias=bias)
output = torch.narrow(output, 0, 0, input_2d.shape[0])
output = output.view(*output_shape)
return output
else:
# Fallback for channelwise case, where we use unfused DQ
# due to limitations with scaled_mm
# Symmetric quantized GEMM by definition computes the following:
# C = (s_x * X) (s_w * W) + bias
# This is equivalent to dequantizing the weights and activations
# before applying a GEMM.
#
# In order to compute quantized operands, a quantized kernel
# will rewrite the above like so:
# C = s_w * s_x * (X * W) + bias
#
# For the scaled_mm fallback case, we break this down, since it
# does not support s_w being a vector.
# GEMM
# This computes C = (X * W).
# Output in fp32 to allow subsequent ops to happen in-place
output = torch._scaled_mm(qinput,
weight,
scale_a=TORCH_DEVICE_IDENTITY,
scale_b=TORCH_DEVICE_IDENTITY,
out_dtype=torch.float32)
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if type(output) is tuple and len(output) == 2:
output = output[0]
# Unpad (undo num_token_padding)
output = torch.narrow(output, 0, 0, input_2d.shape[0])
x_scale = torch.narrow(x_scale, 0, 0, input_2d.shape[0])
# DQ
# C = sw * sx * (X * W) + bias
output = output * x_scale * weight_scale.t()
if bias is not None:
output = output + bias
return output.to(dtype=input.dtype).view(*output_shape)
def apply_int8_linear( def apply_int8_linear(
......
...@@ -46,20 +46,12 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor: ...@@ -46,20 +46,12 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
return x.flatten(-2) return x.flatten(-2)
def _apply_rotary_emb( def _apply_rotary_emb_torch(
x: torch.Tensor, x: torch.Tensor,
cos: torch.Tensor, cos: torch.Tensor,
sin: torch.Tensor, sin: torch.Tensor,
is_neox_style: bool, is_neox_style: bool,
) -> torch.Tensor: ) -> torch.Tensor:
"""
Args:
x: [num_tokens, num_heads, head_size]
cos: [num_tokens, head_size // 2]
sin: [num_tokens, head_size // 2]
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
positional embeddings.
"""
cos = cos.unsqueeze(-2).to(x.dtype) cos = cos.unsqueeze(-2).to(x.dtype)
sin = sin.unsqueeze(-2).to(x.dtype) sin = sin.unsqueeze(-2).to(x.dtype)
if is_neox_style: if is_neox_style:
...@@ -75,6 +67,24 @@ def _apply_rotary_emb( ...@@ -75,6 +67,24 @@ def _apply_rotary_emb(
return torch.stack((o1, o2), dim=-1).flatten(-2) return torch.stack((o1, o2), dim=-1).flatten(-2)
def _apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor,
is_neox_style: bool) -> torch.Tensor:
"""
Args:
x: [num_tokens, num_heads, head_size]
cos: [num_tokens, head_size // 2]
sin: [num_tokens, head_size // 2]
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
positional embeddings.
"""
if current_platform.is_cuda_alike():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
return apply_rotary_emb(x.unsqueeze(0), cos, sin,
not is_neox_style).squeeze(0)
else:
return _apply_rotary_emb_torch(x, cos, sin, is_neox_style)
@CustomOp.register("rotary_embedding") @CustomOp.register("rotary_embedding")
class RotaryEmbedding(CustomOp): class RotaryEmbedding(CustomOp):
"""Original rotary positional embedding.""" """Original rotary positional embedding."""
...@@ -141,14 +151,16 @@ class RotaryEmbedding(CustomOp): ...@@ -141,14 +151,16 @@ class RotaryEmbedding(CustomOp):
query = query.view(num_tokens, -1, self.head_size) query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., :self.rotary_dim] query_rot = query[..., :self.rotary_dim]
query_pass = query[..., self.rotary_dim:] query_pass = query[..., self.rotary_dim:]
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style) query_rot = _apply_rotary_emb_torch(query_rot, cos, sin,
self.is_neox_style)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key_shape = key.shape key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size) key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., :self.rotary_dim] key_rot = key[..., :self.rotary_dim]
key_pass = key[..., self.rotary_dim:] key_pass = key[..., self.rotary_dim:]
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) key_rot = _apply_rotary_emb_torch(key_rot, cos, sin,
self.is_neox_style)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key return query, key
...@@ -988,8 +1000,9 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -988,8 +1000,9 @@ class MRotaryEmbedding(RotaryEmbedding):
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key return query, key
@staticmethod @classmethod
def get_input_positions( def get_input_positions(
cls,
input_tokens: List[int], input_tokens: List[int],
hf_config: PretrainedConfig, hf_config: PretrainedConfig,
image_grid_thw: Optional[Union[List[List[int]], torch.Tensor]], image_grid_thw: Optional[Union[List[List[int]], torch.Tensor]],
...@@ -997,6 +1010,8 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -997,6 +1010,8 @@ class MRotaryEmbedding(RotaryEmbedding):
second_per_grid_ts: Optional[List[float]], second_per_grid_ts: Optional[List[float]],
context_len: int = 0, context_len: int = 0,
seq_len: Optional[int] = None, seq_len: Optional[int] = None,
audio_feature_lengths: Optional[torch.Tensor] = None,
use_audio_in_video: bool = False,
) -> Tuple[List[List[int]], int]: ) -> Tuple[List[List[int]], int]:
"""Get mrope input positions and delta value.""" """Get mrope input positions and delta value."""
...@@ -1006,7 +1021,7 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -1006,7 +1021,7 @@ class MRotaryEmbedding(RotaryEmbedding):
second_per_grid_ts second_per_grid_ts
llm_positions, mrope_position_delta = \ llm_positions, mrope_position_delta = \
MRotaryEmbedding.get_input_positions_tensor( cls.get_input_positions_tensor(
input_tokens=input_tokens, input_tokens=input_tokens,
hf_config=hf_config, hf_config=hf_config,
image_grid_thw=image_grid_thw, image_grid_thw=image_grid_thw,
...@@ -1014,12 +1029,52 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -1014,12 +1029,52 @@ class MRotaryEmbedding(RotaryEmbedding):
second_per_grid_ts=second_per_grid_ts, second_per_grid_ts=second_per_grid_ts,
context_len=context_len, context_len=context_len,
seq_len=seq_len, seq_len=seq_len,
audio_feature_lengths=audio_feature_lengths,
use_audio_in_video=use_audio_in_video,
) )
return llm_positions.tolist(), mrope_position_delta return llm_positions.tolist(), mrope_position_delta
@staticmethod @classmethod
def get_input_positions_tensor( def get_input_positions_tensor(
cls,
input_tokens: List[int],
hf_config: PretrainedConfig,
image_grid_thw: Union[List[List[int]], torch.Tensor],
video_grid_thw: Union[List[List[int]], torch.Tensor],
second_per_grid_ts: List[float],
context_len: int = 0,
seq_len: Optional[int] = None,
audio_feature_lengths: Optional[torch.Tensor] = None,
use_audio_in_video: bool = False,
) -> Tuple[torch.Tensor, int]:
from vllm.transformers_utils.config import thinker_uses_mrope
if thinker_uses_mrope(hf_config):
return cls._omni_get_input_positions_tensor(
input_tokens=input_tokens,
hf_config=hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
context_len=context_len,
seq_len=seq_len,
audio_feature_lengths=audio_feature_lengths,
use_audio_in_video=use_audio_in_video,
)
else:
return cls._vl_get_input_positions_tensor(
input_tokens=input_tokens,
hf_config=hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
context_len=context_len,
seq_len=seq_len,
)
@classmethod
def _vl_get_input_positions_tensor(
cls,
input_tokens: List[int], input_tokens: List[int],
hf_config: PretrainedConfig, hf_config: PretrainedConfig,
image_grid_thw: Union[List[List[int]], torch.Tensor], image_grid_thw: Union[List[List[int]], torch.Tensor],
...@@ -1037,11 +1092,6 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -1037,11 +1092,6 @@ class MRotaryEmbedding(RotaryEmbedding):
tokens_per_second = getattr(hf_config.vision_config, tokens_per_second = getattr(hf_config.vision_config,
"tokens_per_second", 1.0) "tokens_per_second", 1.0)
if isinstance(image_grid_thw, torch.Tensor):
image_grid_thw = image_grid_thw.tolist()
if isinstance(video_grid_thw, torch.Tensor):
video_grid_thw = video_grid_thw.tolist()
input_tokens_tensor = torch.tensor(input_tokens) input_tokens_tensor = torch.tensor(input_tokens)
vision_start_indices = torch.argwhere( vision_start_indices = torch.argwhere(
input_tokens_tensor == vision_start_token_id).squeeze(1) input_tokens_tensor == vision_start_token_id).squeeze(1)
...@@ -1121,6 +1171,224 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -1121,6 +1171,224 @@ class MRotaryEmbedding(RotaryEmbedding):
return llm_positions, mrope_position_delta return llm_positions, mrope_position_delta
@classmethod
def _omni_get_input_positions_tensor(
cls,
input_tokens: List[int],
hf_config: PretrainedConfig,
image_grid_thw: Union[List[List[int]], torch.Tensor],
video_grid_thw: Union[List[List[int]], torch.Tensor],
second_per_grid_ts: Optional[List[float]] = None,
context_len: int = 0,
seq_len: Optional[int] = None,
audio_feature_lengths: Optional[torch.Tensor] = None,
use_audio_in_video: bool = False,
) -> Tuple[torch.Tensor, int]:
"""Get mrope input positions and delta value (Qwen2.5-Omni version).
Differences from MRotaryEmbedding:
1. Add audio support (and related `audio_feature_lengths`).
2. Add `use_audio_in_video` option to read audio from video inputs.
In this case, audio and vision position ids will be split into
chunks and interleaved.
Example:
(V_i are vision position ids, A_i are audio position ids)
|V_1 ... V_n|A_1 ... A_n|V_n+1 ... V_2n|A_n+1 ... A_2n|...
|vision chunk 1|audio chunk 1|vision chunk 2|audio chunk 2 |...
"""
# TODO(fyabc): refactor and share more code with
# _vl_get_input_positions_tensor.
thinker_config = hf_config.thinker_config
audio_token_id = thinker_config.audio_token_index
image_token_id = thinker_config.image_token_index
video_token_id = thinker_config.video_token_index
audio_start_token_id = thinker_config.audio_start_token_id
audio_end_token_id = thinker_config.audio_end_token_id
vision_start_token_id = thinker_config.vision_start_token_id
vision_end_token_id = thinker_config.vision_end_token_id
seconds_per_chunk = thinker_config.seconds_per_chunk
spatial_merge_size = thinker_config.vision_config.spatial_merge_size
tokens_per_second = getattr(thinker_config.vision_config,
"tokens_per_second", 25)
if isinstance(image_grid_thw, list):
image_grid_thw = torch.tensor(image_grid_thw)
if isinstance(video_grid_thw, list):
video_grid_thw = torch.tensor(video_grid_thw)
src_item = input_tokens
audio_seqlens = audio_feature_lengths
if not second_per_grid_ts:
second_per_grid_ts = [1] * video_grid_thw.shape[0]
audio_idx = 0
video_idx = 0
image_idx = 0
new_src_item: list[int] = []
llm_pos_ids_list: list[torch.Tensor] = []
idx = 0
while idx < len(src_item):
new_src_item_len = len(new_src_item)
start_idx = llm_pos_ids_list[-1].max() + 1 if len(
llm_pos_ids_list) > 0 else 0
if src_item[idx] not in [
audio_token_id, video_token_id, image_token_id
]:
if use_audio_in_video and idx > 0:
if src_item[idx] == vision_end_token_id and \
src_item[idx - 1] == audio_end_token_id:
# processing the <|audio_eos|> before <|vision_eos|>
start_idx -= 1
elif src_item[idx] == audio_start_token_id and \
src_item[idx - 1] == vision_start_token_id:
# processing the <|audio_bos|> after <|vision_eos|>
start_idx -= 1
new_src_item.append(src_item[idx])
llm_pos_ids = torch.tensor([start_idx],
dtype=torch.long).expand(3, -1)
llm_pos_ids_list.append(llm_pos_ids)
elif src_item[idx] == audio_token_id:
assert audio_seqlens is not None
audio_seqlen = audio_seqlens[audio_idx]
place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1)
new_src_item.extend([audio_token_id] * place_num)
llm_pos_ids = torch.arange(place_num).expand(3, -1) + start_idx
llm_pos_ids_list.append(llm_pos_ids)
audio_idx += 1
elif src_item[idx] == image_token_id:
grid_t = image_grid_thw[image_idx][0]
grid_hs = image_grid_thw[:, 1]
grid_ws = image_grid_thw[:, 2]
t_index = (torch.arange(grid_t) * 1 * tokens_per_second).long()
llm_pos_ids = cls._get_llm_pos_ids_for_vision(
start_idx, image_idx, spatial_merge_size, t_index, grid_hs,
grid_ws)
llm_pos_ids_list.append(llm_pos_ids)
vision_seqlen = image_grid_thw[image_idx].prod() // (
spatial_merge_size**2)
new_src_item.extend([image_token_id] * vision_seqlen)
image_idx += 1
elif src_item[idx] == video_token_id and not use_audio_in_video:
grid_t = video_grid_thw[video_idx][0]
grid_hs = video_grid_thw[:, 1]
grid_ws = video_grid_thw[:, 2]
t_index = (torch.arange(grid_t) *
second_per_grid_ts[video_idx] *
tokens_per_second).long()
llm_pos_ids = cls._get_llm_pos_ids_for_vision(
start_idx, video_idx, spatial_merge_size, t_index, grid_hs,
grid_ws)
llm_pos_ids_list.append(llm_pos_ids)
vision_seqlen = video_grid_thw[video_idx].prod() // (
spatial_merge_size**2)
new_src_item.extend([video_token_id] * vision_seqlen)
video_idx += 1
else:
# read audio from video
assert audio_seqlens is not None
audio_seqlen = audio_seqlens[audio_idx]
vision_seqlen = video_grid_thw[video_idx].prod() // (
spatial_merge_size**2)
grid_t = video_grid_thw[video_idx][0]
grid_h = video_grid_thw[video_idx][1]
grid_w = video_grid_thw[video_idx][2]
grid_hs = video_grid_thw[:, 1]
grid_ws = video_grid_thw[:, 2]
t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk)
t_index = (torch.arange(grid_t) *
second_per_grid_ts[video_idx] *
tokens_per_second).long()
t_index_split_chunk = cls._split_list_into_ranges(
t_index, t_ntoken_per_chunk)
place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1) + 2
pure_audio_len = place_num - 2
added_audio_len = 0
audio_llm_pos_ids_list: List[torch.Tensor] = []
for t_chunk in t_index_split_chunk:
vision_ntoken_per_chunk = len(
t_chunk) * grid_h * grid_w // (spatial_merge_size**2)
new_src_item.extend([video_token_id] *
vision_ntoken_per_chunk)
vision_llm_pos_ids_list = cls._get_llm_pos_ids_for_vision(
start_idx, video_idx, spatial_merge_size, t_chunk,
grid_hs, grid_ws).split(1, dim=1)
llm_pos_ids_list.extend(vision_llm_pos_ids_list)
new_src_item.extend(
min(t_ntoken_per_chunk, pure_audio_len -
added_audio_len) * [audio_token_id])
audio_start_idx = start_idx if len(
audio_llm_pos_ids_list
) == 0 else audio_llm_pos_ids_list[-1][0].item() + 1
if min(t_ntoken_per_chunk,
pure_audio_len - added_audio_len) > 0:
audio_llm_pos_ids_list = (torch.arange(
min(t_ntoken_per_chunk, pure_audio_len -
added_audio_len)).expand(3, -1) +
audio_start_idx).split(1,
dim=1)
else:
audio_llm_pos_ids_list = []
added_audio_len += min(t_ntoken_per_chunk,
pure_audio_len - added_audio_len)
llm_pos_ids_list.extend(audio_llm_pos_ids_list)
if added_audio_len < pure_audio_len:
new_src_item.extend(
(pure_audio_len - added_audio_len) * [audio_token_id])
audio_llm_pos_ids_list = (
torch.arange(pure_audio_len - added_audio_len).expand(
3, -1) + llm_pos_ids_list[-1].max() + 1).split(
1, dim=1)
llm_pos_ids_list.extend(audio_llm_pos_ids_list)
audio_idx += 1
video_idx += 1
# move to the next token
idx += len(new_src_item) - new_src_item_len
llm_positions = torch.cat(llm_pos_ids_list, dim=1)
mrope_position_delta = torch.cat(llm_pos_ids_list,
dim=1).max() + 1 - len(src_item)
llm_positions = llm_positions[:, context_len:seq_len]
return llm_positions, mrope_position_delta
@staticmethod
def _get_llm_pos_ids_for_vision(
start_idx: int,
vision_idx: int,
spatial_merge_size: int,
t_index: List[int],
grid_hs: torch.Tensor,
grid_ws: torch.Tensor,
) -> torch.Tensor:
llm_pos_ids_list = []
llm_grid_h = grid_hs[vision_idx] // spatial_merge_size
llm_grid_w = grid_ws[vision_idx] // spatial_merge_size
h_index = (torch.arange(llm_grid_h).view(1, -1, 1).expand(
len(t_index), -1, llm_grid_w).flatten())
w_index = (torch.arange(llm_grid_w).view(1, 1, -1).expand(
len(t_index), llm_grid_h, -1).flatten())
t_index_tensor = torch.Tensor(t_index).to(llm_grid_h.device).view(
-1, 1).expand(-1, llm_grid_h * llm_grid_w).long().flatten()
_llm_pos_ids = torch.stack([t_index_tensor, h_index, w_index])
llm_pos_ids_list.append(_llm_pos_ids + start_idx)
llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1)
return llm_pos_ids
@staticmethod
def _split_list_into_ranges(lst: torch.Tensor,
interval: int) -> List[List[int]]:
ranges: List[List[int]] = [[]
for _ in range((max(lst) // interval) + 1)]
for num in lst:
index = num // interval
ranges[index].append(num)
return ranges
@staticmethod @staticmethod
def get_next_input_positions( def get_next_input_positions(
mrope_position_delta: int, mrope_position_delta: int,
...@@ -1144,6 +1412,58 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -1144,6 +1412,58 @@ class MRotaryEmbedding(RotaryEmbedding):
mrope_position_delta + seq_len, mrope_position_delta + seq_len,
).expand(3, -1) ).expand(3, -1)
@classmethod
def omni_get_updates_use_audio_in_video(
cls,
thinker_config: PretrainedConfig,
audio_len: int,
video_grid_thw: Union[List[int], torch.Tensor],
video_second_per_grid_t: float,
) -> List[int]:
"""Get video prompt updates when `use_audio_in_video` is True.
In this case, audio and vision update ids will be split into
chunks and interleaved (details in `_omni_get_input_positions_tensor`).
<|video_bos|><|VIDEO|><|video_eos|> =>
<|video_bos|><|audio_bos|>(... chunks ...)<|audio_eos|><|video_eos|>
"""
audio_token_id = thinker_config.audio_token_index
video_token_id = thinker_config.video_token_index
audio_start_token_id = thinker_config.audio_start_token_id
audio_end_token_id = thinker_config.audio_end_token_id
seconds_per_chunk = thinker_config.seconds_per_chunk
spatial_merge_size = thinker_config.vision_config.spatial_merge_size
tokens_per_second = getattr(thinker_config.vision_config,
"tokens_per_second", 25)
grid_t = video_grid_thw[0]
grid_h = video_grid_thw[1]
grid_w = video_grid_thw[2]
t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk)
t_index = (torch.arange(grid_t) * video_second_per_grid_t *
tokens_per_second).long()
t_index_split_chunk = cls._split_list_into_ranges(
t_index, t_ntoken_per_chunk)
updates = [audio_start_token_id]
added_audio_len = 0
for t_chunk in t_index_split_chunk:
vision_ntoken_per_chunk = len(t_chunk) * grid_h * grid_w // (
spatial_merge_size**2)
updates.extend([video_token_id] * vision_ntoken_per_chunk)
audio_chunk_size = min(t_ntoken_per_chunk,
audio_len - added_audio_len)
updates.extend(audio_chunk_size * [audio_token_id])
added_audio_len += audio_chunk_size
if added_audio_len < audio_len:
updates.extend((audio_len - added_audio_len) * [audio_token_id])
updates.extend([audio_end_token_id])
return updates
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {} _ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""Utility methods for model layers.""" """Utility methods for model layers."""
from typing import Tuple from typing import Callable, Optional, Tuple
import torch import torch
from vllm import _custom_ops as ops
from vllm import envs
from vllm.platforms import current_platform
def get_token_bin_counts_and_mask( def get_token_bin_counts_and_mask(
tokens: torch.Tensor, tokens: torch.Tensor,
...@@ -47,12 +51,49 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, ...@@ -47,12 +51,49 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
output_tokens_tensor, vocab_size, num_seqs) output_tokens_tensor, vocab_size, num_seqs)
repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat( repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat(
1, vocab_size) 1, vocab_size)
logits[logits > 0] /= torch.where(prompt_mask | output_mask,
repetition_penalties, 1.0)[logits > 0] # If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
logits[logits <= 0] *= torch.where(prompt_mask | output_mask, penalties = torch.where(prompt_mask | output_mask, repetition_penalties,
repetition_penalties, 1.0)[logits <= 0] 1.0)
# If logits are positive, divide by penalty, otherwise multiply by penalty.
scaling = torch.where(logits > 0, 1.0 / penalties, penalties)
logits *= scaling
# We follow the definition in OpenAI API. # We follow the definition in OpenAI API.
# Refer to https://platform.openai.com/docs/api-reference/parameter-details # Refer to https://platform.openai.com/docs/api-reference/parameter-details
logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts
logits -= presence_penalties.unsqueeze(dim=1) * output_mask logits -= presence_penalties.unsqueeze(dim=1) * output_mask
return logits return logits
def rocm_unquantized_gemm(x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None):
from vllm.platforms.rocm import on_mi250_mi300
k = weight.shape[1]
use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi250_mi300() and \
x.dtype in [torch.float16, torch.bfloat16] \
and k % 8 == 0 and bias is None)
if use_skinny is not True:
return torch.nn.functional.linear(x, weight, bias)
x_view = x.view(-1, x.size(-1))
n = x_view.shape[0]
m = weight.shape[0]
cu_count = current_platform.get_cu_count()
if m > 8 and 0 < n < 4:
out = ops.wvSplitK(weight, x_view, cu_count)
return out.view(*x.shape[:-1], weight.shape[0])
elif m % 4 == 0 and n == 1 and k <= 8192:
out = ops.LLMM1(weight, x_view, 4)
return out.view(*x.shape[:-1], weight.shape[0])
return torch.nn.functional.linear(x, weight, bias)
def dispatch_unquantized_gemm() -> Callable[..., torch.Tensor]:
# if current_platform.is_rocm():
# return rocm_unquantized_gemm
return torch.nn.functional.linear
...@@ -13,6 +13,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank, ...@@ -13,6 +13,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding) QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding)
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
from vllm.model_executor.parameter import BasevLLMParameter from vllm.model_executor.parameter import BasevLLMParameter
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
...@@ -55,8 +56,8 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase): ...@@ -55,8 +56,8 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase):
else: else:
return torch.matmul(x, layer.weight) return torch.matmul(x, layer.weight)
else: else:
return F.linear(x, layer.weight, bias) return dispatch_unquantized_gemm()(x, layer.weight, bias)
def embedding(self, layer: torch.nn.Module, def embedding(self, layer: torch.nn.Module,
input_: torch.Tensor) -> torch.Tensor: input_: torch.Tensor) -> torch.Tensor:
return F.embedding(input_, layer.weight) return F.embedding(input_, layer.weight)
......
...@@ -613,8 +613,12 @@ class ShardedStateLoader(BaseModelLoader): ...@@ -613,8 +613,12 @@ class ShardedStateLoader(BaseModelLoader):
DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors" DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"
def __init__(self, load_config: LoadConfig): def __init__(self,
load_config: LoadConfig,
runai_model_streamer: bool = False):
super().__init__(load_config) super().__init__(load_config)
self.runai_model_streamer = runai_model_streamer
extra_config = ({} if load_config.model_loader_extra_config is None extra_config = ({} if load_config.model_loader_extra_config is None
else load_config.model_loader_extra_config.copy()) else load_config.model_loader_extra_config.copy())
self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN) self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN)
...@@ -661,7 +665,7 @@ class ShardedStateLoader(BaseModelLoader): ...@@ -661,7 +665,7 @@ class ShardedStateLoader(BaseModelLoader):
def _prepare_weights(self, model_name_or_path: str, def _prepare_weights(self, model_name_or_path: str,
revision: Optional[str]): revision: Optional[str]):
if os.path.isdir(model_name_or_path): if is_s3(model_name_or_path) or os.path.isdir(model_name_or_path):
return model_name_or_path return model_name_or_path
else: else:
allow_patterns = ["*.safetensors"] allow_patterns = ["*.safetensors"]
...@@ -680,12 +684,13 @@ class ShardedStateLoader(BaseModelLoader): ...@@ -680,12 +684,13 @@ class ShardedStateLoader(BaseModelLoader):
device_config = vllm_config.device_config device_config = vllm_config.device_config
model_config = vllm_config.model_config model_config = vllm_config.model_config
target_device = torch.device(device_config.device) target_device = torch.device(device_config.device)
from safetensors.torch import safe_open
from vllm.distributed import get_tensor_model_parallel_rank from vllm.distributed import get_tensor_model_parallel_rank
local_model_path = self._prepare_weights(model_config.model, model_weights = model_config.model
model_config.revision) if hasattr(model_config, "model_weights"):
model_weights = model_config.model_weights
local_model_path = model_weights
with set_default_torch_dtype(model_config.dtype): with set_default_torch_dtype(model_config.dtype):
with target_device: with target_device:
...@@ -697,40 +702,56 @@ class ShardedStateLoader(BaseModelLoader): ...@@ -697,40 +702,56 @@ class ShardedStateLoader(BaseModelLoader):
local_model_path, local_model_path,
self.pattern.format(rank=rank, part="*"), self.pattern.format(rank=rank, part="*"),
) )
filepaths = glob.glob(pattern)
filepaths = []
if is_s3(local_model_path):
file_pattern = f"*{self.pattern.format(rank=rank, part=' * ')}"
filepaths = s3_glob(path=local_model_path,
allow_pattern=[file_pattern])
else:
filepaths = glob.glob(pattern)
if not filepaths: if not filepaths:
# TODO: support un-sharded checkpoints too # TODO: support un-sharded checkpoints too
raise ValueError( raise ValueError(
f"Could not find checkpoint files '{pattern}', only " f"Could not find checkpoint files '{pattern}', only "
f"pre-sharded checkpoints are currently supported!") f"pre-sharded checkpoints are currently supported!")
state_dict = self._filter_subtensors(model.state_dict()) state_dict = self._filter_subtensors(model.state_dict())
for path in filepaths: for key, tensor in self.iterate_over_files(filepaths):
with safe_open(path, framework="pt") as f: # If loading with LoRA enabled, additional padding may
for key in f.keys(): # noqa: SIM118 # be added to certain parameters. We only load into a
tensor = f.get_tensor(key) # narrowed view of the parameter data.
# If loading with LoRA enabled, additional padding may param_data = state_dict[key].data
# be added to certain parameters. We only load into a param_shape = state_dict[key].shape
# narrowed view of the parameter data. for dim, size in enumerate(tensor.shape):
param_data = state_dict[key].data if size < param_shape[dim]:
param_shape = state_dict[key].shape param_data = param_data.narrow(dim, 0, size)
for dim, size in enumerate(tensor.shape): if tensor.shape != param_shape:
if size < param_shape[dim]: logger.warning(
param_data = param_data.narrow(dim, 0, size) "loading tensor of shape %s into "
if tensor.shape != param_shape: "parameter '%s' of shape %s",
logger.warning( tensor.shape,
"loading tensor of shape %s into " key,
"parameter '%s' of shape %s", param_shape,
tensor.shape, )
key, param_data.copy_(tensor)
param_shape, state_dict.pop(key)
)
param_data.copy_(tensor)
state_dict.pop(key)
if state_dict: if state_dict:
raise ValueError( raise ValueError(
f"Missing keys {tuple(state_dict)} in loaded state!") f"Missing keys {tuple(state_dict)} in loaded state!")
return model.eval() return model.eval()
def iterate_over_files(
self, paths) -> Generator[Tuple[str, torch.Tensor], None, None]:
if self.runai_model_streamer:
yield from runai_safetensors_weights_iterator(paths, True)
else:
from safetensors.torch import safe_open
for path in paths:
with safe_open(path, framework="pt") as f:
for key in f.keys(): # noqa: SIM118
tensor = f.get_tensor(key)
yield key, tensor
@staticmethod @staticmethod
def save_model( def save_model(
model: torch.nn.Module, model: torch.nn.Module,
...@@ -1517,4 +1538,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: ...@@ -1517,4 +1538,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
if load_config.load_format == LoadFormat.RUNAI_STREAMER: if load_config.load_format == LoadFormat.RUNAI_STREAMER:
return RunaiModelStreamerLoader(load_config) return RunaiModelStreamerLoader(load_config)
if load_config.load_format == LoadFormat.RUNAI_STREAMER_SHARDED:
return ShardedStateLoader(load_config, runai_model_streamer=True)
return DefaultModelLoader(load_config) return DefaultModelLoader(load_config)
...@@ -31,15 +31,6 @@ def set_default_torch_dtype(dtype: torch.dtype): ...@@ -31,15 +31,6 @@ def set_default_torch_dtype(dtype: torch.dtype):
torch.set_default_dtype(old_dtype) torch.set_default_dtype(old_dtype)
def is_transformers_impl_compatible(
arch: str,
module: Optional["transformers.PreTrainedModel"] = None) -> bool:
mod = module or getattr(transformers, arch, None)
if mod is None:
return False
return mod.is_backend_compatible()
def resolve_transformers_arch(model_config: ModelConfig, def resolve_transformers_arch(model_config: ModelConfig,
architectures: list[str]): architectures: list[str]):
for i, arch in enumerate(architectures): for i, arch in enumerate(architectures):
...@@ -56,20 +47,32 @@ def resolve_transformers_arch(model_config: ModelConfig, ...@@ -56,20 +47,32 @@ def resolve_transformers_arch(model_config: ModelConfig,
# "AutoModelFor<Task>": "<your-repo-name>--<config-name>", # "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
# }, # },
auto_modules = { auto_modules = {
name: get_class_from_dynamic_module(module, model_config.model) name:
get_class_from_dynamic_module(module,
model_config.model,
revision=model_config.revision)
for name, module in sorted(auto_map.items(), key=lambda x: x[0]) for name, module in sorted(auto_map.items(), key=lambda x: x[0])
} }
custom_model_module = auto_modules.get("AutoModel") model_module = getattr(transformers, arch, None)
if model_module is None:
if "AutoModel" not in auto_map:
raise ValueError(
f"Cannot find model module. '{arch}' is not a registered "
"model in the Transformers library (only relevant if the "
"model is meant to be in Transformers) and 'AutoModel' is "
"not present in the model config's 'auto_map' (relevant "
"if the model is custom).")
model_module = auto_modules["AutoModel"]
# TODO(Isotr0py): Further clean up these raises. # TODO(Isotr0py): Further clean up these raises.
# perhaps handled them in _ModelRegistry._raise_for_unsupported? # perhaps handled them in _ModelRegistry._raise_for_unsupported?
if model_config.model_impl == ModelImpl.TRANSFORMERS: if model_config.model_impl == ModelImpl.TRANSFORMERS:
if not is_transformers_impl_compatible(arch, custom_model_module): if not model_module.is_backend_compatible():
raise ValueError( raise ValueError(
f"The Transformers implementation of {arch} is not " f"The Transformers implementation of {arch} is not "
"compatible with vLLM.") "compatible with vLLM.")
architectures[i] = "TransformersForCausalLM" architectures[i] = "TransformersForCausalLM"
if model_config.model_impl == ModelImpl.AUTO: if model_config.model_impl == ModelImpl.AUTO:
if not is_transformers_impl_compatible(arch, custom_model_module): if not model_module.is_backend_compatible():
raise ValueError( raise ValueError(
f"{arch} has no vLLM implementation and the Transformers " f"{arch} has no vLLM implementation and the Transformers "
"implementation is not compatible with vLLM. Try setting " "implementation is not compatible with vLLM. Try setting "
...@@ -132,10 +135,10 @@ def get_model_architecture( ...@@ -132,10 +135,10 @@ def get_model_architecture(
architectures = ["QuantMixtralForCausalLM"] architectures = ["QuantMixtralForCausalLM"]
vllm_supported_archs = ModelRegistry.get_supported_archs() vllm_supported_archs = ModelRegistry.get_supported_archs()
is_vllm_supported = any(arch in vllm_supported_archs vllm_not_supported = not any(arch in vllm_supported_archs
for arch in architectures) for arch in architectures)
if (not is_vllm_supported if (model_config.model_impl == ModelImpl.TRANSFORMERS or
or model_config.model_impl == ModelImpl.TRANSFORMERS): model_config.model_impl != ModelImpl.VLLM and vllm_not_supported):
architectures = resolve_transformers_arch(model_config, architectures) architectures = resolve_transformers_arch(model_config, architectures)
model_cls, arch = ModelRegistry.resolve_model_cls(architectures) model_cls, arch = ModelRegistry.resolve_model_cls(architectures)
......
...@@ -24,7 +24,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig ...@@ -24,7 +24,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.deepspeedfp import ( from vllm.model_executor.layers.quantization.deepspeedfp import (
DeepSpeedFPConfig, DeepSpeedFPParameter) DeepSpeedFPConfig, DeepSpeedFPParameter)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...@@ -435,7 +434,6 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant): ...@@ -435,7 +434,6 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant):
self.unpadded_vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size) config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
...@@ -462,14 +460,6 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant): ...@@ -462,14 +460,6 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant):
sampling_metadata) sampling_metadata)
return logits return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
......
...@@ -15,11 +15,10 @@ from vllm.model_executor.layers.fused_moe import FusedMoE ...@@ -15,11 +15,10 @@ from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import (SamplerOutput,
SamplingMetadata, get_sampler)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name) default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs) MultiModalKwargs)
...@@ -527,7 +526,6 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -527,7 +526,6 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
logit_scale = getattr(config, "logit_scale", 1.0) logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
self.vocab_size, logit_scale) self.vocab_size, logit_scale)
self.sampler = get_sampler()
def _validate_image_sizes( def _validate_image_sizes(
self, images: List[torch.Tensor]) -> List[torch.Tensor]: self, images: List[torch.Tensor]) -> List[torch.Tensor]:
...@@ -653,14 +651,6 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -653,14 +651,6 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
sampling_metadata) sampling_metadata)
return logits return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
# SPDX-License-Identifier: Apache-2.0 Adapted from # SPDX-License-Identifier: Apache-2.0 Adapted from
# https://github.com/huggingface/transformers/tree/main/src/transformers/models/aya_vision # https://github.com/huggingface/transformers/tree/main/src/transformers/models/aya_vision
from functools import cached_property
from typing import (Iterable, Literal, Mapping, Optional, Sequence, Set, Tuple, from typing import (Iterable, Literal, Mapping, Optional, Sequence, Set, Tuple,
TypedDict, Union, cast) TypedDict, Union, cast)
...@@ -17,7 +16,6 @@ from transformers.models.got_ocr2.image_processing_got_ocr2 import ( ...@@ -17,7 +16,6 @@ from transformers.models.got_ocr2.image_processing_got_ocr2 import (
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.jsontree import json_map_leaves from vllm.jsontree import json_map_leaves
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargs from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargs
...@@ -461,17 +459,3 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -461,17 +459,3 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states, return self.language_model.compute_logits(hidden_states,
sampling_metadata) sampling_metadata)
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler
return get_sampler()
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)
...@@ -43,7 +43,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, ...@@ -43,7 +43,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...@@ -504,7 +503,6 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP, ...@@ -504,7 +503,6 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
...@@ -532,14 +530,6 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP, ...@@ -532,14 +530,6 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
sampling_metadata) sampling_metadata)
return logits return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
......
...@@ -24,7 +24,6 @@ from vllm.model_executor.layers.mamba.mamba_mixer2 import ( ...@@ -24,7 +24,6 @@ from vllm.model_executor.layers.mamba.mamba_mixer2 import (
MambaMixer2, extra_groups_for_head_shards) MambaMixer2, extra_groups_for_head_shards)
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...@@ -462,7 +461,6 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, ...@@ -462,7 +461,6 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size) config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
...@@ -538,14 +536,6 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, ...@@ -538,14 +536,6 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
sampling_metadata) sampling_metadata)
return logits return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
......
...@@ -37,7 +37,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -37,7 +37,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...@@ -791,7 +790,6 @@ class BartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant): ...@@ -791,7 +790,6 @@ class BartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant):
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size) config.vocab_size)
self.sampler = get_sampler()
def forward( def forward(
self, self,
...@@ -828,14 +826,6 @@ class BartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant): ...@@ -828,14 +826,6 @@ class BartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant):
sampling_metadata) sampling_metadata)
return logits return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
stacked_params_mapping = { stacked_params_mapping = {
"q_proj": { "q_proj": {
"param_name": "qkv_proj", "param_name": "qkv_proj",
......
...@@ -11,8 +11,10 @@ from vllm.compilation.decorators import support_torch_compile ...@@ -11,8 +11,10 @@ from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, PoolerConfig, VllmConfig from vllm.config import CacheConfig, PoolerConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import (get_act_and_mul_fn,
get_act_fn)
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.pooler import (CrossEncodingPooler, Pooler, from vllm.model_executor.layers.pooler import (CrossEncodingPooler, Pooler,
...@@ -108,6 +110,7 @@ class BertEncoder(nn.Module): ...@@ -108,6 +110,7 @@ class BertEncoder(nn.Module):
def __init__(self, def __init__(self,
vllm_config: VllmConfig, vllm_config: VllmConfig,
bias: bool = True,
rotary_kwargs: Optional[dict] = None, rotary_kwargs: Optional[dict] = None,
prefix: str = ""): prefix: str = ""):
super().__init__() super().__init__()
...@@ -118,6 +121,7 @@ class BertEncoder(nn.Module): ...@@ -118,6 +121,7 @@ class BertEncoder(nn.Module):
BertLayer(config=config, BertLayer(config=config,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
bias=bias,
rotary_kwargs=rotary_kwargs, rotary_kwargs=rotary_kwargs,
prefix=f"{prefix}.layer.{layer_idx}") prefix=f"{prefix}.layer.{layer_idx}")
for layer_idx in range(config.num_hidden_layers) for layer_idx in range(config.num_hidden_layers)
...@@ -139,6 +143,7 @@ class BertLayer(nn.Module): ...@@ -139,6 +143,7 @@ class BertLayer(nn.Module):
config: BertConfig, config: BertConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
bias: bool = True,
rotary_kwargs: Optional[dict] = None, rotary_kwargs: Optional[dict] = None,
prefix: str = ""): prefix: str = ""):
super().__init__() super().__init__()
...@@ -149,19 +154,31 @@ class BertLayer(nn.Module): ...@@ -149,19 +154,31 @@ class BertLayer(nn.Module):
layer_norm_eps=config.layer_norm_eps, layer_norm_eps=config.layer_norm_eps,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
bias=bias,
rotary_kwargs=rotary_kwargs, rotary_kwargs=rotary_kwargs,
prefix=f"{prefix}.attention") prefix=f"{prefix}.attention")
self.intermediate = BertIntermediate( if config.hidden_act in ["silu", "gelu_and_mul"]:
hidden_size=config.hidden_size, self.intermediate = BertGatedIntermediate(
intermediate_size=config.intermediate_size, hidden_size=config.hidden_size,
hidden_act=config.hidden_act, intermediate_size=config.intermediate_size,
quant_config=quant_config, hidden_act=config.hidden_act,
prefix=f"{prefix}.intermediate") bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.intermediate")
else:
self.intermediate = BertIntermediate(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.intermediate")
self.output = BertOutput(hidden_size=config.hidden_size, self.output = BertOutput(hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
layer_norm_eps=config.layer_norm_eps, layer_norm_eps=config.layer_norm_eps,
bias=bias,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.output") prefix=f"{prefix}.output")
...@@ -181,6 +198,7 @@ class BertAttention(nn.Module): ...@@ -181,6 +198,7 @@ class BertAttention(nn.Module):
layer_norm_eps: float, layer_norm_eps: float,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
bias: bool = True,
rotary_kwargs: Optional[dict] = None, rotary_kwargs: Optional[dict] = None,
prefix: str = "", prefix: str = "",
): ):
...@@ -190,11 +208,13 @@ class BertAttention(nn.Module): ...@@ -190,11 +208,13 @@ class BertAttention(nn.Module):
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
bias=bias,
rotary_kwargs=rotary_kwargs, rotary_kwargs=rotary_kwargs,
prefix=f"{prefix}.output") prefix=f"{prefix}.output")
self.output = BertSelfOutput(hidden_size=hidden_size, self.output = BertSelfOutput(hidden_size=hidden_size,
layer_norm_eps=layer_norm_eps, layer_norm_eps=layer_norm_eps,
bias=bias,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.output") prefix=f"{prefix}.output")
...@@ -215,6 +235,7 @@ class BertSelfAttention(nn.Module): ...@@ -215,6 +235,7 @@ class BertSelfAttention(nn.Module):
num_attention_heads: int, num_attention_heads: int,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
bias: bool = True,
rotary_kwargs: Optional[dict] = None, rotary_kwargs: Optional[dict] = None,
prefix: str = "", prefix: str = "",
): ):
...@@ -240,7 +261,7 @@ class BertSelfAttention(nn.Module): ...@@ -240,7 +261,7 @@ class BertSelfAttention(nn.Module):
head_size=self.head_dim, head_size=self.head_dim,
total_num_heads=self.total_num_heads, total_num_heads=self.total_num_heads,
total_num_kv_heads=self.total_num_kv_heads, total_num_kv_heads=self.total_num_kv_heads,
bias=True, bias=bias,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.qkv_proj") prefix=f"{prefix}.qkv_proj")
...@@ -278,12 +299,13 @@ class BertSelfOutput(nn.Module): ...@@ -278,12 +299,13 @@ class BertSelfOutput(nn.Module):
def __init__(self, def __init__(self,
hidden_size: int, hidden_size: int,
layer_norm_eps: float, layer_norm_eps: float,
bias: bool = True,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""): prefix: str = ""):
super().__init__() super().__init__()
self.dense = RowParallelLinear(input_size=hidden_size, self.dense = RowParallelLinear(input_size=hidden_size,
output_size=hidden_size, output_size=hidden_size,
bias=True, bias=bias,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.dense") prefix=f"{prefix}.dense")
self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
...@@ -301,12 +323,13 @@ class BertIntermediate(nn.Module): ...@@ -301,12 +323,13 @@ class BertIntermediate(nn.Module):
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size: int,
hidden_act: str, hidden_act: str,
bias: bool = True,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""): prefix: str = ""):
super().__init__() super().__init__()
self.dense = ColumnParallelLinear(input_size=hidden_size, self.dense = ColumnParallelLinear(input_size=hidden_size,
output_size=intermediate_size, output_size=intermediate_size,
bias=True, bias=bias,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.dense") prefix=f"{prefix}.dense")
self.intermediate_act_fn = get_act_fn(hidden_act) self.intermediate_act_fn = get_act_fn(hidden_act)
...@@ -317,19 +340,46 @@ class BertIntermediate(nn.Module): ...@@ -317,19 +340,46 @@ class BertIntermediate(nn.Module):
return hidden_states return hidden_states
class BertGatedIntermediate(nn.Module):
# for NomciBert and GteModel
def __init__(self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
bias: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.act_fn = get_act_and_mul_fn(hidden_act)
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
gate_up, _ = self.gate_up_proj(hidden_states)
hidden_states = self.act_fn(gate_up)
return hidden_states
class BertOutput(nn.Module): class BertOutput(nn.Module):
def __init__(self, def __init__(self,
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size: int,
layer_norm_eps: float, layer_norm_eps: float,
bias: bool = True,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""): prefix: str = ""):
super().__init__() super().__init__()
self.dense = RowParallelLinear(input_size=intermediate_size, self.dense = RowParallelLinear(input_size=intermediate_size,
output_size=hidden_size, output_size=hidden_size,
bias=True, bias=bias,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.dense") prefix=f"{prefix}.dense")
...@@ -343,19 +393,32 @@ class BertOutput(nn.Module): ...@@ -343,19 +393,32 @@ class BertOutput(nn.Module):
class BertModel(nn.Module, SupportsQuant): class BertModel(nn.Module, SupportsQuant):
packed_modules_mapping = {"qkv_proj": ["query", "key", "value"]} packed_modules_mapping = {
"qkv_proj": ["query", "key", "value"],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
def __init__(self, def __init__(self,
*, *,
vllm_config: VllmConfig, vllm_config: VllmConfig,
prefix: str = "", prefix: str = "",
embedding_class: type = BertEmbedding, embedding_class: type = BertEmbedding,
bias: bool = True,
rotary_kwargs: Optional[dict] = None, rotary_kwargs: Optional[dict] = None,
add_pooling_layer: bool = False): add_pooling_layer: bool = False):
super().__init__() super().__init__()
"""
For BertModel, all linear layers have bias.
For NomicBertModel, all linear layers do not have bias.
"""
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
self.embeddings = embedding_class(config) self.embeddings = embedding_class(config)
self.encoder = BertEncoder(vllm_config=vllm_config, self.encoder = BertEncoder(vllm_config=vllm_config,
bias=bias,
rotary_kwargs=rotary_kwargs, rotary_kwargs=rotary_kwargs,
prefix=f"{prefix}.encoder") prefix=f"{prefix}.encoder")
self.pooler = BertPooler(config) if add_pooling_layer else None self.pooler = BertPooler(config) if add_pooling_layer else None
...@@ -387,6 +450,8 @@ class BertModel(nn.Module, SupportsQuant): ...@@ -387,6 +450,8 @@ class BertModel(nn.Module, SupportsQuant):
("qkv_proj", "query", "q"), ("qkv_proj", "query", "q"),
("qkv_proj", "key", "k"), ("qkv_proj", "key", "k"),
("qkv_proj", "value", "v"), ("qkv_proj", "value", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
...@@ -546,3 +611,115 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, ...@@ -546,3 +611,115 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
token_type_ids=token_type_ids) token_type_ids=token_type_ids)
class NomicBertEmbeddingModel(BertEmbeddingModel):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={
"emb_ln": "embeddings.LayerNorm",
"layers": "layer",
"attn.Wqkv": "attention.self.qkv_proj",
"attn.out_proj": "attention.output.dense",
'norm1': "attention.output.LayerNorm",
'mlp.fc11': "intermediate.up_proj",
'mlp.fc12': "intermediate.gate_proj",
'mlp.fc2': "output.dense",
'norm2': "output.LayerNorm",
})
def _build_model(self,
vllm_config: VllmConfig,
prefix: str = "") -> BertModel:
config = vllm_config.model_config.hf_config
assert config.__class__.__name__ == "NomicBertConfig"
assert config.activation_function == "swiglu"
# Assume NomicBertModel all linear layers do not have bias
assert not config.mlp_fc1_bias
assert not config.mlp_fc2_bias
assert not config.qkv_proj_bias
config.layer_norm_eps = config.layer_norm_epsilon
config.position_embedding_type = "rotary"
config.intermediate_size = config.n_inner
config.hidden_act = "silu"
config.hidden_size = config.n_embd
config.num_hidden_layers = config.n_layer
head_dim = config.hidden_size // config.num_attention_heads
rotary_kwargs = {
"head_size": head_dim,
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
"max_position": config.max_trained_positions,
"base": config.rotary_emb_base,
"rope_scaling": {
"rope_type": "dynamic",
"factor": config.rotary_scaling_factor
}
}
return BertModel(vllm_config=vllm_config,
prefix=prefix,
bias=False,
rotary_kwargs=rotary_kwargs,
embedding_class=BertEmbedding)
class GteEmbeddingModel(BertEmbeddingModel):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={
"attention.qkv_proj": "attention.self.qkv_proj",
"attention.o_proj": "attention.output.dense",
'attn_ln': "attention.output.LayerNorm",
'mlp.down_proj': "output.dense",
'mlp_ln': "output.LayerNorm",
})
def _build_model(self,
vllm_config: VllmConfig,
prefix: str = "") -> BertModel:
config = vllm_config.model_config.hf_config
assert config.__class__.__name__ == "GteConfig"
assert config.position_embedding_type == "rope"
assert config.hidden_act == "gelu"
config.position_embedding_type = "rotary"
config.hidden_act = "gelu_and_mul"
head_dim = config.hidden_size // config.num_attention_heads
rotary_kwargs = {
"head_size": head_dim,
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
"max_position": config.max_position_embeddings,
"base": config.rope_theta,
}
model = BertModel(vllm_config=vllm_config,
prefix=prefix,
rotary_kwargs=rotary_kwargs,
embedding_class=BertEmbedding)
# GteModel only gate_up_proj does not have bias.
# Hack method learned from vllm/model_executor/models/glm.py
for layer in model.encoder.layer:
layer.intermediate.gate_up_proj.bias = None
layer.intermediate.skip_bias_add = True
return model
def split_up_gate_proj(self, weights: Iterable[Tuple[str, torch.Tensor]]):
n = "mlp.up_gate_proj"
for name, weight in weights:
if n in name:
up, gate = weight.chunk(2, dim=0)
yield name.replace(n, "intermediate.up_proj"), up
yield name.replace(n, "intermediate.gate_proj"), gate
else:
yield name, weight
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weights = self.hf_to_vllm_mapper.apply(weights)
weights = self.split_up_gate_proj(weights)
self.model.load_weights(weights)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment