Commit 38d80967 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.10.2rc2' into v0.10.2rc2-ori

parents 33650733 880c741b
......@@ -38,9 +38,7 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
) -> mk.PrepareResultType:
if apply_router_weight_on_input:
topk = topk_ids.size(1)
......
......@@ -420,9 +420,8 @@ def shuffle_weights(
Args:
*tensors: Variable number of torch.Tensor objects.
layout: A pair of integers specifying the
block sizes used to divide the tensors during shuffling.
Default is (16, 16).
layout: A pair of integers specifying the block sizes used to divide
the tensors during shuffling. Default is (16, 16).
Returns:
A Tuple of shuffled tensors.
......
......@@ -10,7 +10,7 @@ like uniform random routing.
"""
from abc import ABC, abstractmethod
from typing import Optional
from typing import Any, Optional
import torch
......@@ -50,7 +50,9 @@ class DistributionBasedRouting(RoutingStrategy):
distributions for testing different routing patterns.
"""
def __init__(self, distribution: str = "uniform", **distribution_params):
def __init__(self,
distribution: str = "uniform",
**distribution_params: Any):
"""
Initialize distribution-based routing.
......@@ -244,7 +246,7 @@ class RoutingSimulator:
cls._routing_strategies[name] = strategy
@classmethod
def get_available_strategies(cls):
def get_available_strategies(cls) -> list[str]:
"""
Get list of available routing strategy names.
......
......@@ -9,11 +9,11 @@ import torch.nn as nn
import vllm.envs as envs
from vllm.model_executor.custom_op import CustomOp
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
def is_rocm_aiter_rmsnorm_enabled() -> bool:
return current_platform.is_rocm() \
and envs.VLLM_ROCM_USE_AITER_RMSNORM \
return envs.VLLM_ROCM_USE_AITER_RMSNORM \
and envs.VLLM_ROCM_USE_AITER
......@@ -43,7 +43,21 @@ def fused_add_rms_norm(
return x, residual
def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor,
def poly_norm(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor,
variance_epsilon: float) -> torch.Tensor:
from vllm import _custom_ops as ops
out = torch.empty_like(x)
ops.poly_norm(
out,
x,
weight,
bias,
variance_epsilon,
)
return out
def rocm_aiter_rms_norm_impl(x: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> torch.Tensor:
import aiter as rocm_aiter
if x.dim() > 2:
......@@ -55,7 +69,7 @@ def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor,
return rocm_aiter.rms_norm(x, weight, variance_epsilon)
def rocm_aiter_fused_add_rms_norm(
def rocm_aiter_rmsnorm2d_fwd_with_add_impl(
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
......@@ -74,14 +88,48 @@ def rocm_aiter_fused_add_rms_norm(
return output, residual_out
def dispatch_cuda_rmsnorm_func(add_residual: bool):
if add_residual:
if is_rocm_aiter_rmsnorm_enabled():
return rocm_aiter_fused_add_rms_norm
return fused_add_rms_norm
def rocm_aiter_rms_norm_fake(x: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> torch.Tensor:
return torch.empty_like(x)
def rocm_aiter_rmsnorm2d_fwd_with_add_fake(
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
return torch.empty_like(x), torch.empty_like(residual)
if current_platform.is_rocm():
direct_register_custom_op(
op_name="rocm_aiter_rms_norm",
op_func=rocm_aiter_rms_norm_impl,
mutates_args=[],
fake_impl=rocm_aiter_rms_norm_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_rmsnorm2d_fwd_with_add",
op_func=rocm_aiter_rmsnorm2d_fwd_with_add_impl,
mutates_args=[],
fake_impl=rocm_aiter_rmsnorm2d_fwd_with_add_fake,
dispatch_key=current_platform.dispatch_key,
)
if is_rocm_aiter_rmsnorm_enabled():
return rocm_aiter_rms_norm
def dispatch_rocm_rmsnorm_func(with_fused_add: bool, dtype: torch.dtype):
use_aiter = is_rocm_aiter_rmsnorm_enabled() and dtype in [
torch.float16, torch.bfloat16
]
if use_aiter and with_fused_add:
return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add
if use_aiter:
return torch.ops.vllm.rocm_aiter_rms_norm
# fall back to CUDA implementation
if with_fused_add:
return fused_add_rms_norm
return rms_norm
......@@ -114,6 +162,13 @@ class RMSNorm(CustomOp):
self.weight = torch.ones(hidden_size)
if self.has_weight:
self.weight = nn.Parameter(self.weight)
weight_dtype = self.weight.data.dtype
if current_platform.is_rocm():
self.rocm_norm_func = dispatch_rocm_rmsnorm_func(
with_fused_add=False, dtype=weight_dtype)
self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func(
with_fused_add=True, dtype=weight_dtype)
def forward_native(
self,
......@@ -162,13 +217,27 @@ class RMSNorm(CustomOp):
return self.forward_native(x, residual)
add_residual = residual is not None
norm_func = dispatch_cuda_rmsnorm_func(add_residual)
if add_residual:
return fused_add_rms_norm(x, residual, self.weight.data,
self.variance_epsilon)
else:
return rms_norm(x, self.weight.data, self.variance_epsilon)
def forward_hip(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if self.variance_size_override is not None:
return self.forward_native(x, residual)
add_residual = residual is not None
if add_residual:
return norm_func(x, residual, self.weight.data,
return self.rocm_norm_func_with_add(x, residual, self.weight.data,
self.variance_epsilon)
else:
return norm_func(x, self.weight.data, self.variance_epsilon)
return self.rocm_norm_func(x, self.weight.data,
self.variance_epsilon)
def forward_xpu(
self,
......@@ -265,3 +334,48 @@ class GemmaRMSNorm(CustomOp):
self.forward_static)
self._is_compiled = True
return self.forward_native(x, residual)
@CustomOp.register("poly_norm")
class PolyNorm(CustomOp):
"""Polynomial normalization.
Computes x -> w_0 * RMSNorm(x^3) + w_1 * RMSNorm(x^2) + w_2 * RMSNorm(x) + b
where w_n is the learned weight and b is the bias.
Refer to https://arxiv.org/html/2411.03884v1
"""
def __init__(
self,
eps: float = 1e-6,
) -> None:
super().__init__()
self.weight = torch.nn.Parameter(torch.ones(3) / 3)
self.bias = torch.nn.Parameter(torch.zeros(1))
self.variance_epsilon = eps
def _norm(self, x):
return x / torch.sqrt(
x.pow(2).mean(-1, keepdim=True) + self.variance_epsilon)
def forward_native(
self,
x: torch.Tensor,
) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward().
Refer to https://github.com/BryceZhuo/PolyCom?tab=readme-ov-file/README.md
"""
orig_dtype = x.dtype
x_float = x.to(torch.float32)
output = (self.weight[0] * self._norm(x_float**3) +
self.weight[1] * self._norm(x_float**2) +
self.weight[2] * self._norm(x_float) + self.bias)
return output.to(orig_dtype)
def forward_cuda(
self,
x: torch.Tensor,
) -> torch.Tensor:
return poly_norm(x, self.weight, self.bias, self.variance_epsilon)
......@@ -9,7 +9,6 @@ import torch
import torch.nn as nn
from torch.nn.parameter import Parameter, UninitializedParameter
from vllm import envs
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
......@@ -200,26 +199,10 @@ class UnquantizedLinearMethod(LinearMethodBase):
set_weight_attrs(weight, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# special postprocessing for CPU SGL
if current_platform.is_cpu() and envs.VLLM_CPU_SGL_KERNEL:
from vllm.model_executor.layers.utils import check_cpu_sgl_kernel
N, K = layer.weight.size()
dtype = layer.weight.dtype
if check_cpu_sgl_kernel(N, K, dtype):
packed_weight = torch.ops._C.convert_weight_packed(
layer.weight)
assert packed_weight.size() == layer.weight.size()
layer.weight.copy_(packed_weight)
if layer.bias is not None:
layer.bias = Parameter(layer.bias.to(torch.float32),
requires_grad=False)
layer.use_cpu_sgl = True
else:
logger.warning(
"CPU SGL kernels require Intel AMX support,"
" bf16/fp16/int8 weight, IC and OC are divisible by "
"32 and 16.")
layer.use_cpu_sgl = False
if current_platform.is_cpu():
from vllm.model_executor.layers.utils import (
dispatch_cpu_unquantized_gemm)
dispatch_cpu_unquantized_gemm(layer, remove_weight=True)
def apply(self,
layer: torch.nn.Module,
......@@ -240,6 +223,7 @@ class LinearBase(CustomOp):
quant_config: Quantization configure.
prefix: Prefix for parameter names.
return_bias: If true, return bias together with outputs in forward pass.
disable_tp: If true, tensor parallelism will be disabled for this layer.
"""
def __init__(
......@@ -252,6 +236,7 @@ class LinearBase(CustomOp):
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
):
super().__init__()
......@@ -271,6 +256,17 @@ class LinearBase(CustomOp):
self.quant_method = quant_config.get_quant_method(self,
prefix=prefix)
self.return_bias = return_bias
self.disable_tp = disable_tp
self.tp_rank = (get_tensor_model_parallel_rank()
if not disable_tp else 0)
self.tp_size = (get_tensor_model_parallel_world_size()
if not disable_tp else 1)
def update_param_tp_status(self):
for param in self.parameters():
if isinstance(param, BasevLLMParameter):
param.tp_rank = self.tp_rank
param.tp_size = self.tp_size
@CustomOp.register("replicated_linear")
......@@ -287,6 +283,7 @@ class ReplicatedLinear(LinearBase):
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
return_bias: If true, return bias together with outputs in forward pass.
disable_tp: Take no effect for replicated linear layers.
"""
def __init__(
......@@ -300,26 +297,21 @@ class ReplicatedLinear(LinearBase):
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
):
# If MergedReplicatedLinear, use output size of each partition.
if hasattr(self, "output_sizes"):
self.output_partition_sizes = self.output_sizes
else:
self.output_partition_sizes = [output_size]
super().__init__(input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix=prefix,
return_bias=return_bias)
return_bias=return_bias,
disable_tp=disable_tp)
# All the linear layer supports quant method.
assert self.quant_method is not None
self.quant_method.create_weights(self,
self.input_size,
self.output_partition_sizes,
self.input_size, [self.output_size],
self.input_size,
self.output_size,
self.params_dtype,
......@@ -375,74 +367,6 @@ class ReplicatedLinear(LinearBase):
return s
class MergedReplicatedLinear(ReplicatedLinear):
"""Replicated linear layer.
Args:
input_size: input dimension of the linear layer.
output_sizes: list of output dimensions of the linear layer.
bias: If true, add bias.
skip_bias_add: If true, skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
return_bias: If true, return bias together with outputs in forward pass.
"""
def __init__(
self,
input_size: int,
output_sizes: list[int],
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
*,
return_bias: bool = True,
):
self.output_sizes = output_sizes
super().__init__(input_size,
sum(output_sizes),
bias,
skip_bias_add,
params_dtype,
quant_config,
prefix=prefix,
return_bias=return_bias)
def weight_loader(self,
param: Union[Parameter, BasevLLMParameter],
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[int] = None):
assert loaded_shard_id is not None
assert loaded_shard_id < len(self.output_sizes)
if isinstance(param, BlockQuantScaleParameter):
from vllm.model_executor.layers.quantization.fp8 import (
Fp8LinearMethod, Fp8MoEMethod)
assert self.quant_method is not None
assert isinstance(self.quant_method,
(Fp8LinearMethod, Fp8MoEMethod))
weight_block_size = self.quant_method.quant_config.weight_block_size
assert weight_block_size is not None
block_n, _ = weight_block_size[0], weight_block_size[1]
shard_offset = (
(sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) //
block_n)
shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) //
block_n)
elif isinstance(param, PerTensorScaleParameter):
shard_offset = loaded_shard_id
shard_size = 1
else:
shard_offset = sum(self.output_sizes[:loaded_shard_id])
shard_size = self.output_sizes[loaded_shard_id]
param.data[shard_offset:shard_offset + shard_size] = loaded_weight
@CustomOp.register("column_parallel_linear")
class ColumnParallelLinear(LinearBase):
"""Linear layer with column parallelism.
......@@ -466,6 +390,8 @@ class ColumnParallelLinear(LinearBase):
the list would be size 3.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
return_bias: If true, return bias together with outputs in forward pass.
disable_tp: If true, weights matrix won't be sharded through tp rank.
"""
def __init__(
......@@ -481,9 +407,13 @@ class ColumnParallelLinear(LinearBase):
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
):
# Divide the weight matrix along the last dimension.
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = (get_tensor_model_parallel_rank()
if not disable_tp else 0)
self.tp_size = (get_tensor_model_parallel_world_size()
if not disable_tp else 1)
self.input_size_per_partition = input_size
self.output_size_per_partition = divide(output_size, self.tp_size)
self.output_partition_sizes = [self.output_size_per_partition]
......@@ -500,7 +430,8 @@ class ColumnParallelLinear(LinearBase):
params_dtype,
quant_config,
prefix,
return_bias=return_bias)
return_bias=return_bias,
disable_tp=disable_tp)
self.gather_output = gather_output
......@@ -528,8 +459,7 @@ class ColumnParallelLinear(LinearBase):
})
else:
self.register_parameter("bias", None)
self.tp_rank = get_tensor_model_parallel_rank()
self.update_param_tp_status()
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
......@@ -571,7 +501,8 @@ class ColumnParallelLinear(LinearBase):
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
def weight_loader_v2(self, param: BasevLLMParameter,
loaded_weight: torch.Tensor):
# Special case for loading scales off disk, which often do not
# have a shape (such as in the case of AutoFP8).
if len(loaded_weight.shape) == 0:
......@@ -587,7 +518,7 @@ class ColumnParallelLinear(LinearBase):
# Matrix multiply.
assert self.quant_method is not None
output_parallel = self.quant_method.apply(self, input_, bias)
if self.gather_output:
if self.gather_output and self.tp_size > 1:
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)
else:
......@@ -601,7 +532,7 @@ class ColumnParallelLinear(LinearBase):
s = f"in_features={self.input_size}"
s += f", output_features={self.output_size_per_partition}"
s += f", bias={self.bias is not None}"
s += f", tp_size={get_tensor_model_parallel_world_size()}"
s += f", tp_size={self.tp_size}"
s += f", gather_output={self.gather_output}"
return s
......@@ -628,6 +559,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
return_bias: If true, return bias together with outputs in forward pass.
disable_tp: If true, all weights matrix won't be sharded, this layer
will be treated as a "Replicated" MergedLinear.
"""
def __init__(
......@@ -642,10 +575,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
):
self.output_sizes = output_sizes
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = (get_tensor_model_parallel_world_size()
if not disable_tp else 1)
self.tp_rank = (get_tensor_model_parallel_rank()
if not disable_tp else 0)
assert all(output_size % self.tp_size == 0
for output_size in output_sizes)
......@@ -657,7 +593,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix,
return_bias=return_bias)
return_bias=return_bias,
disable_tp=disable_tp)
def weight_loader(self,
param: Parameter,
......@@ -722,8 +659,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
# If quantized, we need to adjust the offset and size to account
# for the packing.
if packed_dim == output_dim:
shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor
shard_size = shard_size // param.packed_factor
shard_offset = shard_offset // param.packed_factor
# Special case for Marlin.
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)
......@@ -756,8 +693,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
# for the packing.
packed_dim = getattr(param, "packed_dim", None)
if packed_dim == output_dim:
shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor
shard_size = shard_size // param.packed_factor
shard_offset = shard_offset // param.packed_factor
# Special case for Marlin.
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)
......@@ -849,8 +786,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
assert loaded_shard_id < len(self.output_sizes)
tp_size = get_tensor_model_parallel_world_size()
if isinstance(param, BlockQuantScaleParameter):
from vllm.model_executor.layers.quantization.fp8 import (
Fp8LinearMethod, Fp8MoEMethod)
......@@ -862,17 +797,19 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
block_n, _ = weight_block_size[0], weight_block_size[1]
shard_offset = (
(sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) //
block_n) // tp_size
block_n) // self.tp_size
shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) //
block_n // tp_size)
block_n // self.tp_size)
else:
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
shard_size = self.output_sizes[loaded_shard_id] // tp_size
shard_offset = sum(
self.output_sizes[:loaded_shard_id]) // self.tp_size
shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
param.load_merged_column_weight(loaded_weight=loaded_weight,
shard_id=loaded_shard_id,
shard_offset=shard_offset,
shard_size=shard_size)
shard_size=shard_size,
tp_rank=self.tp_rank)
class QKVParallelLinear(ColumnParallelLinear):
......@@ -900,6 +837,7 @@ class QKVParallelLinear(ColumnParallelLinear):
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
return_bias: If true, return bias together with outputs in forward pass.
disable_tp: If true, weights matrix won't be sharded through tp rank.
"""
def __init__(
......@@ -915,6 +853,7 @@ class QKVParallelLinear(ColumnParallelLinear):
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
):
self.hidden_size = hidden_size
self.head_size = head_size
......@@ -923,7 +862,8 @@ class QKVParallelLinear(ColumnParallelLinear):
total_num_kv_heads = total_num_heads
self.total_num_kv_heads = total_num_kv_heads
# Divide the weight matrix along the last dimension.
tp_size = get_tensor_model_parallel_world_size()
tp_size = (get_tensor_model_parallel_world_size()
if not disable_tp else 1)
self.num_heads = divide(self.total_num_heads, tp_size)
if tp_size >= self.total_num_kv_heads:
self.num_kv_heads = 1
......@@ -949,7 +889,8 @@ class QKVParallelLinear(ColumnParallelLinear):
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix,
return_bias=return_bias)
return_bias=return_bias,
disable_tp=disable_tp)
def _get_shard_offset_mapping(self, loaded_shard_id: str):
shard_offset_mapping = {
......@@ -1010,10 +951,13 @@ class QKVParallelLinear(ColumnParallelLinear):
loaded_shard_id: Optional[str] = None):
if loaded_shard_id is None: # special case for certain models
if isinstance(param, PerTensorScaleParameter):
param.load_qkv_weight(loaded_weight=loaded_weight, shard_id=0)
param.load_qkv_weight(loaded_weight=loaded_weight,
shard_id=0,
tp_rank=self.tp_rank)
return
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
param.load_qkv_weight(loaded_weight=loaded_weight)
param.load_qkv_weight(loaded_weight=loaded_weight,
tp_rank=self.tp_rank)
return
# TODO: @dsikka - move to parameter.py
self._load_fused_module_from_checkpoint(param, loaded_weight)
......@@ -1037,7 +981,8 @@ class QKVParallelLinear(ColumnParallelLinear):
num_heads=self.num_kv_head_replicas,
shard_id=loaded_shard_id,
shard_offset=shard_offset,
shard_size=shard_size)
shard_size=shard_size,
tp_rank=self.tp_rank)
def weight_loader(self,
param: Parameter,
......@@ -1107,8 +1052,8 @@ class QKVParallelLinear(ColumnParallelLinear):
# If quantized, we need to adjust the offset and size to account
# for the packing.
if packed_dim == output_dim:
shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor
shard_size = shard_size // param.packed_factor
shard_offset = shard_offset // param.packed_factor
# Special case for Marlin.
shard_size, shard_offset = adjust_marlin_shard(
......@@ -1155,8 +1100,8 @@ class QKVParallelLinear(ColumnParallelLinear):
# for the packing.
packed_dim = getattr(param, "packed_dim", None)
if packed_dim == output_dim:
shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor
shard_size = shard_size // param.packed_factor
shard_offset = shard_offset // param.packed_factor
# Special case for Marlin.
shard_size, shard_offset = adjust_marlin_shard(
......@@ -1243,6 +1188,7 @@ class RowParallelLinear(LinearBase):
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.down_proj)
return_bias: If true, return bias together with outputs in forward pass.
disable_tp: If true, weights matrix won't be sharded through tp rank.
"""
def __init__(
......@@ -1258,10 +1204,13 @@ class RowParallelLinear(LinearBase):
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
):
# Divide the weight matrix along the first dimension.
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = (get_tensor_model_parallel_rank()
if not disable_tp else 0)
self.tp_size = (get_tensor_model_parallel_world_size()
if not disable_tp else 1)
self.input_size_per_partition = divide(input_size, self.tp_size)
self.output_size_per_partition = output_size
self.output_partition_sizes = [output_size]
......@@ -1272,7 +1221,8 @@ class RowParallelLinear(LinearBase):
params_dtype,
quant_config,
prefix,
return_bias=return_bias)
return_bias=return_bias,
disable_tp=disable_tp)
self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results
......@@ -1301,6 +1251,7 @@ class RowParallelLinear(LinearBase):
})
else:
self.register_parameter("bias", None)
self.update_param_tp_status()
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
input_dim = getattr(param, "input_dim", None)
......@@ -1356,10 +1307,9 @@ class RowParallelLinear(LinearBase):
if self.input_is_parallel:
input_parallel = input_
else:
tp_rank = get_tensor_model_parallel_rank()
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[tp_rank].contiguous()
input_parallel = splitted_input[self.tp_rank].contiguous()
# Matrix multiply.
assert self.quant_method is not None
......
......@@ -6,11 +6,11 @@ from concurrent.futures import ThreadPoolExecutor
from typing import Optional
import torch
import torch.nn as nn
import vllm.envs as envs
from vllm.distributed import (tensor_model_parallel_all_gather,
tensor_model_parallel_gather)
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.sampling_metadata import SamplingMetadata
......@@ -22,7 +22,8 @@ if envs.VLLM_LOGITS_PROCESSOR_THREADS is not None:
envs.VLLM_LOGITS_PROCESSOR_THREADS)
class LogitsProcessor(nn.Module):
@CustomOp.register("logits_processor")
class LogitsProcessor(CustomOp):
"""Process logits and apply logits processors from sampling metadata.
This layer does the following:
......
......@@ -83,17 +83,7 @@ class MiniMaxText01RMSNormTP(CustomOp):
variance = tensor_model_parallel_all_reduce(
variance) / self.tp_world
x = x * torch.rsqrt(variance + self.variance_epsilon)
weight = self.weight
if x.size(-1) != self.weight.size(0):
if self.weight.size(0) < x.size(-1):
repeat_count = (x.size(-1) + self.weight.size(0)) // x.size(-1)
full_weight = self.weight.repeat(repeat_count)
weight = full_weight[:x.size(-1)]
else:
weight = self.weight[:x.size(-1)]
x = x.to(orig_dtype) * weight
x = x.to(orig_dtype) * self.weight
return x
def forward(
......
......@@ -291,6 +291,7 @@ class MambaMixer2(MambaBase, CustomOp):
output_size=self.conv_dim,
bias=use_conv_bias,
quant_config=None,
prefix=f"{prefix}.conv1d",
)
# unsqueeze to fit conv1d weights shape into the linear weights shape.
# Can't do this in `weight_loader` since it already exists in
......@@ -303,6 +304,7 @@ class MambaMixer2(MambaBase, CustomOp):
output_size=intermediate_size + self.conv_dim + self.num_heads,
bias=use_bias,
quant_config=quant_config,
prefix=f"{prefix}.in_proj",
)
# - because in_proj is a concatenation of 3 weights, we
......@@ -402,6 +404,7 @@ class MambaMixer2(MambaBase, CustomOp):
bias=use_bias,
input_is_parallel=True,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
self.norm = Mixer2RMSNormGated(intermediate_size,
......
......@@ -30,12 +30,8 @@ class MambaStateDtypeCalculator:
mamba_cache_dtype: MambaDType,
mamba_ssm_cache_dtype: MambaDType,
) -> tuple[torch.dtype, ...]:
# TODO (tdoublep) requires kernel changes
if mamba_cache_dtype == "float32" or mamba_ssm_cache_dtype == "float32":
raise ValueError("fp32 state for mamba1 is not yet supported")
else:
return MambaStateDtypeCalculator.mamba2_state_dtype(
model_dtype, mamba_cache_dtype, mamba_ssm_cache_dtype)
return cls._mamba_state_dtype(model_dtype, mamba_cache_dtype,
mamba_ssm_cache_dtype)
@classmethod
def mamba2_state_dtype(
......@@ -43,6 +39,16 @@ class MambaStateDtypeCalculator:
model_dtype: Union[ModelDType, torch.dtype],
mamba_cache_dtype: MambaDType,
mamba_ssm_cache_dtype: MambaDType,
) -> tuple[torch.dtype, ...]:
return cls._mamba_state_dtype(model_dtype, mamba_cache_dtype,
mamba_ssm_cache_dtype)
@classmethod
def _mamba_state_dtype(
cls,
model_dtype: Union[ModelDType, torch.dtype],
mamba_cache_dtype: MambaDType,
mamba_ssm_cache_dtype: MambaDType,
) -> tuple[torch.dtype, ...]:
conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype,
model_dtype)
......@@ -64,6 +70,15 @@ class MambaStateDtypeCalculator:
model_dtype)
return (conv_state_dtype, )
@classmethod
def gated_delta_net_state_dtype(
cls,
model_dtype: Union[ModelDType, torch.dtype],
mamba_cache_dtype: MambaDType,
) -> tuple[torch.dtype, torch.dtype]:
state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
return (state_dtype, state_dtype)
class MambaStateShapeCalculator:
......@@ -157,3 +172,31 @@ class MambaStateShapeCalculator:
# for n_groups == 1, this is exactly tp_size - n_groups
return tp_size - ngroups
@classmethod
def gated_delta_net_state_shape(
cls,
tp_world_size: int,
num_k_heads: int,
num_v_heads: int,
head_k_dim: int,
head_v_dim: int,
conv_kernel_size: int,
num_spec: int = 0,
use_v1: bool = True,
):
conv_dim = (head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads)
conv_state_shape = (
divide(conv_dim, tp_world_size),
conv_kernel_size - 1 + num_spec,
)
# In V0, the conv_state shape was swapped during allocation in
# MambaCacheManager, but in V1 it needs to be determined here at the
# calculation level
if use_v1:
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
temporal_state_shape = (divide(num_v_heads,
tp_world_size), head_k_dim, head_v_dim)
return conv_state_shape, temporal_state_shape
......@@ -464,7 +464,9 @@ def causal_conv1d_fn(
# 3. mapping from sequence x[idx] to a cache line at index as specified via cache_indices[idx]
# 4. computation can be skipped if cache_indices[idx] == pad_slot_id
num_cache_lines = conv_states.size(0)
assert (num_cache_lines, dim, width - 1) == conv_states.shape
assert (num_cache_lines == conv_states.shape[0]
and dim == conv_states.shape[1]
and width - 1 <= conv_states.shape[2])
stride_istate_seq = conv_states.stride(0)
stride_istate_dim = conv_states.stride(1)
stride_istate_token = conv_states.stride(2)
......@@ -623,6 +625,7 @@ def _causal_conv1d_update_kernel(
conv_state_ptr,
cache_seqlens_ptr, # circular buffer
conv_state_indices_ptr,
num_accepted_tokens_ptr,
o_ptr, # (batch, dim, seqlen)
# Matrix dimensions
batch: int,
......@@ -639,6 +642,7 @@ def _causal_conv1d_update_kernel(
stride_conv_state_seq: tl.constexpr,
stride_conv_state_dim: tl.constexpr,
stride_conv_state_tok: tl.constexpr,
stride_state_indices: tl.constexpr,
stride_o_seq: tl.constexpr,
stride_o_dim: tl.constexpr,
stride_o_token: tl.constexpr,
......@@ -649,6 +653,7 @@ def _causal_conv1d_update_kernel(
KERNEL_WIDTH: tl.constexpr,
SILU_ACTIVATION: tl.constexpr,
IS_CONTINUOUS_BATCHING: tl.constexpr,
IS_SPEC_DECODING: tl.constexpr,
NP2_STATELEN: tl.constexpr,
USE_PAD_SLOT: tl.constexpr,
BLOCK_N: tl.constexpr,
......@@ -663,7 +668,8 @@ def _causal_conv1d_update_kernel(
if IS_CONTINUOUS_BATCHING:
# mask = idx_seq < batch
conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq).to(
conv_state_batch_coord = tl.load(conv_state_indices_ptr +
idx_seq * stride_state_indices).to(
tl.int64)
else:
conv_state_batch_coord = idx_seq
......@@ -672,13 +678,32 @@ def _causal_conv1d_update_kernel(
# not processing as this is not the actual sequence
return
if IS_SPEC_DECODING:
# The rolling of conv state:
#
# Before forward, the conv_state is:
# [history1, history2, ..., historyM].
#
# After forward, the conv_state becomes:
# [history2, ..., historyM, draft1, draft2, ..., draftN].
#
# After acceptance, it becomes:
#
# - accept 1 tokens: [history2, ..., historyM, draft1]
# - accept 2 tokens: [history3, ..., historyM, draft1, draft2]
# - and so on.
conv_state_token_offset = (tl.load(num_accepted_tokens_ptr + idx_seq) -
1)
else:
conv_state_token_offset = 0
# STEP 1: READ init_state data
conv_states_base = (conv_state_ptr +
(conv_state_batch_coord * stride_conv_state_seq) +
(idx_feats * stride_conv_state_dim))
mask_w = idx_feats < dim
prior_tokens = conv_states_base
prior_tokens = conv_states_base + conv_state_token_offset * stride_conv_state_tok
if KERNEL_WIDTH >= 2:
conv_states_ptrs = prior_tokens # [BLOCK_N]
col0 = tl.load(conv_states_ptrs, mask_w, 0.0)
......@@ -695,11 +720,15 @@ def _causal_conv1d_update_kernel(
# STEP 2: assume state_len > seqlen
idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
# With speculative decoding, the conv_state updates works in a sliding
# window manner, at each forward pass, the tokens are shift by 1, so we
# load since idx_tokens + 1.
conv_state_ptrs_source = (
conv_state_ptr + (conv_state_batch_coord * stride_conv_state_seq) +
conv_state_token_offset * stride_conv_state_tok +
(idx_feats * stride_conv_state_dim)[None, :] +
((idx_tokens + seqlen) * stride_conv_state_tok)[:, None]
) # [BLOCK_M, BLOCK_N]
((idx_tokens + (1 if IS_SPEC_DECODING else seqlen)) *
stride_conv_state_tok)[:, None]) # [BLOCK_M, BLOCK_N]
mask = ((conv_state_batch_coord < num_cache_lines)
& ((idx_tokens + seqlen) < state_len)[:, None]
& (idx_feats < dim)[None, :])
......@@ -820,6 +849,7 @@ def causal_conv1d_update(
activation: Union[bool, str, None] = None,
cache_seqlens: Optional[torch.Tensor] = None,
conv_state_indices: Optional[torch.Tensor] = None,
num_accepted_tokens: Optional[torch.Tensor] = None,
pad_slot_id: int = PAD_SLOT_ID,
metadata=None,
validate_data=False,
......@@ -890,9 +920,13 @@ def causal_conv1d_update(
) # X (batch, dim, seqlen)
stride_o_seq, stride_o_dim, stride_o_token = out.stride()
stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride(
)
stride_state_indices = conv_state_indices.stride(
0) if conv_state_indices is not None else 0
if num_accepted_tokens is not None:
state_len = width - 1 + (seqlen - 1) # effective state_len needed
else:
state_len = width - 1
np2_statelen = triton.next_power_of_2(state_len)
......@@ -910,6 +944,7 @@ def causal_conv1d_update(
conv_state,
cache_seqlens,
conv_state_indices,
num_accepted_tokens,
out,
# Matrix dimensions
batch,
......@@ -926,6 +961,7 @@ def causal_conv1d_update(
stride_istate_seq,
stride_istate_dim,
stride_istate_token,
stride_state_indices,
stride_o_seq,
stride_o_dim,
stride_o_token,
......@@ -936,6 +972,7 @@ def causal_conv1d_update(
KERNEL_WIDTH=width,
SILU_ACTIVATION=activation in ["silu", "swish"],
IS_CONTINUOUS_BATCHING=conv_state_indices is not None,
IS_SPEC_DECODING=num_accepted_tokens is not None,
NP2_STATELEN=np2_statelen,
USE_PAD_SLOT=pad_slot_id is not None,
BLOCK_N=256,
......
......@@ -289,6 +289,9 @@ def _chunk_scan_fwd_kernel(
# get the cs at the offset boundary
# - c_off == 0 is a passthrough
# - We need dA_cs at the boundary, defined by c_off - no need
# to increase pointer by pid_m (it is a constant offset,
# i.e. the same for all blocks)
dA_cs_m_boundary = tl.load(
dA_cumsum_ptr + (c_off - 1) * stride_dA_cs_csize,
mask=(((c_off - 1) > -1) and ((c_off) < chunk_size)),
......
......@@ -502,7 +502,7 @@ def _chunk_state_varlen_kernel(
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
# If the sequence starts after the last chunk idx, we don't need to add the contribution from the last chunk
# If HAS_INITSTATES==True need to consider two possiblties
# If HAS_INITSTATES==True need to consider two possibilities
# - if start_idx < pid_c * chunk_size, then we need to take the past_states_ptrs
# - if state_idx >= pid * chunk_size, then we need to insert initstates
if ((start_idx < pid_c * chunk_size) # first chunk
......
......@@ -106,21 +106,24 @@ def _mamba_chunk_scan_combined_fwd(x,
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
# (middle term of factorization of off-diag blocks; A terms)
# - for handling chunked prefill, this requires i) initial_states
# ii) seq_idx and iii) is_cont_batched to be all specified.
# ii) seq_idx iii) is_cont_batched and (iv) chunk_offsets to be all specified.
# - When a new seq_idx is detected, we will stop passing the prev_state
# and switch accordingly to the init_state corresponding to the new seq_idx.
# - We will also make sure that the dA_cumsum is taken only from the start of the
# sequence (hence we need the full dA_cumsum tensor and not just the values at chunk boundaries)
# - this will ensure that states will be updated with the rightmost flushed seq_idx
# of the previous chunk. This implies that the first chunk of states is either 0
# or equal to init_states of the first example.
states, final_states = _state_passing_fwd(
rearrange(states, "... p n -> ... (p n)"),
dA_cumsum[:, :, :, -1],
dA_cumsum,
initial_states=rearrange(initial_states, "... p n -> ... (p n)")
if initial_states is not None else None,
seq_idx=seq_idx,
chunk_size=chunk_size,
out_dtype=state_dtype if state_dtype is not None else C.dtype,
is_cont_batched=cu_seqlens is not None)
is_cont_batched=cu_seqlens is not None,
chunk_offsets=chunk_offsets)
states, final_states = (rearrange(t, "... (p n) -> ... p n", n=dstate)
for t in [states, final_states])
......
......@@ -31,6 +31,8 @@ def _state_passing_fwd_kernel(
dA_cs_ptr,
initstates_ptr,
seq_idx_ptr,
chunk_offsets_ptr,
chunk_meta_num,
# Matrix dimensions
dim,
nchunks,
......@@ -51,6 +53,7 @@ def _state_passing_fwd_kernel(
stride_dA_cs_batch,
stride_dA_cs_chunk,
stride_dA_cs_head,
stride_dA_cs_csize,
stride_initstates_batch,
stride_initstates_head,
stride_initstates_dim,
......@@ -66,7 +69,8 @@ def _state_passing_fwd_kernel(
pid_h = tl.program_id(axis=2)
pid_m = tl.program_id(axis=0)
states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head
dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head
dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head + (
chunk_size - 1) * stride_dA_cs_csize
out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
final_states_ptr += pid_b * stride_final_states_batch + pid_h * stride_final_states_head
if HAS_INITSTATES:
......@@ -95,35 +99,62 @@ def _state_passing_fwd_kernel(
tl.store(out_ptrs, states, mask=offs_m < dim)
out_ptrs += stride_out_chunk
seq_idx = 0
prev_seq_idx_chunk_end = 0
logical_chunk_idx = 0
for c in range(nchunks):
new_states = tl.load(states_ptrs, mask=offs_m < dim,
other=0.0).to(tl.float32)
dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
scale = tl.exp(dA_cs)
scale_mask = True
if HAS_SEQ_IDX:
# - the seq to pass forward is the one that is flushed to the right
# boundary.
# - that is given by seq_idx_new below.
seq_idx_new = tl.load(seq_idx_ptr +
(min((c + 1) * chunk_size, seqlen) - 1) *
stride_seq_idx_seqlen)
# - that is given by seq_idx_chunk_end below: the sequence index at the end of the chunk.
seq_idx_chunk_end = tl.load(seq_idx_ptr + (min(
(c + 1) * chunk_size, seqlen) - 1) * stride_seq_idx_seqlen)
if HAS_INITSTATES:
if IS_CONT_BATCHED and seq_idx != seq_idx_new:
if IS_CONT_BATCHED and prev_seq_idx_chunk_end != seq_idx_chunk_end:
# this means in the current chunk the rightmost flushed seq
# has changed.
# - so we do not propagate the state from previous chunk
# - but rather we load that sequence's init state
initstates_ptrs = initstates_ptr + seq_idx_new * stride_initstates_batch
initstates_ptrs = initstates_ptr + seq_idx_chunk_end * stride_initstates_batch
# - update state with seq_idx_new's init state
states = tl.load(initstates_ptrs,
mask=offs_m < dim,
other=0.0).to(tl.float32)
# - we need to consider the cumsum only of the last sequence in the chunk
# - find its starting position (given by c_off of the logical chunk index)
# - and subtract the cumsum just before that position from the total cumsum
# - first, update the logical chunk index (add the number of sequences in the current physical chunk):
# sequence index at the start of the current chunk
seq_idx_chunk_start = tl.load(seq_idx_ptr +
min(c * chunk_size, seqlen) *
stride_seq_idx_seqlen)
logical_chunk_idx += seq_idx_chunk_end - seq_idx_chunk_start
# - load the chunk offset:
c_off = tl.load(chunk_offsets_ptr + logical_chunk_idx,
mask=logical_chunk_idx < chunk_meta_num,
other=0)
# - if offset is 0, then the sequence starts at the beginning of the chunk, and we don't need to subtract anything
if c_off > 0:
# - dA_cs_ptr currently points to the cumsum at the end of the chunk - subtract the chunk size and add the offset
dA_cs_boundary = tl.load(
dA_cs_ptr - (chunk_size - 1) * stride_dA_cs_csize +
(c_off - 1) * stride_dA_cs_csize,
mask=(c_off - 1) > -1 and c_off < chunk_size,
other=0.0)
dA_cs -= dA_cs_boundary
# - increment logical chunk index for every physical chunk
logical_chunk_idx += 1
else:
scale = tl.where(seq_idx_new == seq_idx, scale, 0.0)
scale_mask = seq_idx_chunk_end == prev_seq_idx_chunk_end
prev_seq_idx_chunk_end = seq_idx_chunk_end
seq_idx = seq_idx_new
scale = tl.where(scale_mask, tl.exp(dA_cs), 0.0)
states = scale * states + new_states
if c < nchunks - 1:
tl.store(out_ptrs, states, mask=offs_m < dim)
......@@ -136,28 +167,36 @@ def _state_passing_fwd_kernel(
def _state_passing_fwd(
states,
dA_chunk_cumsum,
dA_cumsum,
initial_states=None,
seq_idx=None,
chunk_size=None,
out_dtype=None,
is_cont_batched=False,
chunk_offsets=None,
):
batch, nchunks, nheads, dim = states.shape
assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)
if chunk_size is None:
chunk_size = dA_cumsum.shape[-1]
else:
assert chunk_size == dA_cumsum.shape[-1]
assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
if initial_states is not None:
if is_cont_batched:
# - if cu_seqlens is provided, then the initial states
# are used for continuous batching. In which case we
# require seq_idx to be provided
assert seq_idx is not None, ""
assert seq_idx is not None, "seq_idx must be provided for continuous batching"
# - we also need chunk_offsets to be provided, to account
# for computation of dA_cumsum from the start of the
# sequence
assert chunk_offsets is not None, "chunk_offsets must be provided for continuous batching"
else:
# - this is the regular batching case, where initial
# states are used are for each example of the batch.
assert initial_states.shape == (batch, nheads, dim)
if seq_idx is not None:
assert chunk_size is not None
seqlen = seq_idx.shape[-1]
assert seq_idx.shape == (batch, seqlen)
out_dtype = states.dtype if out_dtype is None else out_dtype
......@@ -173,13 +212,15 @@ def _state_passing_fwd(
states,
out,
final_states,
dA_chunk_cumsum,
dA_cumsum,
initial_states,
seq_idx,
chunk_offsets,
len(chunk_offsets) if chunk_offsets is not None else 0,
dim,
nchunks,
seqlen if seq_idx is not None else 0,
chunk_size if seq_idx is not None else 0,
chunk_size,
states.stride(0),
states.stride(1),
states.stride(2),
......@@ -191,9 +232,10 @@ def _state_passing_fwd(
final_states.stride(0),
final_states.stride(1),
final_states.stride(2),
dA_chunk_cumsum.stride(0),
dA_chunk_cumsum.stride(2),
dA_chunk_cumsum.stride(1),
dA_cumsum.stride(0),
dA_cumsum.stride(2),
dA_cumsum.stride(1),
dA_cumsum.stride(3),
*((initial_states.stride(0), initial_states.stride(1),
initial_states.stride(2)) if initial_states is not None else
(0, 0, 0)),
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Optional
import torch
from vllm.attention import Attention
from vllm.config import CacheConfig
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization import QuantizationConfig
@dataclass
class MLAModules:
"""Modules used in MLA.
"""
kv_a_layernorm: torch.nn.Module
kv_b_proj: torch.nn.Module
rotary_emb: torch.nn.Module
o_proj: torch.nn.Module
fused_qkv_a_proj: Optional[torch.nn.Module]
kv_a_proj_with_mqa: Optional[torch.nn.Module]
q_a_layernorm: Optional[torch.nn.Module]
q_b_proj: Optional[torch.nn.Module]
q_proj: Optional[torch.nn.Module]
@CustomOp.register("multi_head_latent_attention")
class MultiHeadLatentAttention(CustomOp):
"""MLA layer registered as CustomOp.
Note that currently MLA ignores the enable/disable mechanism of CustomOp
because there is only one in-tree implementation in forward_native.
TODO: implement this with a new PluggableLayer mechanism.
This class takes positions and hidden_states as input.
The input tensors can either contain prefill tokens or decode tokens.
The class does the following:
1. MLA Preprocess.
2. Perform multi-head attention to prefill tokens and
multi-query attention to decode tokens separately.
3. Return the output tensor.
"""
def __init__(
self,
hidden_size: int,
num_heads: int,
scale: float,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
q_lora_rank: Optional[int],
kv_lora_rank: int,
mla_modules: MLAModules,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
self.v_head_dim = v_head_dim
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.num_heads = num_heads
self.fused_qkv_a_proj = mla_modules.fused_qkv_a_proj
self.kv_a_proj_with_mqa = mla_modules.kv_a_proj_with_mqa
self.q_a_layernorm = mla_modules.q_a_layernorm
self.q_b_proj = mla_modules.q_b_proj
self.q_proj = mla_modules.q_proj
self.kv_a_layernorm = mla_modules.kv_a_layernorm
self.kv_b_proj = mla_modules.kv_b_proj
self.rotary_emb = mla_modules.rotary_emb
self.o_proj = mla_modules.o_proj
# In the MLA backend, kv_cache includes both k_c and
# pe (i.e. decoupled position embeddings). In particular,
# the concat_and_cache_mla op requires
# k_c.size(1) + k_pe.size(1) == kv_cache.size(2)
# i.e.
# kv_lora_rank + qk_rope_head_dim == head_size
self.mla_attn = Attention(
num_heads=self.num_heads,
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
scale=scale,
num_kv_heads=1,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
use_mla=True,
# MLA Args
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
qk_nope_head_dim=self.qk_nope_head_dim,
qk_rope_head_dim=self.qk_rope_head_dim,
qk_head_dim=self.qk_head_dim,
v_head_dim=self.v_head_dim,
kv_b_proj=self.kv_b_proj,
)
self.prefix = prefix
self.debug_layer_idx = int(self.prefix.split(".")[-2])
def forward_native(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
q_c = None
kv_lora = None
if self.q_lora_rank is not None:
assert self.fused_qkv_a_proj is not None, \
"fused_qkv_a_proj is required when q_lora_rank is not None"
assert self.q_a_layernorm is not None, \
"q_a_layernorm is required when q_lora_rank is not None"
assert self.q_b_proj is not None, \
"q_b_proj is required when q_lora_rank is not None"
qkv_lora = self.fused_qkv_a_proj(hidden_states)[0]
q_c, kv_lora = qkv_lora.split(
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
dim=-1,
)
q_c = self.q_a_layernorm(q_c)
q = self.q_b_proj(q_c)[0]
else:
assert self.kv_a_proj_with_mqa is not None, \
"kv_a_proj_with_mqa is required when q_lora_rank is None"
assert self.q_proj is not None, \
"q_proj is required when q_lora_rank is None"
kv_lora = self.kv_a_proj_with_mqa(hidden_states)[0]
q = self.q_proj(hidden_states)[0]
kv_c, k_pe = kv_lora.split([self.kv_lora_rank, self.qk_rope_head_dim],
dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c)
q = q.view(-1, self.num_heads, self.qk_head_dim)
# Add head dim of 1 to k_pe
k_pe = k_pe.unsqueeze(1)
q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb(
positions, q[..., self.qk_nope_head_dim:], k_pe)
attn_out = self.mla_attn(
q,
kv_c_normed,
k_pe,
output_shape=(hidden_states.shape[0],
self.num_heads * self.v_head_dim))
return self.o_proj(attn_out)[0]
def forward_cuda(self, *args, **kwargs):
return self.forward_native(*args, **kwargs)
......@@ -5,7 +5,7 @@ from collections.abc import Mapping, Set
from dataclasses import dataclass
from enum import IntEnum
from itertools import groupby
from typing import Callable, Optional, TypeVar, Union, cast
from typing import Callable, Optional, TypeVar, Union
import torch
import torch.nn as nn
......@@ -362,14 +362,13 @@ class PoolerIdentity(PoolerActivation):
class PoolerNormalize(PoolerActivation):
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
x = F.normalize(pooled_data.float(), p=2, dim=-1)
return x.to(pooled_data.dtype)
return F.normalize(pooled_data, p=2, dim=-1)
class PoolerMultiLabelClassify(PoolerActivation):
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
return F.sigmoid(pooled_data.float()).to(pooled_data.dtype)
return F.sigmoid(pooled_data)
class PoolerClassify(PoolerActivation):
......@@ -394,9 +393,9 @@ class PoolerClassify(PoolerActivation):
pooled_data.shape[-1])
if num_labels < 2:
return F.sigmoid(pooled_data.float()).to(pooled_data.dtype)
return F.sigmoid(pooled_data)
return F.softmax(pooled_data.float(), dim=-1).to(pooled_data.dtype)
return F.softmax(pooled_data, dim=-1)
class LambdaPoolerActivation(PoolerActivation):
......@@ -432,8 +431,9 @@ class EmbeddingPoolerHead(PoolerHead):
from vllm.model_executor.models.adapters import _load_st_projector
vllm_config = get_current_vllm_config()
self.projector = _load_st_projector(
self.projector: Optional[nn.Module] = _load_st_projector(
vllm_config.model_config) if vllm_config else None
self.head_dtype = vllm_config.model_config.head_dtype
def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
pooling_metadata: PoolingMetadata):
......@@ -442,16 +442,11 @@ class EmbeddingPoolerHead(PoolerHead):
pooled_data = torch.stack(pooled_data)
# pooled_data shape: [batchsize, hidden_dimension]
pooled_data = pooled_data.to(self.head_dtype)
# Apply ST projector
if self.projector is not None:
projector = cast(nn.Module, self.projector)
def _proj(x: torch.Tensor) -> torch.Tensor:
orig_dtype = x.dtype
y = projector(x.to(torch.float32))
return y.to(orig_dtype)
pooled_data = _proj(pooled_data)
pooled_data = self.projector(pooled_data)
# pooled_data shape: [batchsize, embedding_dimension]
pooling_params = get_pooling_params(pooling_metadata)
......@@ -494,8 +489,18 @@ class RewardPoolerHead(PoolerHead):
def __init__(self) -> None:
super().__init__(activation=PoolerClassify(static_num_labels=False))
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
self.head_dtype = vllm_config.model_config.head_dtype
def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
pooling_metadata: PoolingMetadata):
if isinstance(pooled_data, list):
pooled_data = [p.to(self.head_dtype) for p in pooled_data]
else:
pooled_data = pooled_data.to(self.head_dtype)
pooling_params = get_pooling_params(pooling_metadata)
# for softmax
......@@ -633,9 +638,15 @@ class ClassifierPooler(Pooler):
) -> None:
super().__init__()
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
self.pooling = pooling
self.classifier = classifier
self.act_fn = act_fn or PoolerClassify()
self.logit_bias: Optional[
float] = vllm_config.model_config.pooler_config.logit_bias
self.head_dtype = vllm_config.model_config.head_dtype
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"classify", "score"}
......@@ -650,10 +661,15 @@ class ClassifierPooler(Pooler):
pooled_data = torch.stack(pooled_data)
# pooled_data shape: [batchsize, hidden_size]
pooled_data = pooled_data.to(self.head_dtype)
if self.classifier is not None:
pooled_data = self.classifier(pooled_data)
# pooled_data shape: [batchsize, num_labels]
if self.logit_bias is not None:
pooled_data -= self.logit_bias
pooling_params = get_pooling_params(pooling_metadata)
flags = [p.activation for p in pooling_params]
......
......@@ -26,7 +26,6 @@ QuantizationMethods = Literal[
"bitsandbytes",
"hqq",
"experts_int8",
"neuron_quant",
"ipex",
"quark",
"moe_wna16",
......@@ -108,7 +107,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
from .modelopt import ModelOptFp8Config, ModelOptNvFp4Config
from .moe_wna16 import MoeWNA16Config
from .mxfp4 import Mxfp4Config
from .neuron_quant import NeuronQuantConfig
from .petit import PetitNvFp4Config
from .ptpc_fp8 import PTPCFp8Config
from .rtn import RTNConfig
......@@ -135,7 +133,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"ptpc_fp8": PTPCFp8Config,
"hqq": HQQMarlinConfig,
"experts_int8": ExpertsInt8Config,
"neuron_quant": NeuronQuantConfig,
"ipex": IPEXConfig,
"quark": QuarkConfig,
"moe_wna16": MoeWNA16Config,
......
......@@ -327,6 +327,8 @@ class AutoRoundConfig(QuantizationConfig):
if isinstance(layer, FusedMoE):
if use_marlin:
return GPTQMarlinMoEMethod(quant_args_marlin, layer.moe)
else:
from vllm.model_executor.layers.quantization.moe_wna16 import (
MoeWNA16Config)
......@@ -339,7 +341,6 @@ class AutoRoundConfig(QuantizationConfig):
}
return MoeWNA16Config.from_config(config).get_quant_method(
layer, prefix)
return GPTQMarlinMoEMethod(quant_args_marlin, layer.moe)
if isinstance(layer, (LinearBase, ParallelLMHead)):
if use_marlin:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Callable, Optional
from typing import Any, Callable, Optional, Union
import torch
from torch.nn import Parameter
......@@ -505,7 +505,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
assert self.fused_experts is None
if enable_eplb:
......
......@@ -19,7 +19,7 @@ def awq_dequantize_kernel(
num_rows, # input num rows in qweight
BLOCK_SIZE_X: tl.constexpr,
BLOCK_SIZE_Y: tl.constexpr):
# Setup the pids.
# Set up the pids.
pid_x = tl.program_id(axis=0)
pid_y = tl.program_id(axis=1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment