Commit dcb5624a authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.5' into v0.8.5-dev

parents 55880ca2 ba41cc90
# SPDX-License-Identifier: Apache-2.0
from typing import Dict, List, Optional, Tuple
import torch
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
BITBLAS_OPTIMIZE_FEATURES, BITBLAS_SUPPORTED_GROUP_SIZES,
MINIMUM_BITBLAS_VERSION, bitblas_make_empty_g_idx, bitblas_sort_g_idx,
check_bitblas_supports_shape, query_bitblas_supported_quant_types,
unpack_gptq_qweight, unpack_gptq_qzeros)
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
logger = init_logger(__name__)
class BitBLASLinearKernel(MPLinearKernel):
OPT_FEATURES: List[int] = BITBLAS_OPTIMIZE_FEATURES
ENABLE_TUNING: bool = True
MATMUL_LAYOUT: str = "nt"
BITBLAS_DTYPES: Dict[torch.dtype, str] = {
torch.float32: "float32",
torch.float16: "float16",
torch.bfloat16: "bfloat16",
torch.half: "float16",
torch.int8: "int8",
}
bitblas_matmul: object = None
def __init__(
self,
c: MPLinearLayerConfig,
w_q_param_name: str,
w_s_param_name: str,
w_zp_param_name: Optional[str] = None,
w_gidx_param_name: Optional[str] = None,
bitblas_quant_config: Optional[QuantizationConfig] = None,
):
self.quant_config = bitblas_quant_config
super().__init__(c, w_q_param_name, w_s_param_name, w_zp_param_name,
w_gidx_param_name)
def repack_bitblas_from_gptq(
self,
b_q_weight: torch.Tensor,
scales: torch.Tensor,
qzeros: Optional[torch.Tensor] = None,
):
from bitblas.quantization.utils import general_compress
assert self.bitblas_matmul is not None, "bitblas_matmul is None"
quant_config = self.quant_config
# qweight in gptq old quant linear stored with
# (outfeatures, infeatures), should be transposed.
qweight = b_q_weight.T.contiguous().view(
quant_config.torch_storage_dtype) # type: ignore[union-attr]
intweight = unpack_gptq_qweight(
qweight,
quant_config.weight_bits).contiguous() # type: ignore[union-attr]
if self.bitblas_matmul.weight_transform is not None: # type: ignore[attr-defined]
qweight = self.bitblas_matmul.weight_transform( # type: ignore[attr-defined]
intweight.cpu()).cuda()
# scales in gptq old quant linear stored with
# (infeatures // group_size, outfeatures), should be transposed.
scales = scales.T.contiguous()
if qzeros is None:
return qweight, scales, None
# qzeros should be de-quantized to int zeros.
weight_bits = quant_config.weight_bits # type: ignore[union-attr]
intzeros = unpack_gptq_qzeros(qzeros, weight_bits).T.contiguous()
zeros: Optional[torch.Tensor] = None
zeros_mode = self.bitblas_matmul.config.zeros_mode # type: ignore[attr-defined]
if zeros_mode == "original":
zeros = intzeros.to(torch.float16).contiguous()
elif zeros_mode == "rescale":
assert zeros is not None, "zeros should not be None"
zeros[:, :] = intzeros.to(torch.float16)[:, :] * scales[:, :]
elif zeros_mode == "quantized":
zeros = (
torch.Tensor(
general_compress(
intzeros.T.contiguous().cpu().numpy(),
weight_bits,
)).to(qweight.device).
to(quant_config.torch_storage_dtype # type: ignore[union-attr]
).contiguous())
else:
raise ValueError("Unsupported zeros type: {}".format(zeros_mode))
return qweight, scales, zeros
@classmethod
def get_min_capability(cls) -> int:
return 70
@classmethod
def can_implement(cls,
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
is_bitblas_installed = True
try:
import bitblas
if bitblas.__version__ < MINIMUM_BITBLAS_VERSION:
raise ImportError(
"bitblas version is wrong. Please "
f"install bitblas>={MINIMUM_BITBLAS_VERSION}")
except ImportError:
is_bitblas_installed = False
if not is_bitblas_installed:
return False, "bitblas is not installed. Please install bitblas "\
"by running `pip install bitblas>="\
f"{MINIMUM_BITBLAS_VERSION}`"
quant_types = query_bitblas_supported_quant_types(c.zero_points)
if c.weight_type not in quant_types:
return False, (f"Quant type ({c.weight_type}) not supported by"
f" BitBLAS, supported types are: {quant_types}")
if c.group_size not in BITBLAS_SUPPORTED_GROUP_SIZES:
return False, (f"Group size ({c.group_size}) not supported by "
"BitBLAS, supported group sizes are: "
f"{BITBLAS_SUPPORTED_GROUP_SIZES}")
return check_bitblas_supports_shape(
c.partition_weight_shape[1], # out_features
c.partition_weight_shape[0], # in_features
c.full_weight_shape[0], # in_features
c.group_size)
# note assumes that
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_scale` is: {input_dim = 0, output_dim = 1}
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
device = getattr(layer, self.w_q_name).device
c = self.config
quant_config = self.quant_config
# Default names since bitblas requires empty parameters for these,
# TODO: remove this requirement from bitblas (allow optional tensors)
if self.w_gidx_name is None:
self.w_gidx_name = "g_idx"
if self.w_zp_name is None:
self.w_zp_name = "qzeros"
if c.has_g_idx:
g_idx, g_idx_sort_indices = bitblas_sort_g_idx(
getattr(layer, self.w_gidx_name))
self._transform_param(layer, self.w_gidx_name, lambda _: g_idx)
layer.g_idx_sort_indices = g_idx_sort_indices
else:
setattr(layer, self.w_gidx_name, bitblas_make_empty_g_idx(device))
layer.g_idx_sort_indices = bitblas_make_empty_g_idx(device)
if c.zero_points:
raise NotImplementedError("Zero points not supported by BitBLAS")
else:
setattr(layer, self.w_zp_name, bitblas_make_empty_g_idx(device))
# Repack weights
bitblas_qweight, bitblas_scales, bitblas_qzeros = (
self.repack_bitblas_from_gptq(
layer.qweight,
layer.scales,
None if quant_config.is_sym else # type: ignore[union-attr]
layer.qzeros, # type: ignore[union-attr]
))
replace_parameter(layer, self.w_q_name, bitblas_qweight)
replace_parameter(layer, self.w_s_name, bitblas_scales)
if bitblas_qzeros is not None:
replace_parameter(layer, self.w_zp_name, bitblas_qzeros)
def configure_bitblas_matmul(
self,
infeatures: int,
outfeatures: int,
params_dtype: torch.dtype,
bias: bool,
) -> None:
enable_tuning = self.ENABLE_TUNING
layout = self.MATMUL_LAYOUT
bits = self.quant_config.weight_bits # type: ignore[union-attr]
self._configure_bitblas_matmul(
infeatures,
outfeatures,
params_dtype,
enable_tuning,
bias,
layout,
bits,
)
def _configure_bitblas_matmul(
self,
infeatures,
outfeatures,
params_dtype,
enable_tuning,
bias,
layout,
bits,
):
from bitblas import MatmulConfig
bitblas_dtype = self.BITBLAS_DTYPES[params_dtype]
quant_config = self.quant_config
with_scaling = False
with_zeros = False
group_size = quant_config.group_size # type: ignore[union-attr]
zeros_mode = quant_config.zeros_mode # type: ignore[union-attr]
if quant_config.quant_method == "gptq": # type: ignore[union-attr]
with_scaling = True
with_zeros = True
W_dtype = f"uint{bits}"
if quant_config.is_sym: # type: ignore[union-attr]
with_zeros = False
W_dtype = f"int{bits}"
else:
raise ValueError(
f"Unsupported quant_method {quant_config.quant_method}" # type: ignore[union-attr]
) # type: ignore[union-attr]
matmul_config = MatmulConfig(
M=self.OPT_FEATURES,
N=outfeatures,
K=infeatures,
A_dtype=bitblas_dtype,
W_dtype=W_dtype,
out_dtype=bitblas_dtype,
accum_dtype="int32" if bitblas_dtype == "int8" else bitblas_dtype,
storage_dtype=quant_config. # type: ignore[union-attr]
storage_dtype, # type: ignore[union-attr]
with_scaling=with_scaling,
with_zeros=with_zeros,
group_size=group_size,
with_bias=bias,
layout=layout,
zeros_mode=zeros_mode,
)
self.bitblas_matmul = self._get_or_create_bitblas_operator(
matmul_config, enable_tuning)
def _get_or_create_bitblas_operator(self, config, enable_tuning):
from bitblas import Matmul, auto_detect_nvidia_target
from bitblas.cache import get_database_path, global_operator_cache
BITBLAS_DATABASE_PATH = get_database_path()
BITBLAS_TARGET = auto_detect_nvidia_target()
if global_operator_cache.size() == 0:
global_operator_cache.load_from_database(BITBLAS_DATABASE_PATH,
BITBLAS_TARGET)
bitblas_matmul = global_operator_cache.get(config)
if bitblas_matmul is None:
bitblas_matmul = Matmul(config,
target=BITBLAS_TARGET,
enable_tuning=False)
if enable_tuning:
bitblas_matmul.hardware_aware_finetune(topk=20)
global_operator_cache.add(config, bitblas_matmul)
global_operator_cache.save_into_database(
BITBLAS_DATABASE_PATH, BITBLAS_TARGET)
TUNING_MESSAGE = (
f"BitBLAS Operator {config} tuned and saved to database.")
logger.info(TUNING_MESSAGE)
else:
_message = f"BitBLAS Operator {config} created without tuning. "
logger.info(_message)
else:
_message = f"BitBLAS Operator {config} retrieved from cache."
logger.info(_message)
return bitblas_matmul
def apply_gptq_bitblas_linear(
self,
layer: torch.nn.Module,
x: torch.Tensor,
) -> torch.Tensor:
output_size_per_partition = self.config.partition_weight_shape[1]
out_shape = x.shape[:-1] + (output_size_per_partition, )
args = [x, layer.qweight, layer.scales]
if self.bitblas_matmul.config.with_zeros: # type: ignore[attr-defined]
args.append(layer.qzeros)
output = self.bitblas_matmul(*args) # type: ignore[operator]
return output.view(out_shape)
def apply_weights(self, layer, x, bias=None):
NOT_IMPLEMENT_MESSAGE = (
f"{self.__class__.__name__}.apply_weights is not implemented. "
"Please use BitBLASLinearKernel.apply_gptq_bitblas_linear instead")
raise NotImplementedError(NOT_IMPLEMENT_MESSAGE)
......@@ -26,17 +26,14 @@ class MacheteLinearKernel(MPLinearKernel):
@classmethod
def can_implement(cls,
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
if c.has_g_idx and\
c.partition_weight_shape[0] != c.full_weight_shape[0]:
return False, "Act reordering currently not supported by Machete, "\
"when the input features are partitioned across "\
"devices"
if c.zero_points:
return False, "Zero points currently not supported by "\
" Compressed Tensors + Machete. (Kernel supports it"\
" but CompressedTensorsWNA16 does not so support has"\
" not been added to MacheteWNA16Kernel yet"
return False, "Zero points currently not supported by Machete"
if c.weight_type not in query_machete_supported_quant_types(
c.zero_points):
......
......@@ -9,7 +9,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
MARLIN_SUPPORTED_GROUP_SIZES, apply_gptq_marlin_linear,
check_marlin_supports_shape, marlin_is_k_full, marlin_make_empty_g_idx,
marlin_make_workspace, marlin_permute_scales, marlin_sort_g_idx,
query_marlin_supported_quant_types)
marlin_zero_points, query_marlin_supported_quant_types, unpack_cols)
from vllm.model_executor.parameter import (BasevLLMParameter,
permute_param_layout_)
......@@ -25,10 +25,6 @@ class MarlinLinearKernel(MPLinearKernel):
@classmethod
def can_implement(cls,
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
if c.zero_points:
return False, "Zero points currently not supported by "\
" MarlinLinearKernel. Will be added when AWQMarlin "\
"is migrated over to using MPLinearKernel backend"
quant_types = query_marlin_supported_quant_types(c.zero_points)
if c.weight_type not in quant_types:
......@@ -67,28 +63,6 @@ class MarlinLinearKernel(MPLinearKernel):
if self.w_zp_name is None:
self.w_zp_name = "w_zp"
if c.has_g_idx:
g_idx, g_idx_sort_indices = marlin_sort_g_idx(
getattr(layer, self.w_gidx_name))
self._transform_param(layer, self.w_gidx_name, lambda _: g_idx)
layer.g_idx_sort_indices = g_idx_sort_indices
else:
setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device))
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
if c.zero_points:
pass
# TODO (lucas): add the following when AWQMarlin is migrated over to
# using MPLinearKernel backend
# self._transform_param(layer, self.w_zp_name, lambda x: \
# marlin_zero_points(
# x,
# size_k=c.partition_weight_shape[0],
# size_n=c.partition_weight_shape[1],
# num_bits=c.weight_type.size_bits))
else:
setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device))
def transform_w_q(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
......@@ -108,6 +82,28 @@ class MarlinLinearKernel(MPLinearKernel):
group_size=c.group_size)
return x
if c.has_g_idx:
g_idx, g_idx_sort_indices = marlin_sort_g_idx(
getattr(layer, self.w_gidx_name))
self._transform_param(layer, self.w_gidx_name, lambda _: g_idx)
layer.g_idx_sort_indices = g_idx_sort_indices
else:
setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device))
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
if c.zero_points:
grouped_k = (c.partition_weight_shape[0] //
c.group_size if c.group_size != -1 else 1)
self._transform_param(layer, self.w_zp_name, lambda x: \
marlin_zero_points(
unpack_cols(x.t(), c.weight_type.size_bits,
grouped_k,
c.partition_weight_shape[1]),
size_k=grouped_k,
size_n=c.partition_weight_shape[1],
num_bits=c.weight_type.size_bits))
else:
setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device))
self._transform_param(layer, self.w_q_name, transform_w_q)
self._transform_param(layer, self.w_s_name, transform_w_s)
......@@ -131,5 +127,6 @@ class MarlinLinearKernel(MPLinearKernel):
wtype=c.weight_type,
input_size_per_partition=c.partition_weight_shape[0],
output_size_per_partition=c.partition_weight_shape[1],
has_zp=self.config.zero_points,
is_k_full=self.is_k_full,
bias=bias)
......@@ -38,6 +38,9 @@ class BaseKVCacheMethod(QuantizeMethodBase):
requires_grad=False)
layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0),
requires_grad=False)
# Initialize P = softmax(QK^T) scales
layer.prob_scale = torch.nn.Parameter(torch.tensor(-1.0),
requires_grad=False)
def apply(self, layer: torch.nn.Module) -> torch.Tensor:
raise RuntimeError(
......@@ -97,5 +100,38 @@ class BaseKVCacheMethod(QuantizeMethodBase):
"may cause accuracy issues. Please make sure k/v_scale "
"scaling factors are available in the fp8 checkpoint.")
if layer.q_scale > 0.0:
q_scale = layer.q_scale
if current_platform.is_fp8_fnuz():
q_scale *= 2
layer.calculate_kv_scales = False
else:
q_scale = 1.0
if layer.prob_scale > 0.0:
prob_scale = layer.prob_scale
if current_platform.is_fp8_fnuz():
prob_scale *= 2
else:
prob_scale = 1.0
is_singleton_float = lambda x: isinstance(x, float) or isinstance(
x, torch.Tensor) and x.numel() == 1 and x.is_floating_point()
if not is_singleton_float(q_scale) or not is_singleton_float(
prob_scale):
raise ValueError("Only support per-tensor scaling factor"
"for fp8-quantized Q/prob")
# These are used in the final Attention.forward()
layer._q_scale.copy_(q_scale)
layer._prob_scale.copy_(prob_scale)
if q_scale == 1.0 or prob_scale == 1.0:
logger.warning_once(
f"Using Q scale {q_scale} and prob scale {prob_scale} "
"with fp8 attention. This may cause accuracy issues. "
"Please make sure Q/prob scaling factors are "
"available in the fp8 checkpoint.")
del layer.k_scale
del layer.v_scale
del layer.q_scale
del layer.prob_scale
# SPDX-License-Identifier: Apache-2.0
import fnmatch
import re
from typing import Any, Dict, List, Optional, cast
import torch
......@@ -125,6 +124,13 @@ class QuarkConfig(QuantizationConfig):
for q_config in q_configs:
q_config["output_tensors"] = None
# In case q_proj output is also quantized, remove the configuration
# to keep qkv consistency.
q_proj_q_config = cast(Dict[str, Any],
layer_quant_config.get("*q_proj"))
if q_proj_q_config is not None:
q_proj_q_config["output_tensors"] = None
return cls(quant_config=config,
kv_cache_group=kv_cache_group,
kv_cache_config=kv_cache_config,
......@@ -289,25 +295,14 @@ class QuarkConfig(QuantizationConfig):
:param name: param name
:return: matching param name for KV cache scale in vLLM
"""
if self.kv_cache_group is None or len(self.kv_cache_group) == 0:
return None
kv_proj_names = [
re.split(r"[*.]", kv_cache)[-1] for kv_cache in self.kv_cache_group
]
if name.endswith(".output_scale"):
if len(kv_proj_names) == 1 and kv_proj_names[0] in name:
kv_output_scale_name = "." + kv_proj_names[0] + ".output_scale"
return name.replace(kv_output_scale_name, ".attn.k_scale")
elif len(kv_proj_names) == 2:
for kv_proj_name in kv_proj_names:
if kv_proj_name in name and kv_proj_name == "k_proj":
return name.replace(".k_proj.output_scale",
".attn.k_scale")
elif kv_proj_name in name and kv_proj_name == "v_proj":
return name.replace(".v_proj.output_scale",
".attn.v_scale")
if name.endswith(".output_scale") and ".k_proj" in name:
return name.replace(".k_proj.output_scale", ".attn.k_scale")
if name.endswith(".output_scale") and ".v_proj" in name:
return name.replace(".v_proj.output_scale", ".attn.v_scale")
if name.endswith(".output_scale") and ".q_proj" in name:
return name.replace(".q_proj.output_scale", ".attn.q_scale")
if name.endswith("self_attn.prob_output_scale"):
return name.replace(".prob_output_scale", ".attn.prob_scale")
# If no matches, return None
return None
......
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple
import torch
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types
MINIMUM_BITBLAS_VERSION = "0.1.0"
BITBLAS_MIN_WEIGHT_SIZE_N = 16
BITBLAS_MIN_WEIGHT_SIZE_K = 16
GPTQ_BITBLAS_MAX_PARALLEL = 16
BITBLAS_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
# For dynamic shape code generation
BITBLAS_OPTIMIZE_FEATURES = [1, 16, 32, 64, 128, 256, 512, 1024]
# If want to enable high performance for contiguous batching
# Please use the following values
BITBLAS_OPTIMIZE_FEATURES_CONTIGUOUS = [16, 32, 64, 128, 256, 512, 1024]
BITBLAS_SUPPORTED_NUM_BITS = [1, 2, 4, 8]
BITBLAS_SUPPORTED_SYM = [False, True]
# Determines the supported quantization types for BitBLAS based on the
# device's capability and whether zero-point (zp) is used.
def query_bitblas_supported_quant_types(has_zp: bool,
device_capability: Optional[int] = None
):
if device_capability is None:
capability_tuple = current_platform.get_device_capability()
device_capability = (-1 if capability_tuple is None else
capability_tuple.to_int())
if device_capability < 70:
return []
if has_zp:
# AWQ style, unsigned + runtime zero-point
return [scalar_types.uint4, scalar_types.uint8]
else:
# GPTQ style, unsigned + symmetric bias
# TODO: once fp8_bitblas is merged into "gptq_bitblas" we should be able
# to add `scalar_types.float8_e4m3fn` here
return [scalar_types.uint4b8, scalar_types.uint8b128]
def _check_bitblas_supported(
quant_type: ScalarType,
group_size: Optional[int],
has_zp: bool,
device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]:
if device_capability is None:
capability_tuple = current_platform.get_device_capability()
device_capability = (-1 if capability_tuple is None else
capability_tuple.to_int())
supported_types = query_bitblas_supported_quant_types(
has_zp, device_capability)
if quant_type not in supported_types:
return (False, f"BitBLAS does not support weight_bits = {quant_type}. "
f"Only types = {supported_types} "
f"are supported (for group_size = {group_size}, "
f"device_capability = {device_capability}, zp = {has_zp}).")
if (group_size is None or group_size not in BITBLAS_SUPPORTED_GROUP_SIZES):
return (False, f"BitBLAS does not support group_size = {group_size}. "
f"Only group_sizes = {BITBLAS_SUPPORTED_GROUP_SIZES} "
"are supported.")
return True, None
def check_bitblas_supported(quant_type: ScalarType,
group_size: int,
has_zp: bool = False,
device_capability: Optional[int] = None) -> bool:
cond, _ = _check_bitblas_supported(quant_type, group_size, has_zp,
device_capability)
return cond
def verify_bitblas_supported(quant_type: ScalarType,
group_size: int,
has_zp: bool = False) -> None:
cond, err_msg = _check_bitblas_supported(quant_type, group_size, has_zp)
if not cond:
assert err_msg is not None
raise ValueError(err_msg)
def verify_bitblas_supports_shape(output_size_per_partition: int,
input_size_per_partition: int,
input_size: int, group_size: int) -> None:
# Validate output_size_per_partition
if output_size_per_partition % BITBLAS_MIN_WEIGHT_SIZE_N != 0:
raise ValueError(f"Weight output_size_per_partition = "
f"{output_size_per_partition} is not divisible by "
f" min_thread_n = {BITBLAS_MIN_WEIGHT_SIZE_N}. "
"Consider reducing tensor_parallel_size or running "
"with --quantization gptq.")
# Validate input_size_per_partition
if input_size_per_partition % BITBLAS_MIN_WEIGHT_SIZE_K != 0:
raise ValueError(f"Weight input_size_per_partition = "
f"{input_size_per_partition} is not divisible "
f"by min_thread_k = {BITBLAS_MIN_WEIGHT_SIZE_K}. "
"Consider reducing tensor_parallel_size or running "
"with --quantization gptq.")
if (group_size < input_size
and input_size_per_partition % group_size != 0):
raise ValueError(
f"Weight input_size_per_partition = {input_size_per_partition}"
f" is not divisible by group_size = {group_size}."
"Consider reducing tensor_parallel_size or running "
"with --quantization gptq.")
def check_bitblas_supports_shape(output_size_per_partition: int,
input_size_per_partition: int,
input_size: int, group_size: int) \
-> Tuple[bool, Optional[str]]:
try:
verify_bitblas_supports_shape(output_size_per_partition,
input_size_per_partition, input_size,
group_size)
except ValueError as e:
return False, e.__str__()
return True, None
def bitblas_is_k_full(act_order: bool, is_row_parallel: bool) -> bool:
return (not act_order) or (act_order and not is_row_parallel)
def bitblas_repeat_scales_on_all_ranks(act_order: bool, group_size: int,
is_row_parallel: bool) -> bool:
# Need to repeat scales on every rank if act_ordering or
# channelwise and RowParallelLinear
is_channelwise = group_size == -1
return act_order or (is_channelwise and is_row_parallel)
def bitblas_make_empty_g_idx(device: torch.device) -> torch.Tensor:
return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device),
requires_grad=False)
def bitblas_make_empty_zp(device: torch.device) -> torch.Tensor:
return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device),
requires_grad=False)
def bitblas_sort_g_idx(
g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
return g_idx[g_idx_sort_indices], g_idx_sort_indices
def unpack_gptq_qzeros(qzeros, bits, is_gptq_v2=False) -> torch.Tensor:
qzeros = qzeros.view(torch.int32)
elems_per_int32 = 32 // bits
unpacked_zeros = torch.zeros(
(qzeros.shape[0], qzeros.shape[1] * elems_per_int32),
dtype=torch.int8,
device=qzeros.device,
requires_grad=False,
)
for col in range(unpacked_zeros.shape[1]):
i = col % elems_per_int32
unpacked_zeros[:, col] = (qzeros[:, col // elems_per_int32] >>
(bits * i)) & 0xF
if not is_gptq_v2:
return unpacked_zeros + 1
return unpacked_zeros
def unpack_gptq_qweight(qweight, bits):
qweight = qweight.view(torch.int8)
elems_per_int8 = 8 // bits
unpacked_weight = torch.zeros(
(qweight.shape[0], qweight.shape[1] * elems_per_int8),
dtype=torch.int8,
device=qweight.device,
requires_grad=False,
)
for col in range(unpacked_weight.shape[1]):
i = col % elems_per_int8
unpacked_weight[:, col] = (qweight[:, col // elems_per_int8] >>
(bits * i))
return torch.bitwise_and(unpacked_weight, 2**bits - 1)
......@@ -151,6 +151,19 @@ def check_marlin_supports_layer(layer: LinearBase, group_size: int) \
group_size=group_size)[0]
def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) \
-> bool:
hidden_size = layer.hidden_size
intermediate_size_per_partition = layer.intermediate_size_per_partition
# gate-up: (n, k) = (intermediate_size_per_partition * 2, hidden_size)
# down: (n, k) = (hidden_size, intermediate_size_per_partition)
# moe marlin requires n % 128 == 0 and k % 64 == 0
return hidden_size % 128 == 0 and \
intermediate_size_per_partition % max(64, group_size) == 0 and \
group_size in [-1, 32, 64, 128]
def marlin_make_workspace(output_size_per_partition: int,
device: torch.device) -> torch.Tensor:
max_workspace_size = (output_size_per_partition //
......@@ -319,6 +332,7 @@ def apply_gptq_marlin_linear(
wtype: ScalarType,
output_size_per_partition: int,
input_size_per_partition: int,
has_zp: bool,
is_k_full: bool,
bias: Optional[torch.Tensor] = None,
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
......@@ -343,8 +357,8 @@ def apply_gptq_marlin_linear(
size_n=output_size_per_partition,
size_k=input_size_per_partition,
is_k_full=is_k_full,
has_zp=False,
use_atomic_add=use_atomic_add,
has_zp=has_zp,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False)
......
# SPDX-License-Identifier: Apache-2.0
from typing import List, Optional, Tuple, Union
from typing import Callable, List, Optional, Tuple, Union
import torch
from vllm import _custom_ops as ops
from vllm import envs
from vllm.config import CompilationLevel, get_current_vllm_config
from vllm.platforms import current_platform
from vllm.utils import W8a8GetCacheJSON
......@@ -19,6 +20,7 @@ W8A8_TRITONJSON=W8a8GetCacheJSON()
# The condition is determined once as the operations
# are time consuming.
USE_ROWWISE_TORCH_SCALED_MM = (current_platform.is_rocm()
and torch.__version__[0:3] >= "2.7"
and current_platform.has_device_capability(94))
def sparse_cutlass_supported() -> bool:
......@@ -132,6 +134,160 @@ def maybe_create_device_identity():
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
def cutlass_w8a8_scaled_mm(*, qinput: torch.Tensor, weight: torch.Tensor,
out_dtype: torch.dtype, scale_a: torch.Tensor,
scale_b: torch.Tensor, bias: torch.Tensor,
output_shape: List, **kwargs) -> torch.Tensor:
# Fused GEMM_DQ
output = ops.cutlass_scaled_mm(qinput,
weight,
out_dtype=out_dtype,
scale_a=scale_a,
scale_b=scale_b,
bias=bias)
return output.view(*output_shape)
def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
weight: torch.Tensor,
out_dtype: torch.dtype,
scale_a: torch.Tensor,
scale_b: torch.Tensor, bias: torch.Tensor,
input_2d: torch.Tensor,
output_shape: List) -> torch.Tensor:
from vllm.platforms.rocm import on_mi250_mi300
if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi250_mi300(
) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0:
output = ops.wvSplitKQ(weight.t(), qinput, out_dtype, scale_a, scale_b,
current_platform.get_cu_count())
else:
output = torch._scaled_mm(qinput,
weight,
out_dtype=out_dtype,
scale_a=scale_a,
scale_b=scale_b,
bias=bias)
return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape)
def torch_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
weight: torch.Tensor,
out_dtype: torch.dtype,
scale_a: torch.Tensor,
scale_b: torch.Tensor, bias: torch.Tensor,
input_2d: torch.Tensor,
output_shape: List) -> torch.Tensor:
output = torch._scaled_mm(qinput,
weight,
out_dtype=out_dtype,
scale_a=scale_a,
scale_b=scale_b,
bias=bias)
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if type(output) is tuple and len(output) == 2:
output = output[0]
return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape)
def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor,
weight: torch.Tensor,
out_dtype: torch.dtype,
scale_a: torch.Tensor,
scale_b: torch.Tensor, bias: torch.Tensor,
input_2d: torch.Tensor,
output_shape: List) -> torch.Tensor:
# Note: Callers of this function should check USE_ROWWISE_TORCH_SCALED_MM
# when using it.
# For now it has only been validated on ROCm platform.
# fp8 rowwise scaling in torch._scaled_mm is introduced in
# https://github.com/pytorch/pytorch/pull/144432 using
# hipBLASLt and ROCm 6.3, which only exists in torch 2.7 and above.
#
# For CUDA platform please validate if the torch._scaled_mm supports
# rowwise scaled GEMM before using it
# Fused GEMM_DQ Rowwise GEMM
output = torch._scaled_mm(qinput,
weight,
out_dtype=out_dtype,
scale_a=scale_a,
scale_b=scale_b.t(),
bias=bias)
output = torch.narrow(output, 0, 0, input_2d.shape[0])
output = output.view(*output_shape)
return output
def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor,
weight: torch.Tensor,
out_dtype: torch.dtype,
scale_a: torch.Tensor,
scale_b: torch.Tensor, bias: torch.Tensor,
input_2d: torch.Tensor,
output_shape: List,
**kwargs) -> torch.Tensor:
# Use unfused DQ due to limitations with scaled_mm
# Symmetric quantized GEMM by definition computes the following:
# C = (s_x * X) (s_w * W) + bias
# This is equivalent to dequantizing the weights and activations
# before applying a GEMM.
#
# In order to compute quantized operands, a quantized kernel
# will rewrite the above like so:
# C = s_w * s_x * (X * W) + bias
#
# For the scaled_mm fallback case, we break this down, since it
# does not support s_w being a vector.
# GEMM
# This computes C = (X * W).
# Output in fp32 to allow subsequent ops to happen in-place
output = torch._scaled_mm(qinput,
weight,
scale_a=TORCH_DEVICE_IDENTITY,
scale_b=TORCH_DEVICE_IDENTITY,
out_dtype=torch.float32)
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if type(output) is tuple and len(output) == 2:
output = output[0]
# Unpad (undo num_token_padding)
output = torch.narrow(output, 0, 0, input_2d.shape[0])
x_scale = torch.narrow(scale_a, 0, 0, input_2d.shape[0])
# DQ
# C = sw * sx * (X * W) + bias
output = output * x_scale * scale_b.t()
if bias is not None:
output = output + bias
return output.to(out_dtype).view(*output_shape)
def dispatch_w8a8_scaled_mm(
cutlass_fp8_supported: bool, per_tensor_weights: bool,
per_tensor_activations: bool, use_per_token_if_dynamic: Optional[bool]
) -> Callable[..., torch.Tensor]:
if cutlass_fp8_supported:
return cutlass_w8a8_scaled_mm
if per_tensor_weights and per_tensor_activations:
if current_platform.is_rocm():
return rocm_per_tensor_w8a8_scaled_mm
return torch_per_tensor_w8a8_scaled_mm
# torch.scaled_mm supports per tensor weights + activations only
# so fallback to naive if per channel or per token
if (use_per_token_if_dynamic and not per_tensor_weights
and not per_tensor_activations and USE_ROWWISE_TORCH_SCALED_MM):
return torch_per_token_w8a8_scaled_mm
return torch_channelwise_w8a8_scaled_mm
# TODO(luka): follow similar pattern for marlin and block-fp8-linear
# https://github.com/vllm-project/vllm/issues/14397
class Fp8LinearOp:
......@@ -157,7 +313,8 @@ class Fp8LinearOp:
if pad_output is None:
config = get_current_vllm_config().compilation_config
pad_output = config.level < CompilationLevel.PIECEWISE
self.output_padding = 17 if pad_output else None
self.output_padding = 17 if (
pad_output and not current_platform.is_rocm()) else None
def apply(
self,
......@@ -196,18 +353,6 @@ class Fp8LinearOp:
input_scale,
scale_ub=input_scale_ub,
use_per_token_if_dynamic=use_per_token_if_dynamic)
# Fused GEMM_DQ
output = ops.cutlass_scaled_mm(qinput,
weight,
out_dtype=out_dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias)
return output.view(*output_shape)
# torch.scaled_mm supports per tensor weights + activations only
# so fallback to naive if per channel or per token
else:
if input.dtype != current_platform.fp8_dtype():
# Maybe apply padding to output, see comment in __init__
......@@ -219,84 +364,21 @@ class Fp8LinearOp:
else:
qinput, x_scale = input_2d, input_scale
per_tensor_weights = (weight_scale.numel() == 1)
per_tensor_activations = (x_scale.numel() == 1)
if per_tensor_weights and per_tensor_activations:
# Fused GEMM_DQ
output = torch._scaled_mm(qinput,
weight,
out_dtype=out_dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias)
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if type(output) is tuple and len(output) == 2:
output = output[0]
return torch.narrow(output, 0, 0,
input_2d.shape[0]).view(*output_shape)
elif (use_per_token_if_dynamic and not per_tensor_weights
and not per_tensor_activations
and USE_ROWWISE_TORCH_SCALED_MM):
# For now validated on ROCm platform
# fp8 rowwise scaling in torch._scaled_mm is introduced in
# https://github.com/pytorch/pytorch/pull/144432 using hipBLASLt
# and ROCm 6.3, which only exists in torch 2.7 and above.
# For CUDA platform please validate if the
# torch._scaled_mm support rowwise scaled GEMM
# Fused GEMM_DQ Rowwise GEMM
output = torch._scaled_mm(qinput,
weight,
out_dtype=out_dtype,
scale_a=x_scale,
scale_b=weight_scale.t(),
bias=bias)
output = torch.narrow(output, 0, 0, input_2d.shape[0])
output = output.view(*output_shape)
return output
else:
# Fallback for channelwise case, where we use unfused DQ
# due to limitations with scaled_mm
# Symmetric quantized GEMM by definition computes the following:
# C = (s_x * X) (s_w * W) + bias
# This is equivalent to dequantizing the weights and activations
# before applying a GEMM.
#
# In order to compute quantized operands, a quantized kernel
# will rewrite the above like so:
# C = s_w * s_x * (X * W) + bias
#
# For the scaled_mm fallback case, we break this down, since it
# does not support s_w being a vector.
# GEMM
# This computes C = (X * W).
# Output in fp32 to allow subsequent ops to happen in-place
output = torch._scaled_mm(qinput,
weight,
scale_a=TORCH_DEVICE_IDENTITY,
scale_b=TORCH_DEVICE_IDENTITY,
out_dtype=torch.float32)
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if type(output) is tuple and len(output) == 2:
output = output[0]
# Unpad (undo num_token_padding)
output = torch.narrow(output, 0, 0, input_2d.shape[0])
x_scale = torch.narrow(x_scale, 0, 0, input_2d.shape[0])
# DQ
# C = sw * sx * (X * W) + bias
output = output * x_scale * weight_scale.t()
if bias is not None:
output = output + bias
return output.to(dtype=input.dtype).view(*output_shape)
per_tensor_weights = (weight_scale.numel() == 1)
per_tensor_activations = (x_scale.numel() == 1)
w8a8_scaled_mm_func = dispatch_w8a8_scaled_mm(
self.cutlass_fp8_supported, per_tensor_weights,
per_tensor_activations, use_per_token_if_dynamic)
return w8a8_scaled_mm_func(qinput=qinput,
weight=weight,
out_dtype=out_dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias,
input_2d=input_2d,
output_shape=output_shape)
def apply_int8_linear(
......
......@@ -46,20 +46,12 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
return x.flatten(-2)
def _apply_rotary_emb(
def _apply_rotary_emb_torch(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
is_neox_style: bool,
) -> torch.Tensor:
"""
Args:
x: [num_tokens, num_heads, head_size]
cos: [num_tokens, head_size // 2]
sin: [num_tokens, head_size // 2]
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
positional embeddings.
"""
cos = cos.unsqueeze(-2).to(x.dtype)
sin = sin.unsqueeze(-2).to(x.dtype)
if is_neox_style:
......@@ -75,6 +67,24 @@ def _apply_rotary_emb(
return torch.stack((o1, o2), dim=-1).flatten(-2)
def _apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor,
is_neox_style: bool) -> torch.Tensor:
"""
Args:
x: [num_tokens, num_heads, head_size]
cos: [num_tokens, head_size // 2]
sin: [num_tokens, head_size // 2]
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
positional embeddings.
"""
if current_platform.is_cuda_alike():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
return apply_rotary_emb(x.unsqueeze(0), cos, sin,
not is_neox_style).squeeze(0)
else:
return _apply_rotary_emb_torch(x, cos, sin, is_neox_style)
@CustomOp.register("rotary_embedding")
class RotaryEmbedding(CustomOp):
"""Original rotary positional embedding."""
......@@ -141,14 +151,16 @@ class RotaryEmbedding(CustomOp):
query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., :self.rotary_dim]
query_pass = query[..., self.rotary_dim:]
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
query_rot = _apply_rotary_emb_torch(query_rot, cos, sin,
self.is_neox_style)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., :self.rotary_dim]
key_pass = key[..., self.rotary_dim:]
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
key_rot = _apply_rotary_emb_torch(key_rot, cos, sin,
self.is_neox_style)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
......@@ -988,8 +1000,9 @@ class MRotaryEmbedding(RotaryEmbedding):
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
@staticmethod
@classmethod
def get_input_positions(
cls,
input_tokens: List[int],
hf_config: PretrainedConfig,
image_grid_thw: Optional[Union[List[List[int]], torch.Tensor]],
......@@ -997,6 +1010,8 @@ class MRotaryEmbedding(RotaryEmbedding):
second_per_grid_ts: Optional[List[float]],
context_len: int = 0,
seq_len: Optional[int] = None,
audio_feature_lengths: Optional[torch.Tensor] = None,
use_audio_in_video: bool = False,
) -> Tuple[List[List[int]], int]:
"""Get mrope input positions and delta value."""
......@@ -1006,7 +1021,7 @@ class MRotaryEmbedding(RotaryEmbedding):
second_per_grid_ts
llm_positions, mrope_position_delta = \
MRotaryEmbedding.get_input_positions_tensor(
cls.get_input_positions_tensor(
input_tokens=input_tokens,
hf_config=hf_config,
image_grid_thw=image_grid_thw,
......@@ -1014,12 +1029,52 @@ class MRotaryEmbedding(RotaryEmbedding):
second_per_grid_ts=second_per_grid_ts,
context_len=context_len,
seq_len=seq_len,
audio_feature_lengths=audio_feature_lengths,
use_audio_in_video=use_audio_in_video,
)
return llm_positions.tolist(), mrope_position_delta
@staticmethod
@classmethod
def get_input_positions_tensor(
cls,
input_tokens: List[int],
hf_config: PretrainedConfig,
image_grid_thw: Union[List[List[int]], torch.Tensor],
video_grid_thw: Union[List[List[int]], torch.Tensor],
second_per_grid_ts: List[float],
context_len: int = 0,
seq_len: Optional[int] = None,
audio_feature_lengths: Optional[torch.Tensor] = None,
use_audio_in_video: bool = False,
) -> Tuple[torch.Tensor, int]:
from vllm.transformers_utils.config import thinker_uses_mrope
if thinker_uses_mrope(hf_config):
return cls._omni_get_input_positions_tensor(
input_tokens=input_tokens,
hf_config=hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
context_len=context_len,
seq_len=seq_len,
audio_feature_lengths=audio_feature_lengths,
use_audio_in_video=use_audio_in_video,
)
else:
return cls._vl_get_input_positions_tensor(
input_tokens=input_tokens,
hf_config=hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
context_len=context_len,
seq_len=seq_len,
)
@classmethod
def _vl_get_input_positions_tensor(
cls,
input_tokens: List[int],
hf_config: PretrainedConfig,
image_grid_thw: Union[List[List[int]], torch.Tensor],
......@@ -1037,11 +1092,6 @@ class MRotaryEmbedding(RotaryEmbedding):
tokens_per_second = getattr(hf_config.vision_config,
"tokens_per_second", 1.0)
if isinstance(image_grid_thw, torch.Tensor):
image_grid_thw = image_grid_thw.tolist()
if isinstance(video_grid_thw, torch.Tensor):
video_grid_thw = video_grid_thw.tolist()
input_tokens_tensor = torch.tensor(input_tokens)
vision_start_indices = torch.argwhere(
input_tokens_tensor == vision_start_token_id).squeeze(1)
......@@ -1121,6 +1171,224 @@ class MRotaryEmbedding(RotaryEmbedding):
return llm_positions, mrope_position_delta
@classmethod
def _omni_get_input_positions_tensor(
cls,
input_tokens: List[int],
hf_config: PretrainedConfig,
image_grid_thw: Union[List[List[int]], torch.Tensor],
video_grid_thw: Union[List[List[int]], torch.Tensor],
second_per_grid_ts: Optional[List[float]] = None,
context_len: int = 0,
seq_len: Optional[int] = None,
audio_feature_lengths: Optional[torch.Tensor] = None,
use_audio_in_video: bool = False,
) -> Tuple[torch.Tensor, int]:
"""Get mrope input positions and delta value (Qwen2.5-Omni version).
Differences from MRotaryEmbedding:
1. Add audio support (and related `audio_feature_lengths`).
2. Add `use_audio_in_video` option to read audio from video inputs.
In this case, audio and vision position ids will be split into
chunks and interleaved.
Example:
(V_i are vision position ids, A_i are audio position ids)
|V_1 ... V_n|A_1 ... A_n|V_n+1 ... V_2n|A_n+1 ... A_2n|...
|vision chunk 1|audio chunk 1|vision chunk 2|audio chunk 2 |...
"""
# TODO(fyabc): refactor and share more code with
# _vl_get_input_positions_tensor.
thinker_config = hf_config.thinker_config
audio_token_id = thinker_config.audio_token_index
image_token_id = thinker_config.image_token_index
video_token_id = thinker_config.video_token_index
audio_start_token_id = thinker_config.audio_start_token_id
audio_end_token_id = thinker_config.audio_end_token_id
vision_start_token_id = thinker_config.vision_start_token_id
vision_end_token_id = thinker_config.vision_end_token_id
seconds_per_chunk = thinker_config.seconds_per_chunk
spatial_merge_size = thinker_config.vision_config.spatial_merge_size
tokens_per_second = getattr(thinker_config.vision_config,
"tokens_per_second", 25)
if isinstance(image_grid_thw, list):
image_grid_thw = torch.tensor(image_grid_thw)
if isinstance(video_grid_thw, list):
video_grid_thw = torch.tensor(video_grid_thw)
src_item = input_tokens
audio_seqlens = audio_feature_lengths
if not second_per_grid_ts:
second_per_grid_ts = [1] * video_grid_thw.shape[0]
audio_idx = 0
video_idx = 0
image_idx = 0
new_src_item: list[int] = []
llm_pos_ids_list: list[torch.Tensor] = []
idx = 0
while idx < len(src_item):
new_src_item_len = len(new_src_item)
start_idx = llm_pos_ids_list[-1].max() + 1 if len(
llm_pos_ids_list) > 0 else 0
if src_item[idx] not in [
audio_token_id, video_token_id, image_token_id
]:
if use_audio_in_video and idx > 0:
if src_item[idx] == vision_end_token_id and \
src_item[idx - 1] == audio_end_token_id:
# processing the <|audio_eos|> before <|vision_eos|>
start_idx -= 1
elif src_item[idx] == audio_start_token_id and \
src_item[idx - 1] == vision_start_token_id:
# processing the <|audio_bos|> after <|vision_eos|>
start_idx -= 1
new_src_item.append(src_item[idx])
llm_pos_ids = torch.tensor([start_idx],
dtype=torch.long).expand(3, -1)
llm_pos_ids_list.append(llm_pos_ids)
elif src_item[idx] == audio_token_id:
assert audio_seqlens is not None
audio_seqlen = audio_seqlens[audio_idx]
place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1)
new_src_item.extend([audio_token_id] * place_num)
llm_pos_ids = torch.arange(place_num).expand(3, -1) + start_idx
llm_pos_ids_list.append(llm_pos_ids)
audio_idx += 1
elif src_item[idx] == image_token_id:
grid_t = image_grid_thw[image_idx][0]
grid_hs = image_grid_thw[:, 1]
grid_ws = image_grid_thw[:, 2]
t_index = (torch.arange(grid_t) * 1 * tokens_per_second).long()
llm_pos_ids = cls._get_llm_pos_ids_for_vision(
start_idx, image_idx, spatial_merge_size, t_index, grid_hs,
grid_ws)
llm_pos_ids_list.append(llm_pos_ids)
vision_seqlen = image_grid_thw[image_idx].prod() // (
spatial_merge_size**2)
new_src_item.extend([image_token_id] * vision_seqlen)
image_idx += 1
elif src_item[idx] == video_token_id and not use_audio_in_video:
grid_t = video_grid_thw[video_idx][0]
grid_hs = video_grid_thw[:, 1]
grid_ws = video_grid_thw[:, 2]
t_index = (torch.arange(grid_t) *
second_per_grid_ts[video_idx] *
tokens_per_second).long()
llm_pos_ids = cls._get_llm_pos_ids_for_vision(
start_idx, video_idx, spatial_merge_size, t_index, grid_hs,
grid_ws)
llm_pos_ids_list.append(llm_pos_ids)
vision_seqlen = video_grid_thw[video_idx].prod() // (
spatial_merge_size**2)
new_src_item.extend([video_token_id] * vision_seqlen)
video_idx += 1
else:
# read audio from video
assert audio_seqlens is not None
audio_seqlen = audio_seqlens[audio_idx]
vision_seqlen = video_grid_thw[video_idx].prod() // (
spatial_merge_size**2)
grid_t = video_grid_thw[video_idx][0]
grid_h = video_grid_thw[video_idx][1]
grid_w = video_grid_thw[video_idx][2]
grid_hs = video_grid_thw[:, 1]
grid_ws = video_grid_thw[:, 2]
t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk)
t_index = (torch.arange(grid_t) *
second_per_grid_ts[video_idx] *
tokens_per_second).long()
t_index_split_chunk = cls._split_list_into_ranges(
t_index, t_ntoken_per_chunk)
place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1) + 2
pure_audio_len = place_num - 2
added_audio_len = 0
audio_llm_pos_ids_list: List[torch.Tensor] = []
for t_chunk in t_index_split_chunk:
vision_ntoken_per_chunk = len(
t_chunk) * grid_h * grid_w // (spatial_merge_size**2)
new_src_item.extend([video_token_id] *
vision_ntoken_per_chunk)
vision_llm_pos_ids_list = cls._get_llm_pos_ids_for_vision(
start_idx, video_idx, spatial_merge_size, t_chunk,
grid_hs, grid_ws).split(1, dim=1)
llm_pos_ids_list.extend(vision_llm_pos_ids_list)
new_src_item.extend(
min(t_ntoken_per_chunk, pure_audio_len -
added_audio_len) * [audio_token_id])
audio_start_idx = start_idx if len(
audio_llm_pos_ids_list
) == 0 else audio_llm_pos_ids_list[-1][0].item() + 1
if min(t_ntoken_per_chunk,
pure_audio_len - added_audio_len) > 0:
audio_llm_pos_ids_list = (torch.arange(
min(t_ntoken_per_chunk, pure_audio_len -
added_audio_len)).expand(3, -1) +
audio_start_idx).split(1,
dim=1)
else:
audio_llm_pos_ids_list = []
added_audio_len += min(t_ntoken_per_chunk,
pure_audio_len - added_audio_len)
llm_pos_ids_list.extend(audio_llm_pos_ids_list)
if added_audio_len < pure_audio_len:
new_src_item.extend(
(pure_audio_len - added_audio_len) * [audio_token_id])
audio_llm_pos_ids_list = (
torch.arange(pure_audio_len - added_audio_len).expand(
3, -1) + llm_pos_ids_list[-1].max() + 1).split(
1, dim=1)
llm_pos_ids_list.extend(audio_llm_pos_ids_list)
audio_idx += 1
video_idx += 1
# move to the next token
idx += len(new_src_item) - new_src_item_len
llm_positions = torch.cat(llm_pos_ids_list, dim=1)
mrope_position_delta = torch.cat(llm_pos_ids_list,
dim=1).max() + 1 - len(src_item)
llm_positions = llm_positions[:, context_len:seq_len]
return llm_positions, mrope_position_delta
@staticmethod
def _get_llm_pos_ids_for_vision(
start_idx: int,
vision_idx: int,
spatial_merge_size: int,
t_index: List[int],
grid_hs: torch.Tensor,
grid_ws: torch.Tensor,
) -> torch.Tensor:
llm_pos_ids_list = []
llm_grid_h = grid_hs[vision_idx] // spatial_merge_size
llm_grid_w = grid_ws[vision_idx] // spatial_merge_size
h_index = (torch.arange(llm_grid_h).view(1, -1, 1).expand(
len(t_index), -1, llm_grid_w).flatten())
w_index = (torch.arange(llm_grid_w).view(1, 1, -1).expand(
len(t_index), llm_grid_h, -1).flatten())
t_index_tensor = torch.Tensor(t_index).to(llm_grid_h.device).view(
-1, 1).expand(-1, llm_grid_h * llm_grid_w).long().flatten()
_llm_pos_ids = torch.stack([t_index_tensor, h_index, w_index])
llm_pos_ids_list.append(_llm_pos_ids + start_idx)
llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1)
return llm_pos_ids
@staticmethod
def _split_list_into_ranges(lst: torch.Tensor,
interval: int) -> List[List[int]]:
ranges: List[List[int]] = [[]
for _ in range((max(lst) // interval) + 1)]
for num in lst:
index = num // interval
ranges[index].append(num)
return ranges
@staticmethod
def get_next_input_positions(
mrope_position_delta: int,
......@@ -1144,6 +1412,58 @@ class MRotaryEmbedding(RotaryEmbedding):
mrope_position_delta + seq_len,
).expand(3, -1)
@classmethod
def omni_get_updates_use_audio_in_video(
cls,
thinker_config: PretrainedConfig,
audio_len: int,
video_grid_thw: Union[List[int], torch.Tensor],
video_second_per_grid_t: float,
) -> List[int]:
"""Get video prompt updates when `use_audio_in_video` is True.
In this case, audio and vision update ids will be split into
chunks and interleaved (details in `_omni_get_input_positions_tensor`).
<|video_bos|><|VIDEO|><|video_eos|> =>
<|video_bos|><|audio_bos|>(... chunks ...)<|audio_eos|><|video_eos|>
"""
audio_token_id = thinker_config.audio_token_index
video_token_id = thinker_config.video_token_index
audio_start_token_id = thinker_config.audio_start_token_id
audio_end_token_id = thinker_config.audio_end_token_id
seconds_per_chunk = thinker_config.seconds_per_chunk
spatial_merge_size = thinker_config.vision_config.spatial_merge_size
tokens_per_second = getattr(thinker_config.vision_config,
"tokens_per_second", 25)
grid_t = video_grid_thw[0]
grid_h = video_grid_thw[1]
grid_w = video_grid_thw[2]
t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk)
t_index = (torch.arange(grid_t) * video_second_per_grid_t *
tokens_per_second).long()
t_index_split_chunk = cls._split_list_into_ranges(
t_index, t_ntoken_per_chunk)
updates = [audio_start_token_id]
added_audio_len = 0
for t_chunk in t_index_split_chunk:
vision_ntoken_per_chunk = len(t_chunk) * grid_h * grid_w // (
spatial_merge_size**2)
updates.extend([video_token_id] * vision_ntoken_per_chunk)
audio_chunk_size = min(t_ntoken_per_chunk,
audio_len - added_audio_len)
updates.extend(audio_chunk_size * [audio_token_id])
added_audio_len += audio_chunk_size
if added_audio_len < audio_len:
updates.extend((audio_len - added_audio_len) * [audio_token_id])
updates.extend([audio_end_token_id])
return updates
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
......
# SPDX-License-Identifier: Apache-2.0
"""Utility methods for model layers."""
from typing import Tuple
from typing import Callable, Optional, Tuple
import torch
from vllm import _custom_ops as ops
from vllm import envs
from vllm.platforms import current_platform
def get_token_bin_counts_and_mask(
tokens: torch.Tensor,
......@@ -47,12 +51,49 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
output_tokens_tensor, vocab_size, num_seqs)
repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat(
1, vocab_size)
logits[logits > 0] /= torch.where(prompt_mask | output_mask,
repetition_penalties, 1.0)[logits > 0]
logits[logits <= 0] *= torch.where(prompt_mask | output_mask,
repetition_penalties, 1.0)[logits <= 0]
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
penalties = torch.where(prompt_mask | output_mask, repetition_penalties,
1.0)
# If logits are positive, divide by penalty, otherwise multiply by penalty.
scaling = torch.where(logits > 0, 1.0 / penalties, penalties)
logits *= scaling
# We follow the definition in OpenAI API.
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts
logits -= presence_penalties.unsqueeze(dim=1) * output_mask
return logits
def rocm_unquantized_gemm(x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None):
from vllm.platforms.rocm import on_mi250_mi300
k = weight.shape[1]
use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi250_mi300() and \
x.dtype in [torch.float16, torch.bfloat16] \
and k % 8 == 0 and bias is None)
if use_skinny is not True:
return torch.nn.functional.linear(x, weight, bias)
x_view = x.view(-1, x.size(-1))
n = x_view.shape[0]
m = weight.shape[0]
cu_count = current_platform.get_cu_count()
if m > 8 and 0 < n < 4:
out = ops.wvSplitK(weight, x_view, cu_count)
return out.view(*x.shape[:-1], weight.shape[0])
elif m % 4 == 0 and n == 1 and k <= 8192:
out = ops.LLMM1(weight, x_view, 4)
return out.view(*x.shape[:-1], weight.shape[0])
return torch.nn.functional.linear(x, weight, bias)
def dispatch_unquantized_gemm() -> Callable[..., torch.Tensor]:
# if current_platform.is_rocm():
# return rocm_unquantized_gemm
return torch.nn.functional.linear
......@@ -13,6 +13,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding)
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
from vllm.model_executor.parameter import BasevLLMParameter
from vllm.model_executor.utils import set_weight_attrs
......@@ -55,8 +56,8 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase):
else:
return torch.matmul(x, layer.weight)
else:
return F.linear(x, layer.weight, bias)
return dispatch_unquantized_gemm()(x, layer.weight, bias)
def embedding(self, layer: torch.nn.Module,
input_: torch.Tensor) -> torch.Tensor:
return F.embedding(input_, layer.weight)
......
......@@ -613,8 +613,12 @@ class ShardedStateLoader(BaseModelLoader):
DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"
def __init__(self, load_config: LoadConfig):
def __init__(self,
load_config: LoadConfig,
runai_model_streamer: bool = False):
super().__init__(load_config)
self.runai_model_streamer = runai_model_streamer
extra_config = ({} if load_config.model_loader_extra_config is None
else load_config.model_loader_extra_config.copy())
self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN)
......@@ -661,7 +665,7 @@ class ShardedStateLoader(BaseModelLoader):
def _prepare_weights(self, model_name_or_path: str,
revision: Optional[str]):
if os.path.isdir(model_name_or_path):
if is_s3(model_name_or_path) or os.path.isdir(model_name_or_path):
return model_name_or_path
else:
allow_patterns = ["*.safetensors"]
......@@ -680,12 +684,13 @@ class ShardedStateLoader(BaseModelLoader):
device_config = vllm_config.device_config
model_config = vllm_config.model_config
target_device = torch.device(device_config.device)
from safetensors.torch import safe_open
from vllm.distributed import get_tensor_model_parallel_rank
local_model_path = self._prepare_weights(model_config.model,
model_config.revision)
model_weights = model_config.model
if hasattr(model_config, "model_weights"):
model_weights = model_config.model_weights
local_model_path = model_weights
with set_default_torch_dtype(model_config.dtype):
with target_device:
......@@ -697,40 +702,56 @@ class ShardedStateLoader(BaseModelLoader):
local_model_path,
self.pattern.format(rank=rank, part="*"),
)
filepaths = glob.glob(pattern)
filepaths = []
if is_s3(local_model_path):
file_pattern = f"*{self.pattern.format(rank=rank, part=' * ')}"
filepaths = s3_glob(path=local_model_path,
allow_pattern=[file_pattern])
else:
filepaths = glob.glob(pattern)
if not filepaths:
# TODO: support un-sharded checkpoints too
raise ValueError(
f"Could not find checkpoint files '{pattern}', only "
f"pre-sharded checkpoints are currently supported!")
state_dict = self._filter_subtensors(model.state_dict())
for path in filepaths:
with safe_open(path, framework="pt") as f:
for key in f.keys(): # noqa: SIM118
tensor = f.get_tensor(key)
# If loading with LoRA enabled, additional padding may
# be added to certain parameters. We only load into a
# narrowed view of the parameter data.
param_data = state_dict[key].data
param_shape = state_dict[key].shape
for dim, size in enumerate(tensor.shape):
if size < param_shape[dim]:
param_data = param_data.narrow(dim, 0, size)
if tensor.shape != param_shape:
logger.warning(
"loading tensor of shape %s into "
"parameter '%s' of shape %s",
tensor.shape,
key,
param_shape,
)
param_data.copy_(tensor)
state_dict.pop(key)
for key, tensor in self.iterate_over_files(filepaths):
# If loading with LoRA enabled, additional padding may
# be added to certain parameters. We only load into a
# narrowed view of the parameter data.
param_data = state_dict[key].data
param_shape = state_dict[key].shape
for dim, size in enumerate(tensor.shape):
if size < param_shape[dim]:
param_data = param_data.narrow(dim, 0, size)
if tensor.shape != param_shape:
logger.warning(
"loading tensor of shape %s into "
"parameter '%s' of shape %s",
tensor.shape,
key,
param_shape,
)
param_data.copy_(tensor)
state_dict.pop(key)
if state_dict:
raise ValueError(
f"Missing keys {tuple(state_dict)} in loaded state!")
return model.eval()
def iterate_over_files(
self, paths) -> Generator[Tuple[str, torch.Tensor], None, None]:
if self.runai_model_streamer:
yield from runai_safetensors_weights_iterator(paths, True)
else:
from safetensors.torch import safe_open
for path in paths:
with safe_open(path, framework="pt") as f:
for key in f.keys(): # noqa: SIM118
tensor = f.get_tensor(key)
yield key, tensor
@staticmethod
def save_model(
model: torch.nn.Module,
......@@ -1517,4 +1538,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
if load_config.load_format == LoadFormat.RUNAI_STREAMER:
return RunaiModelStreamerLoader(load_config)
if load_config.load_format == LoadFormat.RUNAI_STREAMER_SHARDED:
return ShardedStateLoader(load_config, runai_model_streamer=True)
return DefaultModelLoader(load_config)
......@@ -31,15 +31,6 @@ def set_default_torch_dtype(dtype: torch.dtype):
torch.set_default_dtype(old_dtype)
def is_transformers_impl_compatible(
arch: str,
module: Optional["transformers.PreTrainedModel"] = None) -> bool:
mod = module or getattr(transformers, arch, None)
if mod is None:
return False
return mod.is_backend_compatible()
def resolve_transformers_arch(model_config: ModelConfig,
architectures: list[str]):
for i, arch in enumerate(architectures):
......@@ -56,20 +47,32 @@ def resolve_transformers_arch(model_config: ModelConfig,
# "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
# },
auto_modules = {
name: get_class_from_dynamic_module(module, model_config.model)
name:
get_class_from_dynamic_module(module,
model_config.model,
revision=model_config.revision)
for name, module in sorted(auto_map.items(), key=lambda x: x[0])
}
custom_model_module = auto_modules.get("AutoModel")
model_module = getattr(transformers, arch, None)
if model_module is None:
if "AutoModel" not in auto_map:
raise ValueError(
f"Cannot find model module. '{arch}' is not a registered "
"model in the Transformers library (only relevant if the "
"model is meant to be in Transformers) and 'AutoModel' is "
"not present in the model config's 'auto_map' (relevant "
"if the model is custom).")
model_module = auto_modules["AutoModel"]
# TODO(Isotr0py): Further clean up these raises.
# perhaps handled them in _ModelRegistry._raise_for_unsupported?
if model_config.model_impl == ModelImpl.TRANSFORMERS:
if not is_transformers_impl_compatible(arch, custom_model_module):
if not model_module.is_backend_compatible():
raise ValueError(
f"The Transformers implementation of {arch} is not "
"compatible with vLLM.")
architectures[i] = "TransformersForCausalLM"
if model_config.model_impl == ModelImpl.AUTO:
if not is_transformers_impl_compatible(arch, custom_model_module):
if not model_module.is_backend_compatible():
raise ValueError(
f"{arch} has no vLLM implementation and the Transformers "
"implementation is not compatible with vLLM. Try setting "
......@@ -132,10 +135,10 @@ def get_model_architecture(
architectures = ["QuantMixtralForCausalLM"]
vllm_supported_archs = ModelRegistry.get_supported_archs()
is_vllm_supported = any(arch in vllm_supported_archs
for arch in architectures)
if (not is_vllm_supported
or model_config.model_impl == ModelImpl.TRANSFORMERS):
vllm_not_supported = not any(arch in vllm_supported_archs
for arch in architectures)
if (model_config.model_impl == ModelImpl.TRANSFORMERS or
model_config.model_impl != ModelImpl.VLLM and vllm_not_supported):
architectures = resolve_transformers_arch(model_config, architectures)
model_cls, arch = ModelRegistry.resolve_model_cls(architectures)
......
......@@ -24,7 +24,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.deepspeedfp import (
DeepSpeedFPConfig, DeepSpeedFPParameter)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
......@@ -435,7 +434,6 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant):
self.unpadded_vocab_size = config.vocab_size
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
......@@ -462,14 +460,6 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant):
sampling_metadata)
return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
......
......@@ -15,11 +15,10 @@ from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import (SamplerOutput,
SamplingMetadata, get_sampler)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs)
......@@ -527,7 +526,6 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
self.vocab_size, logit_scale)
self.sampler = get_sampler()
def _validate_image_sizes(
self, images: List[torch.Tensor]) -> List[torch.Tensor]:
......@@ -653,14 +651,6 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)
loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
# SPDX-License-Identifier: Apache-2.0 Adapted from
# https://github.com/huggingface/transformers/tree/main/src/transformers/models/aya_vision
from functools import cached_property
from typing import (Iterable, Literal, Mapping, Optional, Sequence, Set, Tuple,
TypedDict, Union, cast)
......@@ -17,7 +16,6 @@ from transformers.models.got_ocr2.image_processing_got_ocr2 import (
from vllm.config import VllmConfig
from vllm.jsontree import json_map_leaves
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargs
......@@ -461,17 +459,3 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler
return get_sampler()
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)
......@@ -43,7 +43,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
......@@ -504,7 +503,6 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
......@@ -532,14 +530,6 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self)
......
......@@ -24,7 +24,6 @@ from vllm.model_executor.layers.mamba.mamba_mixer2 import (
MambaMixer2, extra_groups_for_head_shards)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
......@@ -462,7 +461,6 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
......@@ -538,14 +536,6 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
sampling_metadata)
return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self)
......
......@@ -37,7 +37,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
......@@ -791,7 +790,6 @@ class BartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant):
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = get_sampler()
def forward(
self,
......@@ -828,14 +826,6 @@ class BartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant):
sampling_metadata)
return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
stacked_params_mapping = {
"q_proj": {
"param_name": "qkv_proj",
......
......@@ -11,8 +11,10 @@ from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, PoolerConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.activation import (get_act_and_mul_fn,
get_act_fn)
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.pooler import (CrossEncodingPooler, Pooler,
......@@ -108,6 +110,7 @@ class BertEncoder(nn.Module):
def __init__(self,
vllm_config: VllmConfig,
bias: bool = True,
rotary_kwargs: Optional[dict] = None,
prefix: str = ""):
super().__init__()
......@@ -118,6 +121,7 @@ class BertEncoder(nn.Module):
BertLayer(config=config,
cache_config=cache_config,
quant_config=quant_config,
bias=bias,
rotary_kwargs=rotary_kwargs,
prefix=f"{prefix}.layer.{layer_idx}")
for layer_idx in range(config.num_hidden_layers)
......@@ -139,6 +143,7 @@ class BertLayer(nn.Module):
config: BertConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = True,
rotary_kwargs: Optional[dict] = None,
prefix: str = ""):
super().__init__()
......@@ -149,19 +154,31 @@ class BertLayer(nn.Module):
layer_norm_eps=config.layer_norm_eps,
cache_config=cache_config,
quant_config=quant_config,
bias=bias,
rotary_kwargs=rotary_kwargs,
prefix=f"{prefix}.attention")
self.intermediate = BertIntermediate(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.intermediate")
if config.hidden_act in ["silu", "gelu_and_mul"]:
self.intermediate = BertGatedIntermediate(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.intermediate")
else:
self.intermediate = BertIntermediate(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.intermediate")
self.output = BertOutput(hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
layer_norm_eps=config.layer_norm_eps,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.output")
......@@ -181,6 +198,7 @@ class BertAttention(nn.Module):
layer_norm_eps: float,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = True,
rotary_kwargs: Optional[dict] = None,
prefix: str = "",
):
......@@ -190,11 +208,13 @@ class BertAttention(nn.Module):
num_attention_heads=num_attention_heads,
cache_config=cache_config,
quant_config=quant_config,
bias=bias,
rotary_kwargs=rotary_kwargs,
prefix=f"{prefix}.output")
self.output = BertSelfOutput(hidden_size=hidden_size,
layer_norm_eps=layer_norm_eps,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.output")
......@@ -215,6 +235,7 @@ class BertSelfAttention(nn.Module):
num_attention_heads: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = True,
rotary_kwargs: Optional[dict] = None,
prefix: str = "",
):
......@@ -240,7 +261,7 @@ class BertSelfAttention(nn.Module):
head_size=self.head_dim,
total_num_heads=self.total_num_heads,
total_num_kv_heads=self.total_num_kv_heads,
bias=True,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj")
......@@ -278,12 +299,13 @@ class BertSelfOutput(nn.Module):
def __init__(self,
hidden_size: int,
layer_norm_eps: float,
bias: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.dense = RowParallelLinear(input_size=hidden_size,
output_size=hidden_size,
bias=True,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.dense")
self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
......@@ -301,12 +323,13 @@ class BertIntermediate(nn.Module):
hidden_size: int,
intermediate_size: int,
hidden_act: str,
bias: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.dense = ColumnParallelLinear(input_size=hidden_size,
output_size=intermediate_size,
bias=True,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.dense")
self.intermediate_act_fn = get_act_fn(hidden_act)
......@@ -317,19 +340,46 @@ class BertIntermediate(nn.Module):
return hidden_states
class BertGatedIntermediate(nn.Module):
# for NomciBert and GteModel
def __init__(self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
bias: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.act_fn = get_act_and_mul_fn(hidden_act)
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
gate_up, _ = self.gate_up_proj(hidden_states)
hidden_states = self.act_fn(gate_up)
return hidden_states
class BertOutput(nn.Module):
def __init__(self,
hidden_size: int,
intermediate_size: int,
layer_norm_eps: float,
bias: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.dense = RowParallelLinear(input_size=intermediate_size,
output_size=hidden_size,
bias=True,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.dense")
......@@ -343,19 +393,32 @@ class BertOutput(nn.Module):
class BertModel(nn.Module, SupportsQuant):
packed_modules_mapping = {"qkv_proj": ["query", "key", "value"]}
packed_modules_mapping = {
"qkv_proj": ["query", "key", "value"],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = "",
embedding_class: type = BertEmbedding,
bias: bool = True,
rotary_kwargs: Optional[dict] = None,
add_pooling_layer: bool = False):
super().__init__()
"""
For BertModel, all linear layers have bias.
For NomicBertModel, all linear layers do not have bias.
"""
config = vllm_config.model_config.hf_config
self.embeddings = embedding_class(config)
self.encoder = BertEncoder(vllm_config=vllm_config,
bias=bias,
rotary_kwargs=rotary_kwargs,
prefix=f"{prefix}.encoder")
self.pooler = BertPooler(config) if add_pooling_layer else None
......@@ -387,6 +450,8 @@ class BertModel(nn.Module, SupportsQuant):
("qkv_proj", "query", "q"),
("qkv_proj", "key", "k"),
("qkv_proj", "value", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
......@@ -546,3 +611,115 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors,
token_type_ids=token_type_ids)
class NomicBertEmbeddingModel(BertEmbeddingModel):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={
"emb_ln": "embeddings.LayerNorm",
"layers": "layer",
"attn.Wqkv": "attention.self.qkv_proj",
"attn.out_proj": "attention.output.dense",
'norm1': "attention.output.LayerNorm",
'mlp.fc11': "intermediate.up_proj",
'mlp.fc12': "intermediate.gate_proj",
'mlp.fc2': "output.dense",
'norm2': "output.LayerNorm",
})
def _build_model(self,
vllm_config: VllmConfig,
prefix: str = "") -> BertModel:
config = vllm_config.model_config.hf_config
assert config.__class__.__name__ == "NomicBertConfig"
assert config.activation_function == "swiglu"
# Assume NomicBertModel all linear layers do not have bias
assert not config.mlp_fc1_bias
assert not config.mlp_fc2_bias
assert not config.qkv_proj_bias
config.layer_norm_eps = config.layer_norm_epsilon
config.position_embedding_type = "rotary"
config.intermediate_size = config.n_inner
config.hidden_act = "silu"
config.hidden_size = config.n_embd
config.num_hidden_layers = config.n_layer
head_dim = config.hidden_size // config.num_attention_heads
rotary_kwargs = {
"head_size": head_dim,
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
"max_position": config.max_trained_positions,
"base": config.rotary_emb_base,
"rope_scaling": {
"rope_type": "dynamic",
"factor": config.rotary_scaling_factor
}
}
return BertModel(vllm_config=vllm_config,
prefix=prefix,
bias=False,
rotary_kwargs=rotary_kwargs,
embedding_class=BertEmbedding)
class GteEmbeddingModel(BertEmbeddingModel):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={
"attention.qkv_proj": "attention.self.qkv_proj",
"attention.o_proj": "attention.output.dense",
'attn_ln': "attention.output.LayerNorm",
'mlp.down_proj': "output.dense",
'mlp_ln': "output.LayerNorm",
})
def _build_model(self,
vllm_config: VllmConfig,
prefix: str = "") -> BertModel:
config = vllm_config.model_config.hf_config
assert config.__class__.__name__ == "GteConfig"
assert config.position_embedding_type == "rope"
assert config.hidden_act == "gelu"
config.position_embedding_type = "rotary"
config.hidden_act = "gelu_and_mul"
head_dim = config.hidden_size // config.num_attention_heads
rotary_kwargs = {
"head_size": head_dim,
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
"max_position": config.max_position_embeddings,
"base": config.rope_theta,
}
model = BertModel(vllm_config=vllm_config,
prefix=prefix,
rotary_kwargs=rotary_kwargs,
embedding_class=BertEmbedding)
# GteModel only gate_up_proj does not have bias.
# Hack method learned from vllm/model_executor/models/glm.py
for layer in model.encoder.layer:
layer.intermediate.gate_up_proj.bias = None
layer.intermediate.skip_bias_add = True
return model
def split_up_gate_proj(self, weights: Iterable[Tuple[str, torch.Tensor]]):
n = "mlp.up_gate_proj"
for name, weight in weights:
if n in name:
up, gate = weight.chunk(2, dim=0)
yield name.replace(n, "intermediate.up_proj"), up
yield name.replace(n, "intermediate.gate_proj"), gate
else:
yield name, weight
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weights = self.hf_to_vllm_mapper.apply(weights)
weights = self.split_up_gate_proj(weights)
self.model.load_weights(weights)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment