Commit 539aa992 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.2' into v0.6.2-dev

parents 93872128 7193774b
......@@ -27,7 +27,7 @@ if TYPE_CHECKING:
def compute_meta(
token_lora_tensor: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, bool]:
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, bool]:
"""
Get the information required for the sgmv kernel. With the features:
1. If consecutive requests in the batch use the same LoRA, this function
......@@ -43,7 +43,7 @@ def compute_meta(
b_seq_start_tensor = torch.zeros_like(seq_length_tensor)
b_seq_start_tensor[1:].copy_(cum_result[:-1])
max_length = seq_length_tensor.max().item()
token_nums = seq_length_tensor.sum().item()
batch_size = lora_indices_tensor.size(0)
no_lora = False
# -1 means no lora should be applied. Use `no_lora` to determine whether
......@@ -52,7 +52,7 @@ def compute_meta(
if batch_size == 1 and lora_indices_tensor == -1:
no_lora = True
return (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor,
batch_size, max_length, no_lora)
batch_size, max_length, token_nums, no_lora)
# TODO see if this can be vectorized
......@@ -216,6 +216,7 @@ class PunicaWrapper:
dtype=torch.long,
device=device)
self.max_length: int = 0
self.token_nums: int = 0
self.batch_size: int = -1
self.is_prefill = False
self.no_lora = False
......@@ -276,13 +277,13 @@ class PunicaWrapper:
long_lora_offsets_tensor)
else:
self._long_lora_indices.zero_()
self.indices_len[:] = indices_len
def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None:
(b_seq_start_tensor, seq_length_tensor, lora_indices_tensor,
batch_size, max_length, no_lora) = compute_meta(token_lora_tensor)
batch_size, max_length, token_nums,
no_lora) = compute_meta(token_lora_tensor)
self._seq_start_locs[:b_seq_start_tensor.shape[0]].copy_(
b_seq_start_tensor)
......@@ -291,25 +292,28 @@ class PunicaWrapper:
lora_indices_tensor)
self.batch_size = batch_size
self.max_length = max_length
self.token_nums = token_nums
self.no_lora = no_lora
@property
def prefill_metadata(
self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int]:
self
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]:
"""
This property provides a convenient way to access the necessary
metadata for prefill-related kernel computations.
1. seq_start_locs: Tensor of sequence start positions
2. seq_lengths: Tensor of sequence lengths
1. seq_start_locs: Tensor of sequence start positions.
2. seq_lengths: Tensor of sequence lengths.
3. lora_indices_per_batch: Tensor of lora indices, and an index of
-1 means no lora should be applied.
4. batch_size: batch size after clustering identical lora indices
5. max_length: The maximum sequence length in the batch
4. batch_size: Batch size after clustering identical lora indices.
5. max_length: The maximum sequence length in the batch.
6. token_nums: The token numbers in the batch.
"""
return (self._seq_start_locs[:self.batch_size],
self._seq_lengths[:self.batch_size],
self._lora_indices_per_batch[:self.batch_size],
self.batch_size, self.max_length)
self.batch_size, self.max_length, self.token_nums)
@property
def token_lora_indices(self) -> torch.Tensor:
......@@ -324,7 +328,7 @@ class PunicaWrapper:
def sampler_indices(self) -> torch.Tensor:
"""
This property is used to access the lora indices specifically for
LogitsProcessorWithLoRA
LogitsProcessorWithLoRA.
"""
sampler_indices_len = self.indices_len[1]
return self._sampler_indices[:sampler_indices_len]
......@@ -332,7 +336,7 @@ class PunicaWrapper:
@property
def sampler_indices_padded(self) -> torch.Tensor:
"""
This property provides access to padded sampler indices
This property provides access to padded sampler indices.
"""
indices_padded_len = self.indices_len[2]
return self._sampler_indices_padded[:indices_padded_len]
......@@ -341,7 +345,7 @@ class PunicaWrapper:
def embeddings_indices(self) -> torch.Tensor:
"""
This property provides access to the indices used for lora embeddings,
specifically for VocabParallelEmbeddingWithLoRA
specifically for VocabParallelEmbeddingWithLoRA.
"""
embeddings_indices_len = self.indices_len[3]
return self._embeddings_indices[:, :embeddings_indices_len]
......@@ -350,7 +354,7 @@ class PunicaWrapper:
def long_lora_indices(self) -> torch.Tensor:
"""
This property provides access to the indices used for long context
lora, specifically for LinearScalingRotaryEmbeddingWithLora
lora, specifically for LinearScalingRotaryEmbeddingWithLora.
"""
long_lora_len = self.indices_len[4]
return self._long_lora_indices[:long_lora_len]
......@@ -524,7 +528,7 @@ class PunicaWrapper:
scale (float): Scaling factor.
y_offset (Optional[int], optional): Offset to apply to the starting
column of y.
y_slice_size (Optional[int], optional): Size of the y column slice..
y_slice_size (Optional[int], optional): Size of the y column slice.
buffer (Optional[torch.Tensor], optional): Defaults to None.
"""
y_org = y
......
......@@ -28,6 +28,7 @@ class LoRARequest(
lora_path: str = ""
lora_local_path: Optional[str] = msgspec.field(default=None)
long_lora_max_len: Optional[int] = None
base_model_name: Optional[str] = msgspec.field(default=None)
def __post_init__(self):
if 'lora_local_path' in self.__struct_fields__:
......
import torch.nn as nn
import vllm.envs as envs
from vllm.platforms import current_platform
from vllm.utils import is_cpu, is_hip, is_xpu
......@@ -53,6 +54,10 @@ class CustomOp(nn.Module):
def dispatch_forward(self):
# NOTE(woosuk): Here we assume that vLLM was built for only one
# specific backend. Currently, we do not support dynamic dispatching.
if envs.VLLM_TEST_COMPILE_NO_CUSTOM_OPS:
return self.forward_native
if is_hip():
return self.forward_hip
elif is_cpu():
......
......@@ -67,9 +67,9 @@ class BaseLogitsProcessor:
instruction = self._guide.get_next_instruction(
state=self._fsm_state[seq_id])
if type(instruction) == Generate:
if type(instruction) == Generate: # noqa: E721
allowed_tokens = instruction.tokens
elif type(instruction) == Write:
elif type(instruction) == Write: # noqa: E721
# TODO: support fast forward tokens
allowed_tokens = [instruction.tokens[0]]
else:
......
......@@ -124,9 +124,7 @@ class NewGELU(CustomOp):
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
from vllm._ipex_ops import ipex_ops as ops
out = torch.empty_like(x)
ops.gelu_new(out, x)
return out
return ops.gelu_new(x)
class FastGELU(CustomOp):
......@@ -146,9 +144,7 @@ class FastGELU(CustomOp):
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
from vllm._ipex_ops import ipex_ops as ops
out = torch.empty_like(x)
ops.gelu_fast(out, x)
return out
return ops.gelu_fast(x)
class QuickGELU(CustomOp):
......@@ -165,6 +161,13 @@ class QuickGELU(CustomOp):
ops.gelu_quick(out, x)
return out
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
from vllm._ipex_ops import ipex_ops as ops
out = torch.empty_like(x)
ops.gelu_quick(out, x)
return out
# TODO implement forward_xpu for QuickGELU
# def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
......
......@@ -7,6 +7,7 @@ import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, moe_align_block_size, try_get_optimal_moe_config)
from vllm.scalar_type import scalar_types
def single_marlin_moe(
......@@ -18,7 +19,9 @@ def single_marlin_moe(
perm: torch.Tensor,
topk: int,
renormalize: bool,
override_config: Optional[Dict[str, Any]] = None) -> torch.Tensor:
override_config: Optional[Dict[str, Any]] = None,
num_bits: int = 8,
) -> torch.Tensor:
"""
This function computes the multiplication of hidden_states with expert
weights used in Marlin MoE, using weights w and top-k gating mechanism.
......@@ -36,6 +39,7 @@ def single_marlin_moe(
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- num_bits (bool): The number of bits in expert weights quantization.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
......@@ -48,10 +52,11 @@ def single_marlin_moe(
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w.is_contiguous(), "Expert weights must be contiguous"
assert hidden_states.dtype == torch.float16
assert num_bits in [4, 8]
M, K = hidden_states.shape
E = w.shape[0]
N = w.shape[2] // 2
N = w.shape[2] // (num_bits // 2)
topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
renormalize)
......@@ -76,10 +81,13 @@ def single_marlin_moe(
device="cuda",
requires_grad=False)
scalar_type = (scalar_types.uint4b8
if num_bits == 4 else scalar_types.uint8b128)
intermediate_cache = torch.ops._moe_C.marlin_gemm_moe(
hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales,
g_idx, perm, workspace, M, N, K, True, E, topk, block_size_m, True,
False)
g_idx, perm, workspace, scalar_type, M, N, K, True, E, topk,
block_size_m, True, False)
return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
......@@ -98,6 +106,7 @@ def fused_marlin_moe(
override_config: Optional[Dict[str, Any]] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
num_bits: int = 8,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
......@@ -122,6 +131,7 @@ def fused_marlin_moe(
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.
- num_bits (bool): The number of bits in expert weights quantization.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
......@@ -131,13 +141,14 @@ def fused_marlin_moe(
0], "Number of tokens mismatch"
assert hidden_states.shape[
1] == w1.shape[1] * 16, "Hidden size mismatch w1"
assert hidden_states.shape[
1] == w2.shape[2] // 2, "Hidden size mismatch w2"
assert hidden_states.shape[1] == w2.shape[2] // (
num_bits // 2), "Hidden size mismatch w2"
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype == torch.float16
assert num_bits in [4, 8]
M, K = hidden_states.shape
E = w1.shape[0]
......@@ -165,6 +176,9 @@ def fused_marlin_moe(
device="cuda",
requires_grad=False)
scalar_type = (scalar_types.uint4b8
if num_bits == 4 else scalar_types.uint8b128)
intermediate_cache2 = torch.empty(
(M * topk_ids.shape[1], N),
device=hidden_states.device,
......@@ -181,6 +195,7 @@ def fused_marlin_moe(
g_idx1,
perm1,
workspace,
scalar_type,
M,
2 * N,
K,
......@@ -204,6 +219,7 @@ def fused_marlin_moe(
g_idx2,
perm2,
workspace,
scalar_type,
M,
K,
N,
......
......@@ -445,7 +445,7 @@ def grouped_topk(hidden_states: torch.Tensor,
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids.to(torch.int32)
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
def get_config_dtype_str(dtype: torch.dtype,
......
......@@ -323,10 +323,12 @@ class FusedMoE(torch.nn.Module):
loaded_weight: torch.Tensor, weight_name: str,
shard_id: str, expert_id: int) -> None:
# compressed-tensors represents weights on disk which are flipped
# compressed-tensors checkpoints with packed weights are stored flipped
# TODO (mgoin): check self.quant_method.quant_config.quant_format
# against known CompressionFormat enum values that have this quality
loaded_weight = loaded_weight.t().contiguous() if (
self.quant_method.__class__.__name__
== "CompressedTensorsMoEMethod") else loaded_weight
== "CompressedTensorsWNA16MoEMethod") else loaded_weight
if shard_id not in ("w1", "w2", "w3"):
raise ValueError(f"shard_id must be ['w1','w2','w3'] but "
......@@ -353,6 +355,9 @@ class FusedMoE(torch.nn.Module):
# Case input scale: input_scale loading is only supported for fp8
if "input_scale" in weight_name:
# this is needed for compressed-tensors only
loaded_weight = loaded_weight.to(param.data.device)
if param.data[expert_id] != 1 and (param.data[expert_id] -
loaded_weight).abs() > 1e-5:
raise ValueError(
......
......@@ -99,14 +99,11 @@ class RMSNorm(CustomOp):
self.variance_epsilon,
)
return x, residual
out = torch.empty_like(x)
ops.rms_norm(
out,
return ops.rms_norm(
x,
self.weight.data,
self.variance_epsilon,
)
return out
def extra_repr(self) -> str:
s = f"hidden_size={self.weight.data.size(0)}"
......
......@@ -549,6 +549,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
start_idx = tp_rank * shard_size
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if not use_bitsandbytes_4bit:
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
# Special case for AQLM codebooks.
......@@ -918,8 +921,13 @@ class QKVParallelLinear(ColumnParallelLinear):
else:
shard_id = tp_rank // self.num_kv_head_replicas
start_idx = shard_id * shard_size
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if not use_bitsandbytes_4bit:
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
# Special case for for AQLM codebooks.
elif is_metadata:
# metadata indicates fixed size concatenated along dim 0
......@@ -1019,6 +1027,7 @@ class RowParallelLinear(LinearBase):
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
input_dim = getattr(param, "input_dim", None)
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
# Special case for GGUF
is_gguf_weight = getattr(param, "is_gguf_weight", False)
......@@ -1034,7 +1043,9 @@ class RowParallelLinear(LinearBase):
param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
param_data = param.data
if input_dim is not None:
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if input_dim is not None and not use_bitsandbytes_4bit:
shard_size = param_data.shape[input_dim]
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(input_dim, start_idx,
......
# Copyright (c) 2024, Tri Dao.
# Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py
from typing import Optional
......@@ -70,12 +71,17 @@ def causal_conv1d_update(x: torch.Tensor,
conv_state: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
activation: Optional[str] = None):
activation: Optional[str] = None,
conv_state_indices: Optional[torch.Tensor] = None):
"""
x: (batch, dim)
conv_state: (batch, dim, width)
weight: (dim, width)
bias: (dim,)
conv_state_indices: (batch,), dtype int32
If not None, the conv_state is a larger tensor along the batch dim,
and we are selecting the batch coords specified by conv_state_indices.
Useful for a continuous batching scenario.
out: (batch, dim)
"""
......@@ -83,4 +89,4 @@ def causal_conv1d_update(x: torch.Tensor,
raise NotImplementedError("activation must be None, silu, or swish")
activation_bool = activation in ["silu", "swish"]
return ops.causal_conv1d_update(x, conv_state, weight, bias,
activation_bool)
activation_bool, conv_state_indices)
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/selective_state_update.py
import torch
import triton
......@@ -27,6 +28,10 @@ else:
{"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
@triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
@triton.heuristics({
"HAS_STATE_BATCH_INDICES":
lambda args: args["state_batch_indices_ptr"] is not None
})
@triton.heuristics(
{"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])})
@triton.jit
......@@ -42,6 +47,7 @@ def _selective_scan_update_kernel(
D_ptr,
z_ptr,
out_ptr,
state_batch_indices_ptr,
# Matrix dimensions
batch,
nheads,
......@@ -85,12 +91,24 @@ def _selective_scan_update_kernel(
HAS_DT_BIAS: tl.constexpr,
HAS_D: tl.constexpr,
HAS_Z: tl.constexpr,
HAS_STATE_BATCH_INDICES: tl.constexpr,
BLOCK_SIZE_DSTATE: tl.constexpr,
):
pid_m = tl.program_id(axis=0)
pid_b = tl.program_id(axis=1)
pid_h = tl.program_id(axis=2)
# If HAS_STATE_BATCH_INDICES is true, then the ssm state's batch coordinate
# is taken from the state_batch_indices_ptr Otherwise, the state coordinate
# is the same as the batch id.
if HAS_STATE_BATCH_INDICES:
state_batch_indices_ptr += pid_b
state_batch_idx = tl.load(state_batch_indices_ptr)
state_ptr += (state_batch_idx * stride_state_batch +
pid_h * stride_state_head)
else:
state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head
dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head
if HAS_DT_BIAS:
......@@ -177,7 +195,8 @@ def selective_state_update(state,
D=None,
z=None,
dt_bias=None,
dt_softplus=False):
dt_softplus=False,
state_batch_indices=None):
"""
Argument:
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
......@@ -211,7 +230,10 @@ def selective_state_update(state,
z = z.unsqueeze(1)
if dt_bias is not None and dt_bias.dim() == 1:
dt_bias = dt_bias.unsqueeze(0)
batch, nheads, dim, dstate = state.shape
_, nheads, dim, dstate = state.shape
batch = x.shape[0]
assert x.shape == (batch, nheads, dim)
assert dt.shape == x.shape
assert A.shape == (nheads, dim, dstate)
......@@ -225,6 +247,8 @@ def selective_state_update(state,
assert z.shape == x.shape
if dt_bias is not None:
assert dt_bias.shape == (nheads, dim)
if state_batch_indices is not None:
assert state_batch_indices.shape == (batch, )
out = torch.empty_like(x)
grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads)
z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else
......@@ -249,6 +273,7 @@ def selective_state_update(state,
D,
z,
out,
state_batch_indices,
batch,
nheads,
dim,
......@@ -336,7 +361,7 @@ def selective_scan_fn(u,
x[:, :, 0, 0::2] = 1
if prev_state is not None:
x[:, :, 0, 1::2].copy_(prev_state)
out, x, *rest = ops.selective_scan_fwd(u, delta, A, B, C, D, z, delta_bias,
out, *rest = ops.selective_scan_fwd(u, delta, A, B, C, D, z, delta_bias,
delta_softplus, position_indices, x)
last_state = x[:, :, -1, 1::2] # (batch, dim, dstate)
if z is None:
......
......@@ -7,10 +7,11 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported,
marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales,
replace_tensor, verify_marlin_supported, verify_marlin_supports_shape)
verify_marlin_supported, verify_marlin_supports_shape)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
PackedvLLMParameter)
......@@ -110,9 +111,9 @@ class AWQMarlinConfig(QuantizationConfig):
def is_awq_marlin_compatible(cls, quant_config: Dict[str, Any]):
# Extract data from quant config.
quant_method = quant_config.get("quant_method", "").lower()
num_bits = quant_config.get("bits", None)
group_size = quant_config.get("group_size", None)
has_zp = quant_config.get("zero_point", None)
num_bits = quant_config.get("bits")
group_size = quant_config.get("group_size")
has_zp = quant_config.get("zero_point")
if quant_method != "awq":
return False
......@@ -231,7 +232,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition,
num_bits=self.quant_config.quant_type.size_bits)
replace_tensor(layer, "qweight", marlin_qweight)
replace_parameter(layer, "qweight", marlin_qweight)
# Permute scales from AWQ format to marlin format.
marlin_scales = marlin_permute_scales(
......@@ -239,7 +240,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition,
group_size=self.quant_config.group_size)
replace_tensor(layer, "scales", marlin_scales)
replace_parameter(layer, "scales", marlin_scales)
# Permute zero-points from AWQ format to marlin format.
marlin_zp = awq_to_marlin_zero_points(
......@@ -247,7 +248,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
size_k=layer.num_groups,
size_n=layer.output_size_per_partition,
num_bits=self.quant_config.quant_type.size_bits)
replace_tensor(layer, "qzeros", marlin_zp)
replace_parameter(layer, "qzeros", marlin_zp)
# Not-used
layer.g_idx = marlin_make_empty_g_idx(device)
......
......@@ -209,12 +209,9 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K,
c = accumulator.to(c_ptr.type.element_ty)
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + N * offs_cm[:, None] + offs_cn[None, :]
c_ptrs = c_ptr + pid_z * N * M + N * offs_cm[:, None] + offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
if SPLIT_K == 1:
tl.store(c_ptrs, c, mask=c_mask)
else:
tl.atomic_add(c_ptrs, c, mask=c_mask)
# qweights - [K , M // 8], int32
......@@ -295,7 +292,9 @@ def awq_gemm_triton(input: torch.Tensor,
split_k_iters,
)
result = torch.zeros((M, N), dtype=scales.dtype, device=input.device)
result = torch.zeros((split_k_iters, M, N),
dtype=scales.dtype,
device=input.device)
# A = input, B = qweight, C = result
# A = M x K, B = K x N, C = M x N
......@@ -313,4 +312,6 @@ def awq_gemm_triton(input: torch.Tensor,
BLOCK_SIZE_K=block_size_k,
SPLIT_K=split_k_iters)
result = result.sum(0)
return result
......@@ -121,12 +121,12 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
def __init__(self, quant_config: BitsAndBytesConfig):
try:
import bitsandbytes
if bitsandbytes.__version__ < "0.42.0":
if bitsandbytes.__version__ < "0.44.0":
raise ImportError("bitsandbytes version is wrong. Please "
"install bitsandbytes>=0.42.0.")
"install bitsandbytes>=0.44.0.")
except ImportError as err:
raise ImportError("Please install bitsandbytes>=0.42.0 via "
"`pip install bitsandbytes>=0.42.0` to use "
raise ImportError("Please install bitsandbytes>=0.44.0 via "
"`pip install bitsandbytes>=0.44.0` to use "
"bitsandbytes quantizer.") from err
self.quant_config = quant_config
......
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, cast
import torch
from pydantic import BaseModel
......@@ -73,14 +73,14 @@ class CompressedTensorsConfig(QuantizationConfig):
if isinstance(layer, Attention):
return CompressedTensorsKVCacheMethod(self)
if isinstance(layer, FusedMoE):
return CompressedTensorsMoEMethod(self)
return CompressedTensorsMoEMethod.get_moe_method(self)
return None
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
target_scheme_map: Dict[str, Any] = dict()
ignore: List[str] = config.get("ignore", None)
quant_format: str = config.get("format", None)
ignore = cast(List[str], config.get("ignore"))
quant_format = cast(str, config.get("format"))
# The quant_config has multiple config_groups, each containing
# an input_activations key with details about how the activations are
......@@ -116,10 +116,10 @@ class CompressedTensorsConfig(QuantizationConfig):
def _check_scheme_supported(self,
min_capability: int,
error: bool = True) -> bool:
capability = current_platform.get_device_capability() # type: ignore
capability_tuple = current_platform.get_device_capability()
if capability is not None:
capability = capability[0] * 10 + capability[1]
if capability_tuple is not None:
capability = capability_tuple.to_int()
supported = capability >= min_capability
if error and not supported:
raise RuntimeError(
......@@ -200,7 +200,7 @@ class CompressedTensorsConfig(QuantizationConfig):
is_per_tensor_or_channel_weight = (weight_quant.strategy in [
QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL
])
if not (is_symmetric_weight and is_static_weight
if not (is_symmetric_weight and is_static_weight # noqa: SIM103
and is_per_tensor_or_channel_weight):
return False
......
......@@ -5,10 +5,16 @@ from typing import Callable, List, Optional
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
WNA16_SUPPORTED_BITS)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
CompressionFormat)
CompressionFormat, QuantizationStrategy)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import is_hip, print_warning_once
class GPTQMarlinState(Enum):
......@@ -16,11 +22,219 @@ class GPTQMarlinState(Enum):
READY = enum.auto()
__all__ = ["CompressedTensorsMoEMethod"]
__all__ = [
"CompressedTensorsMoEMethod", "CompressedTensorsW8A8Fp8MoEMethod",
"CompressedTensorsWNA16MoEMethod"
]
class CompressedTensorsMoEMethod(FusedMoEMethodBase):
@staticmethod
def get_moe_method(
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
) -> "CompressedTensorsMoEMethod":
# TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored.
weight_quant = quant_config.target_scheme_map["Linear"].get("weights")
input_quant = quant_config.target_scheme_map["Linear"].get(
"input_activations")
if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
return CompressedTensorsWNA16MoEMethod(quant_config)
elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Fp8MoEMethod(quant_config)
else:
raise RuntimeError(
f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}")
class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
def __init__(
self,
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
):
self.quant_config = quant_config
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get(
"weights")
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
"input_activations")
if not (self.weight_quant.strategy == QuantizationStrategy.TENSOR
and self.input_quant.strategy == QuantizationStrategy.TENSOR):
raise ValueError(
"For FP8 Fused MoE layers, only per-tensor scales"
"for weights and activations are supported. Found "
f"{self.weight_quant}, {self.input_quant}")
self.static_input_scales = not self.input_quant.dynamic
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size: int,
params_dtype: torch.dtype, **extra_weight_attrs):
params_dtype = torch.float8_e4m3fn
# WEIGHTS
w13_weight = torch.nn.Parameter(torch.empty(num_experts,
2 * intermediate_size,
hidden_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(torch.empty(num_experts,
hidden_size,
intermediate_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
2,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Add the quantization method used (per tensor/grouped/channel)
# to ensure the weight scales are loaded in properly
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
# INPUT_SCALES
if self.static_input_scales:
w13_input_scale = torch.nn.Parameter(torch.ones(
num_experts, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_input_scale", w13_input_scale)
set_weight_attrs(w13_input_scale, extra_weight_attrs)
w2_input_scale = torch.nn.Parameter(torch.ones(
num_experts, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w2_input_scale", w2_input_scale)
set_weight_attrs(w2_input_scale, extra_weight_attrs)
else:
layer.w13_input_scale = None
layer.w2_input_scale = None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Fp8 moe kernels require a single activation scale.
# We take the max of all the scales in case they differ.
if self.static_input_scales:
if (layer.w13_input_scale is None or layer.w2_input_scale is None):
raise ValueError(
"QuantConfig has static quantization, but found "
"activation scales are None.")
if (not all_close_1d(layer.w13_input_scale)
or not all_close_1d(layer.w2_input_scale)):
print_warning_once(
"Found input_scales that are not equal for "
"fp8 MoE layer. Using the maximum across experts "
"for each layer. ")
layer.w13_input_scale = torch.nn.Parameter(
layer.w13_input_scale.max(), requires_grad=False)
layer.w2_input_scale = torch.nn.Parameter(
layer.w2_input_scale.max(), requires_grad=False)
# If rocm, normalize the weights and scales to e4m3fnuz
if is_hip():
# Normalize the weights and scales
w13_weight, w13_weight_scale, w13_input_scale = \
normalize_e4m3fn_to_e4m3fnuz(
layer.w13_weight, layer.w13_weight_scale,
layer.w13_input_scale)
w2_weight, w2_weight_scale, w2_input_scale = \
normalize_e4m3fn_to_e4m3fnuz(
layer.w2_weight, layer.w2_weight_scale,
layer.w2_input_scale)
# Reset the parameter
layer.w13_weight = torch.nn.Parameter(w13_weight,
requires_grad=False)
layer.w13_weight_scale = torch.nn.Parameter(w13_weight_scale,
requires_grad=False)
if w13_input_scale is not None:
layer.w13_input_scale = torch.nn.Parameter(w13_input_scale,
requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight,
requires_grad=False)
layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale,
requires_grad=False)
if w2_input_scale is not None:
layer.w2_input_scale = torch.nn.Parameter(w2_input_scale,
requires_grad=False)
# Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max then dequant and requant each expert.
assert layer.w13_weight_scale is not None
shard_size = layer.intermediate_size_per_partition
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
for expert_id in range(layer.num_experts):
start = 0
for shard_id in range(2):
dq_weight = per_tensor_dequantize(
layer.w13_weight[expert_id][start:start + shard_size, :],
layer.w13_weight_scale[expert_id][shard_id])
layer.w13_weight[expert_id][
start:start + shard_size, :], _ = ops.scaled_fp8_quant(
dq_weight, max_w13_scales[expert_id])
start += shard_size
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
requires_grad=False)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function)
return fused_experts(x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
use_fp8_w8a8=True,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale)
class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
def __init__(
self,
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
......@@ -38,10 +252,11 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
if not (self.quant_config.quant_format
== CompressionFormat.pack_quantized.value
and self.num_bits == 4):
and self.num_bits in WNA16_SUPPORTED_BITS):
raise ValueError("For Fused MoE layers, only ",
f"{CompressionFormat.pack_quantized.value} ",
"is supported for 4 bits")
"is supported for the following bits: ",
f"{WNA16_SUPPORTED_BITS}")
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size: int,
......@@ -292,4 +507,5 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
topk_ids,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
num_bits=self.num_bits,
)
......@@ -8,10 +8,12 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
QuantizationStrategy)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear, cutlass_fp8_supported, requantize_with_max_scale)
apply_fp8_linear, cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz,
requantize_with_max_scale)
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)
from vllm.utils import is_hip
__all__ = ["CompressedTensorsW8A8Fp8"]
......@@ -39,16 +41,37 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
logical_widths=layer.logical_widths,
)
if is_hip():
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
weight=weight,
weight_scale=max_w_scale,
input_scale=layer.input_scale)
if input_scale is not None:
layer.input_scale = Parameter(input_scale,
requires_grad=False)
layer.weight = Parameter(weight.t(), requires_grad=False)
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
# If channelwise, scales are already lined up, so just transpose.
elif self.strategy == QuantizationStrategy.CHANNEL:
weight = layer.weight
if is_hip():
weight, weight_scale, input_scale = \
normalize_e4m3fn_to_e4m3fnuz(
weight=weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale)
if input_scale is not None:
layer.input_scale = Parameter(input_scale,
requires_grad=False)
else:
weight_scale = layer.weight_scale.data
layer.weight = Parameter(weight.t(), requires_grad=False)
# required by torch.compile to be torch.nn.Parameter
layer.weight_scale = Parameter(layer.weight_scale.data,
requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
else:
raise ValueError(f"Unknown quantization strategy {self.strategy}")
......
from typing import Callable, List, Optional
from typing import Callable, List, Optional, Set
import torch
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
ActivationOrdering)
from vllm.model_executor.layers.quantization.kernels import (
MPLinearLayerConfig, choose_mp_linear_kernel)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_gptq_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace,
marlin_permute_scales, marlin_repeat_scales_on_all_ranks,
marlin_sort_g_idx, replace_tensor, verify_marlin_supported,
verify_marlin_supports_shape)
marlin_repeat_scales_on_all_ranks)
from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
......@@ -19,6 +18,8 @@ from vllm.model_executor.parameter import (BasevLLMParameter,
RowvLLMParameter)
from vllm.scalar_type import scalar_types
logger = init_logger(__name__)
__all__ = ["CompressedTensorsWNA16"]
WNA16_SUPPORTED_TYPES_MAP = {
4: scalar_types.uint4b8,
......@@ -28,6 +29,7 @@ WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys())
class CompressedTensorsWNA16(CompressedTensorsScheme):
_kernel_backends_being_used: Set[str] = set()
def __init__(self,
strategy: str,
......@@ -52,35 +54,43 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
self.quant_type = WNA16_SUPPORTED_TYPES_MAP[num_bits]
# Verify supported on platform.
verify_marlin_supported(quant_type=self.quant_type,
group_size=self.group_size)
@classmethod
def get_min_capability(cls) -> int:
# ampere and up
return 80
def create_weights(self, layer: torch.nn.Module, input_size: int,
output_partition_sizes: List[int],
def create_weights(self, layer: torch.nn.Module, output_size: int,
input_size: int, output_partition_sizes: List[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
output_size_per_partition = sum(output_partition_sizes)
mp_linear_kernel_config = MPLinearLayerConfig(
full_weight_shape=(input_size, output_size),
partition_weight_shape=\
(input_size_per_partition, output_size_per_partition),
weight_type=self.quant_type,
act_type=params_dtype,
group_size=self.group_size,
zero_points=False,
has_g_idx=self.has_g_idx
)
kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config)
if kernel_type.__name__ not in self._kernel_backends_being_used:
logger.info("Using %s for CompressedTensorsWNA16",
kernel_type.__name__)
self._kernel_backends_being_used.add(kernel_type.__name__)
# If group_size is -1, we are in channelwise case.
group_size = self.group_size if self.group_size != -1 else input_size
row_parallel = (input_size != input_size_per_partition)
partition_scales = not marlin_repeat_scales_on_all_ranks(
self.has_g_idx, self.group_size, row_parallel)
verify_marlin_supports_shape(
output_size_per_partition=output_size_per_partition,
input_size_per_partition=input_size_per_partition,
input_size=input_size,
group_size=group_size)
scales_and_zp_size = input_size // group_size
if partition_scales:
......@@ -137,69 +147,17 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
weight_loader=weight_loader)
layer.register_parameter("weight_g_idx", weight_g_idx)
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.input_size = input_size
layer.group_size = group_size
self.kernel = kernel_type(mp_linear_kernel_config,
w_q_param_name="weight_packed",
w_s_param_name="weight_scale",
w_zp_param_name=None,
w_gidx_param_name="weight_g_idx")
# Checkpoints are serialized in compressed-tensors format, which is
# different from marlin format. Handle repacking here.
# different from the format the kernel may want. Handle repacking here.
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
device = layer.weight_packed.device
# Allocate marlin workspace.
layer.workspace = marlin_make_workspace(
layer.output_size_per_partition, device)
# Handle sorting for activation reordering if needed.
if self.has_g_idx:
g_idx, g_idx_sort_indices = marlin_sort_g_idx(layer.weight_g_idx)
layer.g_idx_sort_indices = g_idx_sort_indices
replace_tensor(layer, "weight_g_idx", g_idx)
else:
layer.weight_g_idx = marlin_make_empty_g_idx(device)
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
# No zero-point
layer.weight_zp = marlin_make_empty_g_idx(device)
# Update for kernel
layer.weight_packed = torch.nn.Parameter(
layer.weight_packed.t().contiguous(), requires_grad=False)
layer.weight_scale = torch.nn.Parameter(
layer.weight_scale.squeeze().t().contiguous(), requires_grad=False)
# Repack weights from compressed-tensors format to marlin format.
marlin_qweight = ops.gptq_marlin_repack(
layer.weight_packed,
perm=layer.g_idx_sort_indices,
size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition,
num_bits=self.quant_type.size_bits)
replace_tensor(layer, "weight_packed", marlin_qweight)
# Permute scales from compressed-tensors format to marlin format.
# scale is required on all partitions if activation reordering
marlin_scales = marlin_permute_scales(
layer.weight_scale,
size_k=(layer.input_size
if self.has_g_idx else layer.input_size_per_partition),
size_n=layer.output_size_per_partition,
group_size=layer.group_size)
replace_tensor(layer, "weight_scale", marlin_scales)
self.kernel.process_weights_after_loading(layer)
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
return apply_gptq_marlin_linear(
input=x,
weight=layer.weight_packed,
weight_scale=layer.weight_scale,
weight_zp=layer.weight_zp,
g_idx=layer.weight_g_idx,
g_idx_sort_indices=layer.g_idx_sort_indices,
workspace=layer.workspace,
wtype=self.quant_type,
output_size_per_partition=layer.output_size_per_partition,
input_size_per_partition=layer.input_size_per_partition,
is_k_full=True,
bias=bias)
return self.kernel.apply_weights(layer, x, bias)
......@@ -15,10 +15,11 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear)
apply_fp8_linear, normalize_e4m3fn_to_e4m3fnuz)
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
ModelWeightParameter)
from vllm.platforms import current_platform
from vllm.utils import is_hip
logger = init_logger(__name__)
......@@ -32,9 +33,7 @@ class FBGEMMFp8Config(QuantizationConfig):
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
self.use_marlin = capability < 89
self.use_marlin = not current_platform.has_device_capability(89)
@classmethod
def get_name(cls) -> str:
......@@ -127,8 +126,18 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
layer.weight = Parameter(layer.weight.data, requires_grad=False)
weight = layer.weight
layer.weight = Parameter(weight.t(), requires_grad=False)
if is_hip():
weight, weight_scale, input_scale = \
normalize_e4m3fn_to_e4m3fnuz(
weight=weight,
weight_scale=layer.weight_scale,
input_scale=None)
if input_scale is not None:
layer.input_scale = Parameter(input_scale, requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
layer.weight = Parameter(weight.t(), requires_grad=False)
if self.quant_config.use_marlin:
prepare_fp8_layer_for_marlin(layer)
# Activations not quantized for marlin.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment