Commit af7f4372 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.5.5' into v0.5.5-dtk24.04.1

parents 5e19cdef 09c77926
...@@ -11,48 +11,51 @@ import triton.language as tl ...@@ -11,48 +11,51 @@ import triton.language as tl
import vllm.envs as envs import vllm.envs as envs
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform
logger = init_logger(__name__) logger = init_logger(__name__)
@triton.jit @triton.jit
def fused_moe_kernel( def fused_moe_kernel(
# Pointers to matrices # Pointers to matrices
a_ptr, a_ptr,
b_ptr, b_ptr,
c_ptr, c_ptr,
a_scale_ptr, a_scale_ptr,
b_scale_ptr, b_scale_ptr,
topk_weights_ptr, topk_weights_ptr,
sorted_token_ids_ptr, sorted_token_ids_ptr,
expert_ids_ptr, expert_ids_ptr,
num_tokens_post_padded_ptr, num_tokens_post_padded_ptr,
# Matrix dimensions # Matrix dimensions
N, N,
K, K,
EM, EM,
num_valid_tokens, num_valid_tokens,
# The stride variables represent how much to increase the ptr by when # The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is # moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down # how much to increase `a_ptr` by to get the element one row down
# (A has M rows). # (A has M rows).
stride_am, stride_am,
stride_ak, stride_ak,
stride_be, stride_be,
stride_bk, stride_bk,
stride_bn, stride_bn,
stride_cm, stride_cm,
stride_cn, stride_cn,
# Meta-parameters stride_bse,
BLOCK_SIZE_M: tl.constexpr, stride_bsn,
BLOCK_SIZE_N: tl.constexpr, # Meta-parameters
BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_M: tl.constexpr,
GROUP_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
MUL_ROUTED_WEIGHT: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
top_k: tl.constexpr, GROUP_SIZE_M: tl.constexpr,
compute_type: tl.constexpr, MUL_ROUTED_WEIGHT: tl.constexpr,
use_fp8: tl.constexpr, top_k: tl.constexpr,
): compute_type: tl.constexpr,
use_fp8_w8a8: tl.constexpr,
use_int8_w8a16: tl.constexpr):
""" """
Implements the fused computation for a Mixture of Experts (MOE) using Implements the fused computation for a Mixture of Experts (MOE) using
token and expert matrices. token and expert matrices.
...@@ -113,8 +116,12 @@ def fused_moe_kernel( ...@@ -113,8 +116,12 @@ def fused_moe_kernel(
off_experts = tl.load(expert_ids_ptr + pid_m) off_experts = tl.load(expert_ids_ptr + pid_m)
b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +
offs_bn[None, :] * stride_bn) offs_bn[None, :] * stride_bn)
if use_int8_w8a16:
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[
None, :] * stride_bsn
b_scale = tl.load(b_scale_ptrs)
if use_fp8: if use_fp8_w8a8:
a_scale = tl.load(a_scale_ptr) a_scale = tl.load(a_scale_ptr)
b_scale = tl.load(b_scale_ptr + off_experts) b_scale = tl.load(b_scale_ptr + off_experts)
...@@ -136,7 +143,9 @@ def fused_moe_kernel( ...@@ -136,7 +143,9 @@ def fused_moe_kernel(
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
other=0.0) other=0.0)
# We accumulate along the K dimension. # We accumulate along the K dimension.
if use_fp8: if use_int8_w8a16:
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
elif use_fp8_w8a8:
accumulator = tl.dot(a, b, acc=accumulator) accumulator = tl.dot(a, b, acc=accumulator)
else: else:
accumulator += tl.dot(a, b) accumulator += tl.dot(a, b)
...@@ -149,8 +158,9 @@ def fused_moe_kernel( ...@@ -149,8 +158,9 @@ def fused_moe_kernel(
mask=token_mask, mask=token_mask,
other=0) other=0)
accumulator = accumulator * moe_weight[:, None] accumulator = accumulator * moe_weight[:, None]
if use_int8_w8a16:
if use_fp8: accumulator = (accumulator * b_scale).to(compute_type)
elif use_fp8_w8a8:
accumulator = (accumulator * a_scale * b_scale).to(compute_type) accumulator = (accumulator * a_scale * b_scale).to(compute_type)
else: else:
accumulator = accumulator.to(compute_type) accumulator = accumulator.to(compute_type)
...@@ -229,16 +239,18 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, ...@@ -229,16 +239,18 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
num_tokens_post_padded: torch.Tensor, num_tokens_post_padded: torch.Tensor,
mul_routed_weight: bool, top_k: int, mul_routed_weight: bool, top_k: int,
config: Dict[str, Any], compute_type: tl.dtype, config: Dict[str, Any], compute_type: tl.dtype,
use_fp8: bool) -> None: use_fp8_w8a8: bool, use_int8_w8a16: bool) -> None:
assert topk_weights.stride(1) == 1 assert topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1 assert sorted_token_ids.stride(0) == 1
if not use_fp8: if use_fp8_w8a8:
assert A_scale is None
assert B_scale is None
else:
A, A_scale = ops.scaled_fp8_quant(A, A_scale) A, A_scale = ops.scaled_fp8_quant(A, A_scale)
assert B_scale is not None assert B_scale is not None
elif use_int8_w8a16:
assert B_scale is not None
else:
assert A_scale is None
assert B_scale is None
grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[ grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), ) 'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), )
...@@ -264,16 +276,19 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, ...@@ -264,16 +276,19 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
B.stride(1), B.stride(1),
C.stride(1), C.stride(1),
C.stride(2), C.stride(2),
B_scale.stride(0) if B_scale is not None and use_int8_w8a16 else 0,
B_scale.stride(1) if B_scale is not None and use_int8_w8a16 else 0,
MUL_ROUTED_WEIGHT=mul_routed_weight, MUL_ROUTED_WEIGHT=mul_routed_weight,
top_k=top_k, top_k=top_k,
compute_type=compute_type, compute_type=compute_type,
use_fp8=use_fp8, use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
**config, **config,
) )
def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str: def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
device_name = torch.cuda.get_device_name().replace(" ", "_") device_name = current_platform.get_device_name().replace(" ", "_")
dtype_selector = "" if not dtype else f",dtype={dtype}" dtype_selector = "" if not dtype else f",dtype={dtype}"
return f"E={E},N={N},device_name={device_name}{dtype_selector}.json" return f"E={E},N={N},device_name={device_name}{dtype_selector}.json"
...@@ -426,6 +441,20 @@ def grouped_topk(hidden_states: torch.Tensor, ...@@ -426,6 +441,20 @@ def grouped_topk(hidden_states: torch.Tensor,
return topk_weights, topk_ids return topk_weights, topk_ids
def get_config_dtype_str(dtype: torch.dtype,
use_int8_w8a16: Optional[bool] = False,
use_fp8_w8a8: Optional[bool] = False):
if use_fp8_w8a8:
return "fp8_w8a8"
elif use_int8_w8a16:
return "int8_w8a16"
elif dtype == torch.float:
# avoiding cases where kernel fails when float32 MoE
# use fp16/bfloat16 configs
return "float32"
return None
def fused_experts(hidden_states: torch.Tensor, def fused_experts(hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
...@@ -433,7 +462,8 @@ def fused_experts(hidden_states: torch.Tensor, ...@@ -433,7 +462,8 @@ def fused_experts(hidden_states: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
inplace: bool = False, inplace: bool = False,
override_config: Optional[Dict[str, Any]] = None, override_config: Optional[Dict[str, Any]] = None,
use_fp8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
...@@ -454,13 +484,16 @@ def fused_experts(hidden_states: torch.Tensor, ...@@ -454,13 +484,16 @@ def fused_experts(hidden_states: torch.Tensor,
# https://github.com/vllm-project/vllm/issues/5938 # https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
M = min(num_tokens, CHUNK_SIZE) M = min(num_tokens, CHUNK_SIZE)
config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
dtype=hidden_states.dtype)
get_config_func = functools.partial( get_config_func = functools.partial(
try_get_optimal_moe_config, try_get_optimal_moe_config,
w1.shape, w1.shape,
w2.shape, w2.shape,
topk_ids.shape[1], topk_ids.shape[1],
"float8" if use_fp8 else None, config_dtype,
override_config=override_config, override_config=override_config,
) )
...@@ -524,7 +557,8 @@ def fused_experts(hidden_states: torch.Tensor, ...@@ -524,7 +557,8 @@ def fused_experts(hidden_states: torch.Tensor,
topk_ids.shape[1], topk_ids.shape[1],
config, config,
compute_type=compute_type, compute_type=compute_type,
use_fp8=use_fp8) use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16)
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
...@@ -542,7 +576,8 @@ def fused_experts(hidden_states: torch.Tensor, ...@@ -542,7 +576,8 @@ def fused_experts(hidden_states: torch.Tensor,
1, 1,
config, config,
compute_type=compute_type, compute_type=compute_type,
use_fp8=use_fp8) use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16)
torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
dim=1, dim=1,
...@@ -562,7 +597,8 @@ def fused_moe( ...@@ -562,7 +597,8 @@ def fused_moe(
use_grouped_topk: bool = False, use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
use_fp8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
...@@ -588,7 +624,9 @@ def fused_moe( ...@@ -588,7 +624,9 @@ def fused_moe(
- topk_group: Optional[int]: additional parameter for grouped_topk - topk_group: Optional[int]: additional parameter for grouped_topk
- use_grouped_topk: If True, use grouped_topk instead of fused_topk - use_grouped_topk: If True, use grouped_topk instead of fused_topk
note: Deepseekv2 model uses grouped_topk note: Deepseekv2 model uses grouped_topk
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False. products for w1 and w2. Defaults to False.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for - w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1. w1.
...@@ -617,7 +655,8 @@ def fused_moe( ...@@ -617,7 +655,8 @@ def fused_moe(
topk_ids, topk_ids,
inplace=inplace, inplace=inplace,
override_config=override_config, override_config=override_config,
use_fp8=use_fp8, use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
w1_scale=w1_scale, w1_scale=w1_scale,
w2_scale=w2_scale, w2_scale=w2_scale,
a1_scale=a1_scale, a1_scale=a1_scale,
......
...@@ -24,15 +24,9 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -24,15 +24,9 @@ class FusedMoEMethodBase(QuantizeMethodBase):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def apply(self, def apply(self, layer: torch.nn.Module, x: torch.Tensor,
layer: torch.nn.Module, router_logits: torch.Tensor, top_k: int, renormalize: bool,
x: torch.Tensor, use_grouped_topk: bool) -> torch.Tensor:
router_logits: torch.Tensor,
top_k: int,
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
...@@ -61,66 +55,78 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -61,66 +55,78 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer.register_parameter("w2_weight", w2_weight) layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs) set_weight_attrs(w2_weight, extra_weight_attrs)
def apply( def apply(self,
self, layer: torch.nn.Module,
layer: torch.nn.Module, x: torch.Tensor,
x: torch.Tensor, router_logits: torch.Tensor,
router_logits: torch.Tensor, top_k: int,
top_k: int, renormalize: bool,
renormalize: bool = True, use_grouped_topk: bool,
use_grouped_topk: bool = False, topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None) -> torch.Tensor:
topk_group: Optional[int] = None,
) -> torch.Tensor: return self.forward(x=x,
return self.forward(x, layer.w13_weight, layer.w2_weight, layer=layer,
router_logits, top_k, renormalize, router_logits=router_logits,
use_grouped_topk, num_expert_group, topk_group) top_k=top_k,
renormalize=renormalize,
def forward_cuda( use_grouped_topk=use_grouped_topk,
self, topk_group=topk_group,
x: torch.Tensor, num_expert_group=num_expert_group)
w1: torch.Tensor,
w2: torch.Tensor, def forward_cuda(self,
router_logits: torch.Tensor, layer: torch.nn.Module,
top_k: int, x: torch.Tensor,
renormalize: bool, use_grouped_topk: bool,
use_grouped_topk: bool, top_k: int,
num_expert_group: Optional[int], router_logits: torch.Tensor,
topk_group: Optional[int], renormalize: bool,
) -> torch.Tensor: topk_group: Optional[int] = None,
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe num_expert_group: Optional[int] = None) -> torch.Tensor:
return fused_moe(x,
w1, from vllm.model_executor.layers.fused_moe.fused_moe import (
w2, fused_experts)
router_logits,
top_k, topk_weights, topk_ids = FusedMoE.select_experts(
renormalize=renormalize, hidden_states=x,
inplace=True, router_logits=router_logits,
use_grouped_topk=use_grouped_topk, use_grouped_topk=use_grouped_topk,
num_expert_group=num_expert_group, top_k=top_k,
topk_group=topk_group) renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group)
return fused_experts(hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True)
def forward_cpu(self, *args, **kwargs): def forward_cpu(self, *args, **kwargs):
raise NotImplementedError( raise NotImplementedError(
"The CPU backend currently does not support MoE.") "The CPU backend currently does not support MoE.")
def forward_tpu( def forward_tpu(self,
self, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
w1: torch.Tensor, use_grouped_topk: bool,
w2: torch.Tensor, top_k: int,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int, renormalize: bool,
renormalize: bool, topk_group: Optional[int] = None,
use_grouped_topk: bool, num_expert_group: Optional[int] = None) -> torch.Tensor:
num_expert_group: Optional[int],
topk_group: Optional[int],
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe
assert not use_grouped_topk assert not use_grouped_topk
assert num_expert_group is None assert num_expert_group is None
assert topk_group is None assert topk_group is None
return fused_moe(x, w1, w2, router_logits, top_k, renormalize) return fused_moe(hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk=top_k,
gating_output=router_logits,
renormalize=renormalize)
class FusedMoE(torch.nn.Module): class FusedMoE(torch.nn.Module):
...@@ -195,52 +201,83 @@ class FusedMoE(torch.nn.Module): ...@@ -195,52 +201,83 @@ class FusedMoE(torch.nn.Module):
def weight_loader(self, param: torch.nn.Parameter, def weight_loader(self, param: torch.nn.Parameter,
loaded_weight: torch.Tensor, weight_name: str, loaded_weight: torch.Tensor, weight_name: str,
shard_id: int, expert_id: int): shard_id: str, expert_id: int) -> None:
param_data = param.data if shard_id not in ("w1", "w2", "w3"):
raise ValueError(f"shard_id must be ['w1','w2','w3'] but "
# Input scales can be loaded directly and should be equal. f"got {shard_id}.")
if "input_scale" in weight_name:
if param_data[expert_id] != 1 and (param_data[expert_id] - # Special case for fp8 scales.
loaded_weight).abs() > 1e-5: if getattr(param, "is_fp8_scale", False):
raise ValueError( self._load_fp8_scale(param.data, loaded_weight, weight_name,
"input_scales of w1 and w3 of a layer " shard_id, expert_id)
f"must be equal. But got {param_data[expert_id]} " return
f"vs. {loaded_weight}")
param_data[expert_id] = loaded_weight expert_data = param.data[expert_id]
# Weight scales tp_rank = get_tensor_model_parallel_rank()
elif "weight_scale" in weight_name:
# If we are in merged column case (gate_up_proj) # If transposed, weight is saved as [input_dim, output_dim]
# shard_id 0 == gate_proj / w1 # Otherwise, weight is saved as [output_dim, input_dim]
# shard_id 2 == up_proj / w3 # Default is not transposed/input dim is dim 1
if shard_id == 0 or shard_id == 2: input_dim = getattr(param, "input_dim", 1)
# We have to keep the weight scales of w1 and w3 because output_dim = getattr(param, "output_dim", 0)
# we need to re-quantize w1/w3 weights after weight loading.
idx = 0 if shard_id == 0 else 1 # Index the loaded weight for tp sharding.
param_data[expert_id][idx] = loaded_weight # down_proj: "RowParallel" so tp sharding on input_dim
# If we are in the row parallel case (down_proj) if shard_id == "w2":
# shard_id 1 == down_proj / w2 shard_dim = input_dim
else: shard_size = expert_data.shape[shard_dim]
param_data[expert_id] = loaded_weight # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
# Weights elif shard_id in ("w1", "w3"):
shard_dim = output_dim
shard_size = expert_data.shape[output_dim] // 2
offset = shard_size * tp_rank
loaded_weight = loaded_weight.narrow(shard_dim, offset, shard_size)
# Narrow parameter and load.
# w1, gate_proj: Load into first logical weight of w13.
if shard_id == "w1":
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
expert_data.copy_(loaded_weight)
# w3, up_proj: Load into second logical weight of w13.
elif shard_id == "w3":
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
expert_data.copy_(loaded_weight)
# w2, down_proj: Load into only logical weight of w2.
elif shard_id == "w2":
expert_data.copy_(loaded_weight)
else: else:
tp_rank = get_tensor_model_parallel_rank() raise ValueError(
shard_size = self.intermediate_size_per_partition f"Expected shard_id w1,w2 or w3 but got {shard_id}")
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
@staticmethod
# w1, gate_proj case: Load into first shard of w13. def select_experts(hidden_states: torch.Tensor,
if shard_id == 0: router_logits: torch.Tensor,
param_data[expert_id, top_k: int,
0:shard_size, :] = loaded_weight[shard, :] use_grouped_topk: bool,
# w3, up_proj case: Load into second shard of w13. renormalize: bool,
elif shard_id == 2: topk_group: Optional[int] = None,
param_data[expert_id, shard_size:2 * num_expert_group: Optional[int] = None):
shard_size, :] = loaded_weight[shard, :] from vllm.model_executor.layers.fused_moe.fused_moe import (
# w2, down_proj case: Load into only shard of w2. fused_topk, grouped_topk)
elif shard_id == 1:
param_data[expert_id, :, :] = loaded_weight[:, shard] # DeekSeekv2 uses grouped_top_k
else: if use_grouped_topk:
raise ValueError( assert topk_group is not None
f"Shard id must be in [0,1,2] but got {shard_id}") assert num_expert_group is not None
topk_weights, topk_ids = grouped_topk(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group)
else:
topk_weights, topk_ids = fused_topk(hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize)
return topk_weights, topk_ids
def forward(self, hidden_states: torch.Tensor, def forward(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor): router_logits: torch.Tensor):
...@@ -248,14 +285,14 @@ class FusedMoE(torch.nn.Module): ...@@ -248,14 +285,14 @@ class FusedMoE(torch.nn.Module):
# Matrix multiply. # Matrix multiply.
final_hidden_states = self.quant_method.apply( final_hidden_states = self.quant_method.apply(
self, layer=self,
x=hidden_states, x=hidden_states,
router_logits=router_logits, router_logits=router_logits,
top_k=self.top_k, top_k=self.top_k,
renormalize=self.renormalize, renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk, use_grouped_topk=self.use_grouped_topk,
num_expert_group=self.num_expert_group, topk_group=self.topk_group,
topk_group=self.topk_group) num_expert_group=self.num_expert_group)
if self.reduce_results and self.tp_size > 1: if self.reduce_results and self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states = tensor_model_parallel_all_reduce(
...@@ -267,35 +304,42 @@ class FusedMoE(torch.nn.Module): ...@@ -267,35 +304,42 @@ class FusedMoE(torch.nn.Module):
def make_expert_params_mapping( def make_expert_params_mapping(
cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str, cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str,
ckpt_up_proj_name: str, ckpt_up_proj_name: str,
num_experts: int) -> List[Tuple[str, str, int, int]]: num_experts: int) -> List[Tuple[str, str, int, str]]:
gate_up = [ckpt_gate_proj_name, ckpt_up_proj_name]
gate_down_up = [
ckpt_gate_proj_name, ckpt_down_proj_name, ckpt_up_proj_name
]
return [ return [
# These are the weight scales for the experts
# (param_name, weight_name, expert_id, shard_id)
("experts.w13_scale"
if weight_name in gate_up else "experts.w2_scale",
f"experts.{expert_id}.{weight_name}.weight_scale", expert_id,
shard_id) for expert_id in range(num_experts)
for shard_id, weight_name in enumerate(gate_down_up)
] + [
# These are the weights for the experts
# (param_name, weight_name, expert_id, shard_id) # (param_name, weight_name, expert_id, shard_id)
("experts.w13_weight" ("experts.w13_" if weight_name
if weight_name in gate_up else "experts.w2_weight", in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_",
f"experts.{expert_id}.{weight_name}.weight", expert_id, shard_id) f"experts.{expert_id}.{weight_name}.", expert_id, shard_id)
for expert_id in range(num_experts) for expert_id in range(num_experts) for shard_id, weight_name in [
for shard_id, weight_name in enumerate(gate_down_up) ("w1", ckpt_gate_proj_name),
] + [ ("w2", ckpt_down_proj_name),
# These are the weight scales for the experts ("w3", ckpt_up_proj_name),
# (param_name, weight_name, expert_id, shard_id) ]
("experts.a13_scale"
if weight_name in gate_up else "experts.a2_scale",
f"experts.{expert_id}.{weight_name}.input_scale", expert_id,
shard_id) for expert_id in range(num_experts)
for shard_id, weight_name in enumerate(gate_down_up)
] ]
def _load_fp8_scale(self, param: torch.nn.Parameter,
loaded_weight: torch.Tensor, weight_name: str,
shard_id: str, expert_id: int) -> None:
param_data = param.data
# Input scales can be loaded directly and should be equal.
if "input_scale" in weight_name:
if param_data[expert_id] != 1 and (param_data[expert_id] -
loaded_weight).abs() > 1e-5:
raise ValueError(
"input_scales of w1 and w3 of a layer "
f"must be equal. But got {param_data[expert_id]} "
f"vs. {loaded_weight}")
param_data[expert_id] = loaded_weight
# Weight scales
elif "weight_scale" in weight_name:
# If we are in merged column case (gate_up_proj)
if shard_id in ("w1", "w3"):
# We have to keep the weight scales of w1 and w3 because
# we need to re-quantize w1/w3 weights after weight loading.
idx = 0 if shard_id == "w1" else 1
param_data[expert_id][idx] = loaded_weight
# If we are in the row parallel case (down_proj)
else:
param_data[expert_id] = loaded_weight
...@@ -131,10 +131,12 @@ class GemmaRMSNorm(CustomOp): ...@@ -131,10 +131,12 @@ class GemmaRMSNorm(CustomOp):
self.weight = nn.Parameter(torch.zeros(hidden_size)) self.weight = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = eps self.variance_epsilon = eps
def forward_native( @staticmethod
self, def forward_static(
weight: torch.Tensor,
variance_epsilon: float,
x: torch.Tensor, x: torch.Tensor,
residual: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor],
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward().""" """PyTorch-native implementation equivalent to forward()."""
orig_dtype = x.dtype orig_dtype = x.dtype
...@@ -144,17 +146,32 @@ class GemmaRMSNorm(CustomOp): ...@@ -144,17 +146,32 @@ class GemmaRMSNorm(CustomOp):
x = x.float() x = x.float()
variance = x.pow(2).mean(dim=-1, keepdim=True) variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon) x = x * torch.rsqrt(variance + variance_epsilon)
# Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16) # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
# See https://github.com/huggingface/transformers/pull/29402 # See https://github.com/huggingface/transformers/pull/29402
x = x * (1.0 + self.weight.float()) x = x * (1.0 + weight.float())
x = x.to(orig_dtype) x = x.to(orig_dtype)
return x if residual is None else (x, residual) return x if residual is None else (x, residual)
def forward_native(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward()."""
return self.forward_static(self.weight.data, self.variance_epsilon, x,
residual)
def forward_cuda( def forward_cuda(
self, self,
x: torch.Tensor, x: torch.Tensor,
residual: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
# TODO(woosuk): Implement an optimized kernel for GemmaRMSNorm. if torch.compiler.is_compiling():
return self.forward_native(x, residual)
if not getattr(self, "_is_compiled", False):
self.forward_static = torch.compile( # type: ignore
self.forward_static)
self._is_compiled = True
return self.forward_native(x, residual) return self.forward_native(x, residual)
...@@ -3,7 +3,7 @@ from typing import Dict, List, Optional, Tuple ...@@ -3,7 +3,7 @@ from typing import Dict, List, Optional, Tuple
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter, UninitializedParameter
from vllm.distributed import (divide, get_tensor_model_parallel_rank, from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
...@@ -13,6 +13,9 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank, ...@@ -13,6 +13,9 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.parameter import (BasevLLMParameter,
PackedvLLMParameter,
PerTensorScaleParameter)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
import os import os
...@@ -20,6 +23,12 @@ from vllm.model_executor.utils import gemm_bank_conf ...@@ -20,6 +23,12 @@ from vllm.model_executor.utils import gemm_bank_conf
logger = init_logger(__name__) logger = init_logger(__name__)
WEIGHT_LOADER_V2_SUPPORTED = [
"CompressedTensorsLinearMethod", "AWQMarlinLinearMethod",
"AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod",
"MarlinLinearMethod"
]
def adjust_marlin_shard(param, shard_size, shard_offset): def adjust_marlin_shard(param, shard_size, shard_offset):
marlin_tile_size = getattr(param, "marlin_tile_size", None) marlin_tile_size = getattr(param, "marlin_tile_size", None)
...@@ -307,6 +316,7 @@ class ColumnParallelLinear(LinearBase): ...@@ -307,6 +316,7 @@ class ColumnParallelLinear(LinearBase):
if output_sizes is None: if output_sizes is None:
output_sizes = [output_size] output_sizes = [output_size]
self.quant_method.create_weights( self.quant_method.create_weights(
layer=self, layer=self,
input_size_per_partition=self.input_size, input_size_per_partition=self.input_size,
...@@ -314,7 +324,9 @@ class ColumnParallelLinear(LinearBase): ...@@ -314,7 +324,9 @@ class ColumnParallelLinear(LinearBase):
input_size=self.input_size, input_size=self.input_size,
output_size=self.output_size, output_size=self.output_size,
params_dtype=self.params_dtype, params_dtype=self.params_dtype,
weight_loader=self.weight_loader, weight_loader=(
self.weight_loader_v2 if self.quant_method.__class__.__name__
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader),
prefix=prefix) prefix=prefix)
if bias: if bias:
self.bias = Parameter( self.bias = Parameter(
...@@ -330,6 +342,17 @@ class ColumnParallelLinear(LinearBase): ...@@ -330,6 +342,17 @@ class ColumnParallelLinear(LinearBase):
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
output_dim = getattr(param, "output_dim", None) output_dim = getattr(param, "output_dim", None)
# Special case for GGUF
is_gguf_weight = getattr(param, "is_gguf_weight", False)
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
if is_gguf_weight_type:
param.weight_type = loaded_weight.item()
# Materialize GGUF UninitializedParameter
if is_gguf_weight and isinstance(param, UninitializedParameter):
param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
param_data = param.data param_data = param.data
if output_dim is not None: if output_dim is not None:
shard_size = param_data.shape[output_dim] shard_size = param_data.shape[output_dim]
...@@ -345,6 +368,14 @@ class ColumnParallelLinear(LinearBase): ...@@ -345,6 +368,14 @@ class ColumnParallelLinear(LinearBase):
assert param_data.shape == loaded_weight.shape assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight) param_data.copy_(loaded_weight)
def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
# Special case for loading scales off disk, which often do not
# have a shape (such as in the case of AutoFP8).
if len(loaded_weight.shape) == 0:
assert loaded_weight.numel() == 1
loaded_weight = loaded_weight.reshape(1)
param.load_column_parallel_weight(loaded_weight=loaded_weight)
def forward(self, input_): def forward(self, input_):
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
...@@ -417,6 +448,27 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -417,6 +448,27 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
loaded_weight: torch.Tensor, loaded_weight: torch.Tensor,
loaded_shard_id: Optional[int] = None): loaded_shard_id: Optional[int] = None):
# Special case for GGUF
# initialize GGUF param after we know the quantize type
is_gguf_weight = getattr(param, "is_gguf_weight", False)
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
if is_gguf_weight_type:
param.data[loaded_shard_id].copy_(loaded_weight)
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
return
if is_gguf_weight and isinstance(param, UninitializedParameter):
from gguf.constants import GGML_QUANT_SIZES
ori_shape = param.tensor_shape
weight_types = self.qweight_type.shard_weight_type.values()
row_size = []
for weight_type in weight_types:
block_size, type_size = GGML_QUANT_SIZES[weight_type]
row_size.append(ori_shape[1] // block_size * type_size)
q_shape = (ori_shape[0], max(row_size))
param.materialize(q_shape, dtype=loaded_weight.dtype)
param_data = param.data param_data = param.data
output_dim = getattr(param, "output_dim", None) output_dim = getattr(param, "output_dim", None)
# Special case for AQLM codebooks. # Special case for AQLM codebooks.
...@@ -479,6 +531,18 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -479,6 +531,18 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_offset = loaded_weight.shape[output_dim] * \ shard_offset = loaded_weight.shape[output_dim] * \
loaded_shard_id loaded_shard_id
if is_gguf_weight:
tp_size = get_tensor_model_parallel_world_size()
output_dim = getattr(param, "output_dim", None)
shard_shape = list(loaded_weight.shape)
shard_shape[output_dim] = shard_shape[output_dim] // tp_size
param.shard_id.append(loaded_shard_id)
param.shard_size[loaded_shard_id] = shard_shape
input_dim = getattr(param, "input_dim", None)
input_size = loaded_weight.shape[input_dim]
param_data = param_data.narrow(input_dim, 0, input_size)
param_data = param_data.narrow(output_dim, shard_offset, param_data = param_data.narrow(output_dim, shard_offset,
shard_size) shard_size)
start_idx = tp_rank * shard_size start_idx = tp_rank * shard_size
...@@ -507,6 +571,65 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -507,6 +571,65 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
assert param_data.shape == loaded_weight.shape assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight) param_data.copy_(loaded_weight)
def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
loaded_weight: torch.Tensor):
"""
Handle special case for models where MLP layers are already
fused on disk. In this case, we have no shard id. This function
determmines the shard id by splitting these layers and then calls
the weight loader using the shard id.
An example of a model with these fused layers:
https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
"""
current_shard_offset = 0
shard_offsets: List[Tuple[int, int, int]] = []
for i, output_size in enumerate(self.output_sizes):
shard_offsets.append((i, current_shard_offset, output_size))
current_shard_offset += output_size
for shard_id, shard_offset, shard_size in shard_offsets:
# Special case for Quantization.
# If quantized, we need to adjust the offset and size to account
# for the packing.
if isinstance(param, PackedvLLMParameter
) and param.packed_dim == param.output_dim:
shard_size, shard_offset = \
param.adjust_shard_indexes_for_packing(
shard_size=shard_size, shard_offset=shard_offset)
loaded_weight_shard = loaded_weight.narrow(param.output_dim,
shard_offset,
shard_size)
self.weight_loader_v2(param, loaded_weight_shard, shard_id)
def weight_loader_v2(self,
param: BasevLLMParameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[int] = None):
if loaded_shard_id is None:
if isinstance(param, PerTensorScaleParameter):
param.load_merged_column_weight(loaded_weight=loaded_weight,
shard_id=0)
return
elif type(param) is BasevLLMParameter:
param.load_merged_column_weight(loaded_weight=loaded_weight)
return
self._load_fused_module_from_checkpoint(param, loaded_weight)
return
assert loaded_shard_id < len(self.output_sizes)
tp_size = get_tensor_model_parallel_world_size()
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
shard_size = self.output_sizes[loaded_shard_id] // tp_size
param.load_merged_column_weight(loaded_weight=loaded_weight,
shard_id=loaded_shard_id,
shard_offset=shard_offset,
shard_size=shard_size)
class QKVParallelLinear(ColumnParallelLinear): class QKVParallelLinear(ColumnParallelLinear):
"""Linear layers for the attention's QKV transformation. """Linear layers for the attention's QKV transformation.
...@@ -578,10 +701,112 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -578,10 +701,112 @@ class QKVParallelLinear(ColumnParallelLinear):
quant_config=quant_config, quant_config=quant_config,
prefix=prefix) prefix=prefix)
def _get_shard_offset_mapping(self, loaded_shard_id: str):
shard_offset_mapping = {
"q": 0,
"k": self.num_heads * self.head_size,
"v": (self.num_heads + self.num_kv_heads) * self.head_size,
"total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size
}
return shard_offset_mapping.get(loaded_shard_id)
def _get_shard_size_mapping(self, loaded_shard_id: str):
shard_size_mapping = {
"q": self.num_heads * self.head_size,
"k": self.num_kv_heads * self.head_size,
"v": self.num_kv_heads * self.head_size,
}
return shard_size_mapping.get(loaded_shard_id)
def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
loaded_weight: torch.Tensor):
"""
Handle special case for models where QKV layers are already
fused on disk. In this case, we have no shard id. This function
determmines the shard id by splitting these layers and then calls
the weight loader using the shard id.
An example of a model with these fused layers:
https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
"""
shard_offsets = [
# (shard_id, shard_offset, shard_size)
("q", 0, self.total_num_heads * self.head_size),
("k", self.total_num_heads * self.head_size,
self.total_num_kv_heads * self.head_size),
("v",
(self.total_num_heads + self.total_num_kv_heads) * self.head_size,
self.total_num_kv_heads * self.head_size),
]
for shard_id, shard_offset, shard_size in shard_offsets:
# Special case for Quantization.
# If quantized, we need to adjust the offset and size to account
# for the packing.
if isinstance(param, PackedvLLMParameter
) and param.packed_dim == param.output_dim:
shard_size, shard_offset = \
param.adjust_shard_indexes_for_packing(
shard_size=shard_size, shard_offset=shard_offset)
loaded_weight_shard = loaded_weight.narrow(param.output_dim,
shard_offset,
shard_size)
self.weight_loader_v2(param, loaded_weight_shard, shard_id)
def weight_loader_v2(self,
param: BasevLLMParameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None):
if loaded_shard_id is None: # special case for certain models
if isinstance(param, PerTensorScaleParameter):
param.load_merged_column_weight(loaded_weight=loaded_weight,
shard_id=0)
return
elif type(param) is BasevLLMParameter:
param.load_merged_column_weight(loaded_weight=loaded_weight)
return
self._load_fused_module_from_checkpoint(param, loaded_weight)
return
assert loaded_shard_id in ["q", "k", "v"]
shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
shard_size = self._get_shard_size_mapping(loaded_shard_id)
param.load_qkv_weight(loaded_weight=loaded_weight,
num_heads=self.num_kv_head_replicas,
shard_id=loaded_shard_id,
shard_offset=shard_offset,
shard_size=shard_size)
def weight_loader(self, def weight_loader(self,
param: Parameter, param: Parameter,
loaded_weight: torch.Tensor, loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None): loaded_shard_id: Optional[str] = None):
# Special case for GGUF
# initialize GGUF param after we know the quantize type
is_gguf_weight = getattr(param, "is_gguf_weight", False)
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
if is_gguf_weight_type and loaded_shard_id is not None:
idx_map = {"q": 0, "k": 1, "v": 2}
param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
return
if is_gguf_weight and isinstance(param, UninitializedParameter):
from gguf.constants import GGML_QUANT_SIZES
ori_shape = param.tensor_shape
weight_types = self.qweight_type.shard_weight_type.values()
row_size = []
for weight_type in weight_types:
block_size, type_size = GGML_QUANT_SIZES[weight_type]
row_size.append(ori_shape[1] // block_size * type_size)
q_shape = (ori_shape[0], max(row_size))
param.materialize(q_shape, dtype=loaded_weight.dtype)
param_data = param.data param_data = param.data
output_dim = getattr(param, "output_dim", None) output_dim = getattr(param, "output_dim", None)
# Special case for AQLM codebooks. # Special case for AQLM codebooks.
...@@ -669,6 +894,18 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -669,6 +894,18 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_size, shard_offset = adjust_bitsandbytes_shard( shard_size, shard_offset = adjust_bitsandbytes_shard(
param, orig_qkv_offsets, loaded_shard_id) param, orig_qkv_offsets, loaded_shard_id)
if is_gguf_weight:
tp_size = get_tensor_model_parallel_world_size()
output_dim = getattr(param, "output_dim", None)
shard_shape = list(loaded_weight.shape)
shard_shape[output_dim] = shard_shape[output_dim] // tp_size
param.shard_id.append(loaded_shard_id)
param.shard_size[loaded_shard_id] = shard_shape
input_dim = getattr(param, "input_dim", None)
input_size = loaded_weight.shape[input_dim]
param_data = param_data.narrow(input_dim, 0, input_size)
param_data = param_data.narrow(output_dim, shard_offset, param_data = param_data.narrow(output_dim, shard_offset,
shard_size) shard_size)
if loaded_shard_id == "q": if loaded_shard_id == "q":
...@@ -748,6 +985,7 @@ class RowParallelLinear(LinearBase): ...@@ -748,6 +985,7 @@ class RowParallelLinear(LinearBase):
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, self.tp_size) self.input_size_per_partition = divide(input_size, self.tp_size)
assert self.quant_method is not None assert self.quant_method is not None
self.quant_method.create_weights( self.quant_method.create_weights(
layer=self, layer=self,
input_size_per_partition=self.input_size_per_partition, input_size_per_partition=self.input_size_per_partition,
...@@ -755,7 +993,9 @@ class RowParallelLinear(LinearBase): ...@@ -755,7 +993,9 @@ class RowParallelLinear(LinearBase):
input_size=self.input_size, input_size=self.input_size,
output_size=self.output_size, output_size=self.output_size,
params_dtype=self.params_dtype, params_dtype=self.params_dtype,
weight_loader=self.weight_loader, weight_loader=(
self.weight_loader_v2 if self.quant_method.__class__.__name__
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader),
prefix=prefix) prefix=prefix)
if not reduce_results and (bias and not skip_bias_add): if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the " raise ValueError("When not reduce the results, adding bias to the "
...@@ -773,7 +1013,22 @@ class RowParallelLinear(LinearBase): ...@@ -773,7 +1013,22 @@ class RowParallelLinear(LinearBase):
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
input_dim = getattr(param, "input_dim", None) input_dim = getattr(param, "input_dim", None)
# Special case for GGUF
is_gguf_weight = getattr(param, "is_gguf_weight", False)
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
if is_gguf_weight_type:
param.weight_type = loaded_weight.item()
# Materialize GGUF UninitializedParameter
if is_gguf_weight and isinstance(param, UninitializedParameter):
weight_shape = list(loaded_weight.shape)
if input_dim:
weight_shape[input_dim] = weight_shape[input_dim] // tp_size
param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
param_data = param.data param_data = param.data
if input_dim is not None: if input_dim is not None:
shard_size = param_data.shape[input_dim] shard_size = param_data.shape[input_dim]
...@@ -789,6 +1044,17 @@ class RowParallelLinear(LinearBase): ...@@ -789,6 +1044,17 @@ class RowParallelLinear(LinearBase):
assert param_data.shape == loaded_weight.shape assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight) param_data.copy_(loaded_weight)
def weight_loader_v2(self, param: BasevLLMParameter,
loaded_weight: torch.Tensor):
# Special case for loading scales off disk, which often do not
# have a shape (such as in the case of AutoFP8).
if len(loaded_weight.shape) == 0:
assert loaded_weight.numel() == 1
loaded_weight = loaded_weight.reshape(1)
param.load_row_parallel_weight(loaded_weight=loaded_weight)
def forward(self, input_): def forward(self, input_):
if self.input_is_parallel: if self.input_is_parallel:
input_parallel = input_ input_parallel = input_
......
...@@ -50,7 +50,7 @@ class LogitsProcessor(nn.Module): ...@@ -50,7 +50,7 @@ class LogitsProcessor(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
embedding_bias: Optional[torch.Tensor] = None, embedding_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> Optional[torch.Tensor]:
if self.logits_as_input: if self.logits_as_input:
logits = hidden_states logits = hidden_states
else: else:
...@@ -73,14 +73,18 @@ class LogitsProcessor(nn.Module): ...@@ -73,14 +73,18 @@ class LogitsProcessor(nn.Module):
return logits return logits
def _get_logits(self, hidden_states: torch.Tensor, def _get_logits(
lm_head: VocabParallelEmbedding, self,
embedding_bias: Optional[torch.Tensor]) -> torch.Tensor: hidden_states: torch.Tensor,
lm_head: VocabParallelEmbedding,
embedding_bias: Optional[torch.Tensor],
) -> Optional[torch.Tensor]:
# Get the logits for the next tokens. # Get the logits for the next tokens.
logits = lm_head.linear_method.apply(lm_head, logits = lm_head.linear_method.apply(lm_head,
hidden_states, hidden_states,
bias=embedding_bias) bias=embedding_bias)
if self.use_gather: if self.use_gather:
# None may be returned for rank > 0
logits = tensor_model_parallel_gather(logits) logits = tensor_model_parallel_gather(logits)
else: else:
# Gather is not supported for some devices such as TPUs. # Gather is not supported for some devices such as TPUs.
...@@ -91,7 +95,7 @@ class LogitsProcessor(nn.Module): ...@@ -91,7 +95,7 @@ class LogitsProcessor(nn.Module):
logits = tensor_model_parallel_all_gather(logits) logits = tensor_model_parallel_all_gather(logits)
# Remove paddings in vocab (if any). # Remove paddings in vocab (if any).
if logits is not None: if logits is not None:
logits = logits[:, :self.org_vocab_size] logits = logits[..., :self.org_vocab_size]
return logits return logits
def extra_repr(self) -> str: def extra_repr(self) -> str:
......
...@@ -11,8 +11,11 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso ...@@ -11,8 +11,11 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
CompressedTensorsConfig) CompressedTensorsConfig)
from vllm.model_executor.layers.quantization.deepspeedfp import ( from vllm.model_executor.layers.quantization.deepspeedfp import (
DeepSpeedFPConfig) DeepSpeedFPConfig)
from vllm.model_executor.layers.quantization.experts_int8 import (
ExpertsInt8Config)
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
from vllm.model_executor.layers.quantization.fp8 import Fp8Config from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.layers.quantization.gguf import GGUFConfig
from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import ( from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinConfig) GPTQMarlinConfig)
...@@ -21,16 +24,19 @@ from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( ...@@ -21,16 +24,19 @@ from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
from vllm.model_executor.layers.quantization.marlin import MarlinConfig from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.layers.quantization.qqq import QQQConfig from vllm.model_executor.layers.quantization.qqq import QQQConfig
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"aqlm": AQLMConfig, "aqlm": AQLMConfig,
"awq": AWQConfig, "awq": AWQConfig,
"deepspeedfp": DeepSpeedFPConfig, "deepspeedfp": DeepSpeedFPConfig,
"tpu_int8": Int8TpuConfig,
"fp8": Fp8Config, "fp8": Fp8Config,
"fbgemm_fp8": FBGEMMFp8Config, "fbgemm_fp8": FBGEMMFp8Config,
# The order of gptq methods is important for config.py iteration over # The order of gptq methods is important for config.py iteration over
# override_quantization_method(..) # override_quantization_method(..)
"marlin": MarlinConfig, "marlin": MarlinConfig,
"gguf": GGUFConfig,
"gptq_marlin_24": GPTQMarlin24Config, "gptq_marlin_24": GPTQMarlin24Config,
"gptq_marlin": GPTQMarlinConfig, "gptq_marlin": GPTQMarlinConfig,
"awq_marlin": AWQMarlinConfig, "awq_marlin": AWQMarlinConfig,
...@@ -39,6 +45,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { ...@@ -39,6 +45,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"compressed-tensors": CompressedTensorsConfig, "compressed-tensors": CompressedTensorsConfig,
"bitsandbytes": BitsAndBytesConfig, "bitsandbytes": BitsAndBytesConfig,
"qqq": QQQConfig, "qqq": QQQConfig,
"experts_int8": ExpertsInt8Config,
} }
......
...@@ -95,7 +95,7 @@ def generic_dequantize_gemm( ...@@ -95,7 +95,7 @@ def generic_dequantize_gemm(
codebooks: torch. codebooks: torch.
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
scales: torch.Tensor, # [num_out_groups, 1, 1, 1] scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
output_partition_sizes: torch.IntTensor, output_partition_sizes: List[int],
bias: Optional[torch.Tensor], bias: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
output_shape = input.shape[:-1] + (scales.shape[0], ) output_shape = input.shape[:-1] + (scales.shape[0], )
...@@ -133,7 +133,7 @@ def optimized_dequantize_gemm( ...@@ -133,7 +133,7 @@ def optimized_dequantize_gemm(
codebooks: torch. codebooks: torch.
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
scales: torch.Tensor, # [num_out_groups, 1, 1, 1] scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
output_partition_sizes: torch.IntTensor, output_partition_sizes: List[int],
bias: Optional[torch.Tensor], bias: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes) weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
...@@ -288,10 +288,8 @@ class AQLMLinearMethod(LinearMethodBase): ...@@ -288,10 +288,8 @@ class AQLMLinearMethod(LinearMethodBase):
codebooks, codebooks,
{ {
# metadata indicates fixed size concatenated along dim 0 # metadata indicates fixed size concatenated along dim 0
"is_metadata": "is_metadata": True,
True, "output_partition_sizes": output_partition_sizes
"output_partition_sizes":
torch.tensor(output_partition_sizes, device='cpu'),
}, },
) )
...@@ -334,7 +332,7 @@ class AQLMLinearMethod(LinearMethodBase): ...@@ -334,7 +332,7 @@ class AQLMLinearMethod(LinearMethodBase):
codes = layer.codes codes = layer.codes
scales = layer.scales scales = layer.scales
output_partition_sizes = getattr(codebooks, "output_partition_sizes", output_partition_sizes = getattr(codebooks, "output_partition_sizes",
None) [])
nbooks = codes.shape[2] nbooks = codes.shape[2]
ingroups = codebooks.shape[3] ingroups = codebooks.shape[3]
......
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import torch import torch
from torch.nn.parameter import Parameter
import torch.nn.functional as F import torch.nn.functional as F
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.parameter import (GroupQuantScaleParameter,
PackedvLLMParameter)
class AWQShareWorkSpace: class AWQShareWorkSpace:
...@@ -117,70 +117,64 @@ class AWQLinearMethod(LinearMethodBase): ...@@ -117,70 +117,64 @@ class AWQLinearMethod(LinearMethodBase):
"weight shape. This can be caused by too large " "weight shape. This can be caused by too large "
"tensor parallel size.") "tensor parallel size.")
qweight = Parameter( weight_loader = extra_weight_attrs.get("weight_loader")
torch.empty( qweight = PackedvLLMParameter(
data=torch.empty(
input_size_per_partition, input_size_per_partition,
output_size_per_partition // self.quant_config.pack_factor, output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32, dtype=torch.int32,
), ),
requires_grad=False, input_dim=0,
) output_dim=1,
set_weight_attrs( packed_dim=1,
qweight, { packed_factor=self.quant_config.pack_factor,
"input_dim": 0, weight_loader=weight_loader)
"output_dim": 1,
"packed_dim": 1, qzeros = PackedvLLMParameter(
"pack_factor": self.quant_config.pack_factor, data=torch.empty(
})
qzeros = Parameter(
torch.empty(
input_size_per_partition // self.quant_config.group_size, input_size_per_partition // self.quant_config.group_size,
output_size_per_partition // self.quant_config.pack_factor, output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32, dtype=torch.int32,
), ),
requires_grad=False, input_dim=0,
) output_dim=1,
set_weight_attrs( packed_dim=1,
qzeros, { packed_factor=self.quant_config.pack_factor,
"input_dim": 0, weight_loader=weight_loader)
"output_dim": 1,
"packed_dim": 1, scales = GroupQuantScaleParameter(data=torch.empty(
"pack_factor": self.quant_config.pack_factor, input_size_per_partition // self.quant_config.group_size,
}) output_size_per_partition,
scales = Parameter( dtype=params_dtype,
torch.empty( ),
input_size_per_partition // self.quant_config.group_size, input_dim=0,
output_size_per_partition, output_dim=1,
dtype=params_dtype, weight_loader=weight_loader)
),
requires_grad=False,
)
set_weight_attrs(scales, {
"input_dim": 0,
"output_dim": 1,
})
zeros_and_scales=Parameter( zeros_and_scales = GroupQuantScaleParameter(data=torch.empty(
torch.empty( input_size_per_partition // self.quant_config.group_size,
(input_size_per_partition // self.quant_config.group_size), output_size_per_partition,
output_size_per_partition, dtype=params_dtype,
dtype=torch.int32, ),
), input_dim=0,
requires_grad=False, output_dim=1,
) weight_loader=weight_loader)
set_weight_attrs(zeros_and_scales, {
"input_dim": 0,
"output_dim": 1,
})
layer.register_parameter("qweight", qweight) layer.register_parameter("qweight", qweight)
set_weight_attrs(qweight, extra_weight_attrs)
layer.register_parameter("qzeros", qzeros) layer.register_parameter("qzeros", qzeros)
set_weight_attrs(qzeros, extra_weight_attrs)
layer.register_parameter("scales", scales) layer.register_parameter("scales", scales)
set_weight_attrs(scales, extra_weight_attrs)
layer.register_parameter("zeros_and_scales", zeros_and_scales) layer.register_parameter("zeros_and_scales", zeros_and_scales)
set_weight_attrs(zeros_and_scales, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.qweight = torch.nn.Parameter(layer.qweight.data,
requires_grad=False)
layer.qzeros = torch.nn.Parameter(layer.qzeros.data,
requires_grad=False)
layer.scales = torch.nn.Parameter(layer.scales.data,
requires_grad=False)
layer.zeros_and_scales = torch.nn.Parameter(layer.zeros_and_scales.data,
requires_grad=False)
def apply(self, def apply(self,
layer: torch.nn.Module, layer: torch.nn.Module,
......
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import torch import torch
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
...@@ -14,6 +12,8 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import ( ...@@ -14,6 +12,8 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales, marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales,
replace_tensor, verify_marlin_supported, verify_marlin_supports_shape) replace_tensor, verify_marlin_supported, verify_marlin_supports_shape)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
PackedvLLMParameter)
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -126,8 +126,7 @@ class AWQMarlinConfig(QuantizationConfig): ...@@ -126,8 +126,7 @@ class AWQMarlinConfig(QuantizationConfig):
return check_marlin_supported(quant_type=cls.TYPE_MAP[num_bits], return check_marlin_supported(quant_type=cls.TYPE_MAP[num_bits],
group_size=group_size, group_size=group_size,
has_zp=has_zp, has_zp=has_zp)
min_capability=cls.get_min_capability())
class AWQMarlinLinearMethod(LinearMethodBase): class AWQMarlinLinearMethod(LinearMethodBase):
...@@ -152,6 +151,7 @@ class AWQMarlinLinearMethod(LinearMethodBase): ...@@ -152,6 +151,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
) -> None: ) -> None:
del output_size del output_size
output_size_per_partition = sum(output_partition_sizes) output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
# Normalize group_size # Normalize group_size
if self.quant_config.group_size != -1: if self.quant_config.group_size != -1:
...@@ -165,59 +165,44 @@ class AWQMarlinLinearMethod(LinearMethodBase): ...@@ -165,59 +165,44 @@ class AWQMarlinLinearMethod(LinearMethodBase):
input_size=input_size, input_size=input_size,
group_size=group_size) group_size=group_size)
qweight = Parameter( qweight = PackedvLLMParameter(
torch.empty( data=torch.empty(
input_size_per_partition, input_size_per_partition,
output_size_per_partition // self.quant_config.pack_factor, output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32, dtype=torch.int32,
), ),
requires_grad=False, input_dim=0,
) output_dim=1,
set_weight_attrs( packed_dim=1,
qweight, { packed_factor=self.quant_config.pack_factor,
"input_dim": 0, weight_loader=weight_loader)
"output_dim": 1,
"packed_dim": 1,
"pack_factor": self.quant_config.pack_factor,
})
num_groups = input_size_per_partition // group_size num_groups = input_size_per_partition // group_size
qzeros = Parameter( qzeros = PackedvLLMParameter(
torch.empty( data=torch.empty(
num_groups, num_groups,
output_size_per_partition // self.quant_config.pack_factor, output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32, dtype=torch.int32,
), ),
requires_grad=False, input_dim=0,
) output_dim=1,
set_weight_attrs( packed_dim=1,
qzeros, { packed_factor=self.quant_config.pack_factor,
"input_dim": 0, weight_loader=weight_loader)
"output_dim": 1,
"packed_dim": 1, scales = GroupQuantScaleParameter(data=torch.empty(
"pack_factor": self.quant_config.pack_factor, num_groups,
}) output_size_per_partition,
dtype=params_dtype,
scales = Parameter( ),
torch.empty( input_dim=0,
num_groups, output_dim=1,
output_size_per_partition, weight_loader=weight_loader)
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(scales, {
"input_dim": 0,
"output_dim": 1,
})
layer.register_parameter("qweight", qweight) layer.register_parameter("qweight", qweight)
set_weight_attrs(qweight, extra_weight_attrs)
layer.register_parameter("qzeros", qzeros) layer.register_parameter("qzeros", qzeros)
set_weight_attrs(qzeros, extra_weight_attrs)
layer.register_parameter("scales", scales) layer.register_parameter("scales", scales)
set_weight_attrs(scales, extra_weight_attrs)
layer.input_size_per_partition = input_size_per_partition layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition layer.output_size_per_partition = output_size_per_partition
...@@ -229,6 +214,12 @@ class AWQMarlinLinearMethod(LinearMethodBase): ...@@ -229,6 +214,12 @@ class AWQMarlinLinearMethod(LinearMethodBase):
# Here, we handle the repacking # Here, we handle the repacking
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
device = layer.qweight.device device = layer.qweight.device
layer.qweight = torch.nn.Parameter(layer.qweight.data,
requires_grad=False)
layer.qzeros = torch.nn.Parameter(layer.qzeros.data,
requires_grad=False)
layer.scales = torch.nn.Parameter(layer.scales.data,
requires_grad=False)
# Allocate marlin workspace # Allocate marlin workspace
layer.workspace = marlin_make_workspace( layer.workspace = marlin_make_workspace(
...@@ -279,4 +270,4 @@ class AWQMarlinLinearMethod(LinearMethodBase): ...@@ -279,4 +270,4 @@ class AWQMarlinLinearMethod(LinearMethodBase):
quant_type=self.quant_config.quant_type, quant_type=self.quant_config.quant_type,
output_size_per_partition=layer.output_size_per_partition, output_size_per_partition=layer.output_size_per_partition,
input_size_per_partition=layer.input_size_per_partition, input_size_per_partition=layer.input_size_per_partition,
bias=bias) bias=bias)
\ No newline at end of file
import inspect
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional, Type
import torch import torch
from torch import nn from torch import nn
...@@ -23,6 +24,14 @@ class QuantizeMethodBase(ABC): ...@@ -23,6 +24,14 @@ class QuantizeMethodBase(ABC):
Expects create_weights to have been called before on the layer.""" Expects create_weights to have been called before on the layer."""
raise NotImplementedError raise NotImplementedError
# Not required functions
def embedding(self, layer: torch.nn.Module, *args,
**kwargs) -> torch.Tensor:
"""Gather embeddings in the layer based on indices in the input tensor.
Expects create_weights to have been called before on the layer."""
raise NotImplementedError
def process_weights_after_loading(self, layer: nn.Module) -> None: def process_weights_after_loading(self, layer: nn.Module) -> None:
"""Process the weight after loading. """Process the weight after loading.
...@@ -31,6 +40,21 @@ class QuantizeMethodBase(ABC): ...@@ -31,6 +40,21 @@ class QuantizeMethodBase(ABC):
return return
def method_has_implemented_embedding(
method_class: Type[QuantizeMethodBase]) -> bool:
"""
Not all quant methods have embedding implemented, so we need to check that
it exists for our given method. We check this by making sure the function
has been changed from the base implementation.
"""
base_embedding = inspect.getattr_static(QuantizeMethodBase, "embedding",
None)
class_embedding = inspect.getattr_static(method_class, "embedding", None)
return (class_embedding is not None
and class_embedding is not base_embedding)
class QuantizationConfig(ABC): class QuantizationConfig(ABC):
"""Base class for quantization configs.""" """Base class for quantization configs."""
......
...@@ -19,6 +19,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( ...@@ -19,6 +19,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.platforms import current_platform from vllm.platforms import current_platform
__all__ = ["CompressedTensorsLinearMethod"]
class CompressedTensorsConfig(QuantizationConfig): class CompressedTensorsConfig(QuantizationConfig):
...@@ -146,18 +148,15 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -146,18 +148,15 @@ class CompressedTensorsConfig(QuantizationConfig):
if weight_quant is None or input_quant is None: if weight_quant is None or input_quant is None:
return False return False
# Confirm we have floating points.
if not (weight_quant.type == QuantizationType.FLOAT
and input_quant.type == QuantizationType.FLOAT):
return False
# Confirm weight scheme is supported. # Confirm weight scheme is supported.
is_floating_point = (weight_quant.type == QuantizationType.FLOAT
and input_quant.type == QuantizationType.FLOAT)
is_symmetric_weight = weight_quant.symmetric is_symmetric_weight = weight_quant.symmetric
is_static_weight = not weight_quant.dynamic is_static_weight = not weight_quant.dynamic
is_per_tensor_or_channel_weight = (weight_quant.strategy in [ is_per_tensor_or_channel_weight = (weight_quant.strategy in [
QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL
]) ])
if not (is_symmetric_weight and is_static_weight if not (is_floating_point and is_symmetric_weight and is_static_weight
and is_per_tensor_or_channel_weight): and is_per_tensor_or_channel_weight):
return False return False
...@@ -169,11 +168,7 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -169,11 +168,7 @@ class CompressedTensorsConfig(QuantizationConfig):
is_symmetric_activation = input_quant.symmetric is_symmetric_activation = input_quant.symmetric
is_per_tensor_activation = ( is_per_tensor_activation = (
input_quant.strategy == QuantizationStrategy.TENSOR) input_quant.strategy == QuantizationStrategy.TENSOR)
if not (is_symmetric_activation and is_per_tensor_activation): return is_symmetric_activation and is_per_tensor_activation
return False
# All conditions satisfied.
return True
def _is_fp8_w8a16(self, weight_quant: BaseModel, def _is_fp8_w8a16(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool: input_quant: BaseModel) -> bool:
...@@ -230,6 +225,7 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -230,6 +225,7 @@ class CompressedTensorsConfig(QuantizationConfig):
group_size=weight_quant.group_size) group_size=weight_quant.group_size)
# Detect If Activation Quantization. # Detect If Activation Quantization.
# TODO @dsikka: clean-up conditions
if is_activation_quantization_format(self.quant_format): if is_activation_quantization_format(self.quant_format):
if self._is_fp8_w8a8(weight_quant, input_quant): if self._is_fp8_w8a8(weight_quant, input_quant):
is_fp8_w8a8_supported = self._check_scheme_supported( is_fp8_w8a8_supported = self._check_scheme_supported(
...@@ -237,7 +233,8 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -237,7 +233,8 @@ class CompressedTensorsConfig(QuantizationConfig):
if is_fp8_w8a8_supported: if is_fp8_w8a8_supported:
return CompressedTensorsW8A8Fp8( return CompressedTensorsW8A8Fp8(
strategy=weight_quant.strategy, strategy=weight_quant.strategy,
is_static_input_scheme=(not input_quant.dynamic)) is_static_input_scheme=(input_quant
and not input_quant.dynamic))
else: else:
return CompressedTensorsW8A16Fp8( return CompressedTensorsW8A16Fp8(
strategy=weight_quant.strategy, strategy=weight_quant.strategy,
......
...@@ -2,11 +2,10 @@ from typing import Callable, List, Optional ...@@ -2,11 +2,10 @@ from typing import Callable, List, Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn import Parameter
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme) CompressedTensorsScheme)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.parameter import ModelWeightParameter
__all__ = ["CompressedTensorsUnquantized"] __all__ = ["CompressedTensorsUnquantized"]
...@@ -24,7 +23,9 @@ class CompressedTensorsUnquantized(CompressedTensorsScheme): ...@@ -24,7 +23,9 @@ class CompressedTensorsUnquantized(CompressedTensorsScheme):
return 70 return 70
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
pass # required by torch.compile to be torch.nn.Parameter
layer.weight = torch.nn.Parameter(layer.weight.data,
requires_grad=False)
def create_weights(self, layer: torch.nn.Module, def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: List[int], output_partition_sizes: List[int],
...@@ -32,14 +33,15 @@ class CompressedTensorsUnquantized(CompressedTensorsScheme): ...@@ -32,14 +33,15 @@ class CompressedTensorsUnquantized(CompressedTensorsScheme):
params_dtype: torch.dtype, weight_loader: Callable, params_dtype: torch.dtype, weight_loader: Callable,
**kwargs): **kwargs):
weight = Parameter(torch.empty(sum(output_partition_sizes), weight = ModelWeightParameter(data=torch.empty(
input_size_per_partition, sum(output_partition_sizes),
dtype=params_dtype), input_size_per_partition,
requires_grad=False) dtype=params_dtype),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
layer.register_parameter("weight", weight) layer.register_parameter("weight", weight)
set_weight_attrs(weight, {"weight_loader": weight_loader})
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor: bias: Optional[torch.Tensor]) -> torch.Tensor:
......
...@@ -8,7 +8,10 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( ...@@ -8,7 +8,10 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme) CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N) GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedvLLMParameter)
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
__all__ = ["CompressedTensorsW4A16Sparse24"] __all__ = ["CompressedTensorsW4A16Sparse24"]
...@@ -45,7 +48,12 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme): ...@@ -45,7 +48,12 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
return 80 return 80
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
pass # required by torch.compile to be torch.nn.Parameter
layer.weight_packed = Parameter(layer.weight_packed.data,
requires_grad=False)
layer.scale_packed = Parameter(layer.scale_packed.data,
requires_grad=False)
layer.meta = Parameter(layer.meta.data, requires_grad=False)
def create_weights(self, layer: torch.nn.Module, input_size: int, def create_weights(self, layer: torch.nn.Module, input_size: int,
output_partition_sizes: List[int], output_partition_sizes: List[int],
...@@ -56,79 +64,65 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme): ...@@ -56,79 +64,65 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
pack_factor = 32 // self.quant_type.size_bits pack_factor = 32 // self.quant_type.size_bits
output_size_per_partition = sum(output_partition_sizes) output_size_per_partition = sum(output_partition_sizes)
qweight = Parameter( qweight = PackedvLLMParameter(data=torch.empty(
torch.empty( input_size_per_partition // self.tile_size // 2,
input_size_per_partition // self.tile_size // 2, output_size_per_partition * self.tile_size // pack_factor,
output_size_per_partition * self.tile_size // pack_factor, dtype=torch.int32,
dtype=torch.int32, ),
), input_dim=0,
requires_grad=False, output_dim=1,
) packed_dim=1,
set_weight_attrs( packed_factor=pack_factor,
qweight, marlin_tile_size=self.tile_size,
{ weight_loader=weight_loader)
"input_dim": 0,
"output_dim": 1,
"packed_dim": 1,
"pack_factor": pack_factor,
"marlin_tile_size": self.tile_size,
"weight_loader": weight_loader
},
)
layer.register_parameter("weight_packed", qweight)
input_groups = (1 if self.group_size is None else input_groups = (1 if self.group_size is None else
input_size_per_partition // self.group_size) input_size_per_partition // self.group_size)
scales = Parameter( weight_scale_args = {
"data":
torch.empty( torch.empty(
input_groups, input_groups,
output_size_per_partition, output_size_per_partition,
dtype=params_dtype, dtype=params_dtype,
), ),
requires_grad=False, "weight_loader":
) weight_loader
set_weight_attrs( }
scales,
{ if self.group_size is not None:
"output_dim": 1, scales = GroupQuantScaleParameter(output_dim=1,
"input_dim": None if input_groups == 1 else 0, input_dim=0,
"weight_loader": weight_loader **weight_scale_args)
}, else:
) scales = ChannelQuantScaleParameter(output_dim=1,
layer.register_parameter("scale_packed", scales) **weight_scale_args)
weight_shape = Parameter(torch.empty(2, dtype=torch.int64), weight_shape = BasevLLMParameter(data=torch.empty(2,
requires_grad=False) dtype=torch.int64),
weight_loader=weight_loader)
meta = PackedvLLMParameter(data=torch.empty(
input_size_per_partition // 8 // 2 // 2,
output_size_per_partition * 2,
dtype=torch.int16,
),
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=1,
marlin_tile_size=2,
weight_loader=weight_loader)
layer.register_parameter("weight_packed", qweight)
layer.register_parameter("weight_shape", weight_shape) layer.register_parameter("weight_shape", weight_shape)
set_weight_attrs(weight_shape, {"weight_loader": weight_loader}) layer.register_parameter("scale_packed", scales)
meta = Parameter(
torch.empty(
input_size_per_partition // 8 // 2 // 2,
output_size_per_partition * 2,
dtype=torch.int16,
),
requires_grad=False,
)
set_weight_attrs(
meta,
{
"input_dim": 0,
"packed_dim": 1,
"pack_factor": 1,
"output_dim": 1,
"marlin_tile_size": 2,
"weight_loader": weight_loader
},
)
layer.register_parameter("meta", meta) layer.register_parameter("meta", meta)
max_workspace_size = ( max_workspace_size = (
output_size_per_partition // output_size_per_partition //
GPTQ_MARLIN_24_MIN_THREAD_N) * GPTQ_MARLIN_24_MAX_PARALLEL GPTQ_MARLIN_24_MIN_THREAD_N) * GPTQ_MARLIN_24_MAX_PARALLEL
workspace = Parameter(torch.zeros(max_workspace_size, dtype=torch.int), workspace = Parameter(torch.zeros(max_workspace_size, dtype=torch.int),
requires_grad=False) requires_grad=False)
layer.workspace = workspace layer.workspace = workspace
......
...@@ -9,9 +9,10 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( ...@@ -9,9 +9,10 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin) apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise, create_per_channel_scale_param, convert_to_channelwise)
create_per_tensor_scale_param) from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
from vllm.model_executor.utils import set_weight_attrs ModelWeightParameter,
PerTensorScaleParameter)
__all__ = ["CompressedTensorsW8A16Fp8"] __all__ = ["CompressedTensorsW8A16Fp8"]
...@@ -40,11 +41,19 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme): ...@@ -40,11 +41,19 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
layer.logical_widths) layer.logical_widths)
layer.weight_scale = torch.nn.Parameter(ws_channelwise, layer.weight_scale = torch.nn.Parameter(ws_channelwise,
requires_grad=False) requires_grad=False)
else:
# required by torch.compile to be torch.nn.Parameter
layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
requires_grad=False)
# Weights must be transposed for marlin # Weights must be transposed for marlin
layer.weight = torch.nn.Parameter(layer.weight.t(), layer.weight = torch.nn.Parameter(layer.weight.t(),
requires_grad=False) requires_grad=False)
if self.is_static_input_scheme:
# required by torch.compile to be torch.nn.Parameter
layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
requires_grad=False)
prepare_fp8_layer_for_marlin(layer, strategy="channel") prepare_fp8_layer_for_marlin(layer, strategy="channel")
def create_weights(self, layer: torch.nn.Module, input_size: int, def create_weights(self, layer: torch.nn.Module, input_size: int,
...@@ -60,35 +69,39 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme): ...@@ -60,35 +69,39 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
layer.orig_dtype = params_dtype layer.orig_dtype = params_dtype
# WEIGHT # WEIGHT
weight = torch.nn.Parameter(torch.empty(output_size_per_partition, weight = ModelWeightParameter(data=torch.empty(
input_size_per_partition, output_size_per_partition,
dtype=torch.float8_e4m3fn), input_size_per_partition,
requires_grad=False) dtype=torch.float8_e4m3fn),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight", weight) layer.register_parameter("weight", weight)
set_weight_attrs(weight, {
"input_dim": 1,
"output_dim": 0,
"weight_loader": weight_loader,
})
# WEIGHT SCALE # WEIGHT SCALE
layer_kwargs = {"weight_loader": weight_loader}
if self.strategy == QuantizationStrategy.CHANNEL: if self.strategy == QuantizationStrategy.CHANNEL:
weight_scale = create_per_channel_scale_param( weight_scale = ChannelQuantScaleParameter(
output_partition_sizes, **layer_kwargs) data=torch.empty((sum(output_partition_sizes), 1),
dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader)
elif self.strategy == QuantizationStrategy.TENSOR: elif self.strategy == QuantizationStrategy.TENSOR:
weight_scale = create_per_tensor_scale_param( weight_scale = PerTensorScaleParameter(data=torch.empty(
output_partition_sizes, **layer_kwargs) len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
else: else:
raise ValueError( raise ValueError(
f"Unsupported weight strategy={self.strategy}, " f"Unsupported weight strategy={self.strategy}, "
f"supported strategies are {SUPPORTED_STRATEGIES}") f"supported strategies are {SUPPORTED_STRATEGIES}")
weight_scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("weight_scale", weight_scale) layer.register_parameter("weight_scale", weight_scale)
# INPUT SCALE (to deal with converted checkpoints) # INPUT SCALE (to deal with converted checkpoints)
if self.is_static_input_scheme: if self.is_static_input_scheme:
input_scale = create_per_tensor_scale_param( input_scale = PerTensorScaleParameter(data=torch.empty(
output_partition_sizes, **layer_kwargs) len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
layer.register_parameter("input_scale", input_scale) layer.register_parameter("input_scale", input_scale)
def apply_weights(self, def apply_weights(self,
......
...@@ -8,10 +8,10 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( ...@@ -8,10 +8,10 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
QuantizationStrategy) QuantizationStrategy)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear, create_per_channel_scale_param, apply_fp8_linear, cutlass_fp8_supported, requantize_with_max_scale)
create_per_tensor_scale_param, cutlass_fp8_supported, from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
requantize_with_max_scale) ModelWeightParameter,
from vllm.model_executor.utils import set_weight_attrs PerTensorScaleParameter)
__all__ = ["CompressedTensorsW8A8Fp8"] __all__ = ["CompressedTensorsW8A8Fp8"]
...@@ -46,6 +46,9 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): ...@@ -46,6 +46,9 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
elif self.strategy == QuantizationStrategy.CHANNEL: elif self.strategy == QuantizationStrategy.CHANNEL:
weight = layer.weight weight = layer.weight
layer.weight = Parameter(weight.t(), requires_grad=False) layer.weight = Parameter(weight.t(), requires_grad=False)
# required by torch.compile to be torch.nn.Parameter
layer.weight_scale = Parameter(layer.weight_scale.data,
requires_grad=False)
else: else:
raise ValueError(f"Unknown quantization strategy {self.strategy}") raise ValueError(f"Unknown quantization strategy {self.strategy}")
...@@ -66,32 +69,40 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): ...@@ -66,32 +69,40 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
layer.logical_widths = output_partition_sizes layer.logical_widths = output_partition_sizes
# WEIGHT # WEIGHT
weight = torch.nn.Parameter(torch.empty(output_size_per_partition, weight = ModelWeightParameter(data=torch.empty(
input_size_per_partition, output_size_per_partition,
dtype=torch.float8_e4m3fn), input_size_per_partition,
requires_grad=False) dtype=torch.float8_e4m3fn),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight", weight) layer.register_parameter("weight", weight)
set_weight_attrs(weight, {
"input_dim": 1,
"output_dim": 0,
"weight_loader": weight_loader,
})
# WEIGHT SCALE # WEIGHT SCALE
layer_kwargs = {"weight_loader": weight_loader} # TODO: update create_xxx_parameter functions to return
# the newly added parameters
if self.strategy == QuantizationStrategy.CHANNEL: if self.strategy == QuantizationStrategy.CHANNEL:
weight_scale = create_per_channel_scale_param( weight_scale = ChannelQuantScaleParameter(
output_partition_sizes, **layer_kwargs) data=torch.empty((sum(output_partition_sizes), 1),
dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader)
else: else:
assert self.strategy == QuantizationStrategy.TENSOR assert self.strategy == QuantizationStrategy.TENSOR
weight_scale = create_per_tensor_scale_param( weight_scale = PerTensorScaleParameter(data=torch.empty(
output_partition_sizes, **layer_kwargs) len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
# min requirement for fp8 kernels
weight_scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("weight_scale", weight_scale) layer.register_parameter("weight_scale", weight_scale)
# INPUT SCALE # INPUT SCALE
if self.is_static_input_scheme: if self.is_static_input_scheme:
input_scale = create_per_tensor_scale_param( input_scale = PerTensorScaleParameter(data=torch.empty(
output_partition_sizes, **layer_kwargs) len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
input_scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("input_scale", input_scale) layer.register_parameter("input_scale", input_scale)
def apply_weights(self, def apply_weights(self,
......
...@@ -8,9 +8,11 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( ...@@ -8,9 +8,11 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
QuantizationStrategy) QuantizationStrategy)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_int8_linear, convert_to_channelwise, create_per_channel_scale_param, apply_int8_linear, convert_to_channelwise)
create_per_tensor_scale_param) from vllm.model_executor.parameter import (BasevLLMParameter,
from vllm.model_executor.utils import set_weight_attrs ChannelQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)
class CompressedTensorsW8A8Int8(CompressedTensorsScheme): class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
...@@ -39,7 +41,9 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme): ...@@ -39,7 +41,9 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
ws_channelwise = convert_to_channelwise(layer.weight_scale, ws_channelwise = convert_to_channelwise(layer.weight_scale,
self.logical_widths) self.logical_widths)
layer.weight_scale = Parameter(ws_channelwise, requires_grad=False) layer.weight_scale = Parameter(ws_channelwise, requires_grad=False)
else:
layer.weight_scale = Parameter(layer.weight_scale.data,
requires_grad=False)
# INPUT SCALE # INPUT SCALE
if self.is_static_input_scheme: if self.is_static_input_scheme:
layer.input_scale = Parameter(layer.input_scale.max(), layer.input_scale = Parameter(layer.input_scale.max(),
...@@ -55,32 +59,35 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme): ...@@ -55,32 +59,35 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
self.logical_widths = output_partition_sizes self.logical_widths = output_partition_sizes
# WEIGHT # WEIGHT
weight = Parameter(torch.empty(sum(output_partition_sizes), weight = ModelWeightParameter(data=torch.empty(
input_size_per_partition, sum(output_partition_sizes),
dtype=torch.int8), input_size_per_partition,
requires_grad=False) dtype=torch.int8),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight", weight) layer.register_parameter("weight", weight)
set_weight_attrs(weight, {
"input_dim": 1,
"output_dim": 0,
"weight_loader": weight_loader,
})
# WEIGHT SCALE # WEIGHT SCALE
layer_kwargs = {"weight_loader": weight_loader}
if self.strategy == QuantizationStrategy.CHANNEL: if self.strategy == QuantizationStrategy.CHANNEL:
weight_scale = create_per_channel_scale_param( weight_scale = ChannelQuantScaleParameter(
output_partition_sizes, **layer_kwargs) data=torch.empty((sum(output_partition_sizes), 1),
dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader)
else: else:
assert self.strategy == QuantizationStrategy.TENSOR assert self.strategy == QuantizationStrategy.TENSOR
weight_scale = create_per_tensor_scale_param( weight_scale = PerTensorScaleParameter(data=torch.empty(
output_partition_sizes, **layer_kwargs) len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
layer.register_parameter("weight_scale", weight_scale) layer.register_parameter("weight_scale", weight_scale)
# INPUT SCALE # INPUT SCALE
if self.is_static_input_scheme: if self.is_static_input_scheme:
input_scale = create_per_tensor_scale_param( input_scale = BasevLLMParameter(data=torch.empty(
output_partition_sizes, **layer_kwargs) 1, dtype=torch.float32),
weight_loader=weight_loader)
layer.register_parameter("input_scale", input_scale) layer.register_parameter("input_scale", input_scale)
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
......
from typing import Callable, List, Optional from typing import Callable, List, Optional
import torch import torch
from torch.nn import Parameter
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
...@@ -10,7 +9,10 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import ( ...@@ -10,7 +9,10 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_gptq_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace, apply_gptq_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace,
marlin_permute_scales, replace_tensor, verify_marlin_supported, marlin_permute_scales, replace_tensor, verify_marlin_supported,
verify_marlin_supports_shape) verify_marlin_supports_shape)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedvLLMParameter)
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
__all__ = ["CompressedTensorsWNA16"] __all__ = ["CompressedTensorsWNA16"]
...@@ -30,17 +32,12 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): ...@@ -30,17 +32,12 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
self.pack_factor = 32 // num_bits self.pack_factor = 32 // num_bits
self.strategy = strategy self.strategy = strategy
self.group_size = -1 if group_size is None else group_size
self.group_size: int if self.group_size == -1 and self.strategy != "channel":
if group_size is None: raise ValueError("Marlin kernels require group quantization or "
if self.strategy != "channel": "channelwise quantization, but found no group "
raise ValueError( "size and strategy is not channelwise.")
"Marlin kernels require group quantization or "
"channelwise quantization, but found no group "
"size and strategy is not channelwise.")
self.group_size = -1
else:
self.group_size = group_size
if num_bits not in WNA16_SUPPORTED_TYPES_MAP: if num_bits not in WNA16_SUPPORTED_TYPES_MAP:
raise ValueError( raise ValueError(
...@@ -63,11 +60,12 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): ...@@ -63,11 +60,12 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
input_size_per_partition: int, input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable, params_dtype: torch.dtype, weight_loader: Callable,
**kwargs): **kwargs):
output_size_per_partition = sum(output_partition_sizes) output_size_per_partition = sum(output_partition_sizes)
# If group_size is -1, we are in channelwise case. # If group_size is -1, we are in channelwise case.
channelwise = (self.group_size == -1) channelwise = (self.group_size == -1)
group_size = input_size if channelwise else self.group_size group_size = self.group_size if self.group_size != -1 else input_size
row_parallel = (input_size != input_size_per_partition) row_parallel = (input_size != input_size_per_partition)
# In the case of channelwise quantization, we need to replicate the # In the case of channelwise quantization, we need to replicate the
# scales across all gpus. # scales across all gpus.
...@@ -79,60 +77,51 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): ...@@ -79,60 +77,51 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
input_size=input_size, input_size=input_size,
group_size=group_size) group_size=group_size)
weight_scale_dim = None
scales_and_zp_size = input_size // group_size scales_and_zp_size = input_size // group_size
if partition_scales: if partition_scales:
assert input_size_per_partition % group_size == 0 assert input_size_per_partition % group_size == 0
weight_scale_dim = 1
scales_and_zp_size = input_size_per_partition // group_size scales_and_zp_size = input_size_per_partition // group_size
weight = Parameter( weight = PackedvLLMParameter(input_dim=1,
torch.empty( output_dim=0,
output_size_per_partition, weight_loader=weight_loader,
input_size_per_partition // self.pack_factor, packed_factor=self.pack_factor,
dtype=torch.int32, packed_dim=1,
), data=torch.empty(
requires_grad=False, output_size_per_partition,
) input_size_per_partition //
self.pack_factor,
set_weight_attrs( dtype=torch.int32,
weight, { ))
"input_dim": 1,
"output_dim": 0, weight_scale_args = {
"packed_dim": 1, "weight_loader":
"pack_factor": self.pack_factor, weight_loader,
"weight_loader": weight_loader "data":
})
layer.register_parameter("weight_packed", weight)
weight_scale = Parameter(
torch.empty( torch.empty(
output_size_per_partition, output_size_per_partition,
scales_and_zp_size, scales_and_zp_size,
dtype=params_dtype, dtype=params_dtype,
), )
requires_grad=False, }
) if not partition_scales:
weight_scale = ChannelQuantScaleParameter(output_dim=0,
set_weight_attrs( **weight_scale_args)
weight_scale, { else:
"weight_loader": weight_loader, weight_scale = GroupQuantScaleParameter(output_dim=0,
"input_dim": weight_scale_dim, input_dim=1,
"output_dim": 0 **weight_scale_args)
})
layer.register_parameter("weight_scale", weight_scale)
# A 2D array defining the original shape of the weights # A 2D array defining the original shape of the weights
# before packing # before packing
weight_shape = Parameter(torch.empty(2, dtype=torch.int64), weight_shape = BasevLLMParameter(data=torch.empty(2,
requires_grad=False) dtype=torch.int64),
weight_loader=weight_loader)
layer.register_parameter("weight_packed", weight)
layer.register_parameter("weight_scale", weight_scale)
layer.register_parameter("weight_shape", weight_shape) layer.register_parameter("weight_shape", weight_shape)
set_weight_attrs(weight_shape, {
"weight_loader": weight_loader,
"ignore_warning": True,
})
layer.input_size_per_partition = input_size_per_partition layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition layer.output_size_per_partition = output_size_per_partition
...@@ -154,10 +143,15 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): ...@@ -154,10 +143,15 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
# No zero-point # No zero-point
layer.weight_zp = marlin_make_empty_g_idx(device) layer.weight_zp = marlin_make_empty_g_idx(device)
# Update for kernel
layer.weight_packed = torch.nn.Parameter(
layer.weight_packed.t().contiguous(), requires_grad=False)
layer.weight_scale = torch.nn.Parameter(
layer.weight_scale.squeeze().t().contiguous(), requires_grad=False)
# Repack weights from compressed-tensors format to marlin format. # Repack weights from compressed-tensors format to marlin format.
marlin_qweight = ops.gptq_marlin_repack( marlin_qweight = ops.gptq_marlin_repack(
layer.weight_packed.t().contiguous(), layer.weight_packed,
perm=layer.g_idx_sort_indices, perm=layer.g_idx_sort_indices,
size_k=layer.input_size_per_partition, size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition, size_n=layer.output_size_per_partition,
...@@ -166,7 +160,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): ...@@ -166,7 +160,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
# Permute scales from compressed-tensors format to marlin format. # Permute scales from compressed-tensors format to marlin format.
marlin_scales = marlin_permute_scales( marlin_scales = marlin_permute_scales(
layer.weight_scale.squeeze().t().contiguous(), layer.weight_scale,
size_k=layer.input_size_per_partition, size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition, size_n=layer.output_size_per_partition,
group_size=layer.group_size) group_size=layer.group_size)
......
from typing import Any, Dict, List, Optional
import torch
from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs
class ExpertsInt8Config(QuantizationConfig):
"""Config class for Int8 experts quantization."""
def __init__(self) -> None:
pass
@classmethod
def get_name(cls) -> str:
return "experts_int8"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.bfloat16, torch.half]
@classmethod
def get_min_capability(cls) -> int:
return 80
@classmethod
def get_config_filenames(cls) -> List[str]:
return []
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "ExpertsInt8Config":
return cls()
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, LinearBase):
return UnquantizedLinearMethod()
elif isinstance(layer, FusedMoE):
return ExpertsInt8MoEMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
class ExpertsInt8MoEMethod(FusedMoEMethodBase):
def __init__(self, quant_config: ExpertsInt8Config):
self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size: int,
params_dtype: torch.dtype, **extra_weight_attrs):
int8_dtype = torch.int8
assert 'weight_loader' in extra_weight_attrs
weight_loader = extra_weight_attrs['weight_loader']
wrapped_weight_loader = ExpertsInt8MoEMethod.quantizing_weight_loader(
layer, weight_loader)
extra_weight_attrs['weight_loader'] = wrapped_weight_loader
# Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter(torch.empty(num_experts,
2 * intermediate_size,
hidden_size,
dtype=int8_dtype),
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
# down_proj (row parallel)
w2_weight = torch.nn.Parameter(torch.empty(num_experts,
hidden_size,
intermediate_size,
dtype=int8_dtype),
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
w13_scale = torch.nn.Parameter(torch.zeros(num_experts,
2 * intermediate_size,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_scale", w13_scale)
w2_scale = torch.nn.Parameter(torch.zeros(num_experts,
hidden_size,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w2_scale", w2_scale)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group)
return fused_experts(x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
use_int8_w8a16=True,
w1_scale=layer.w13_scale,
w2_scale=layer.w2_scale)
@staticmethod
def quantizing_weight_loader(layer, weight_loader):
def quantize_and_call_weight_loader(param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
weight_name: str, shard_id: int,
expert_id: int):
tp_rank = get_tensor_model_parallel_rank()
shard_size = layer.intermediate_size_per_partition
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
device = get_tp_group().device
loaded_weight = loaded_weight.to(device)
# w1, gate_proj case: Load into first shard of w13.
if shard_id == "w1":
scales = quantize_in_place_and_get_scales(
loaded_weight[shard, :])
layer.w13_scale.data[expert_id, 0:shard_size].copy_(scales[:,
0])
# w3, up_proj case: Load into second shard of w13.
elif shard_id == "w3":
scales = quantize_in_place_and_get_scales(
loaded_weight[shard, :])
layer.w13_scale.data[expert_id, shard_size:2 *
shard_size].copy_(scales[:, 0])
# w2, down_proj case: Load into only shard of w2.
elif shard_id == "w2":
scales = quantize_in_place_and_get_scales(loaded_weight[:,
shard])
layer.w2_scale.data[expert_id, :].copy_(scales[:, 0])
else:
raise ValueError(
f"Shard id must be in [0,1,2] but got {shard_id}")
weight_loader(param, loaded_weight, weight_name, shard_id,
expert_id)
return quantize_and_call_weight_loader
def quantize_in_place_and_get_scales(weight: torch.Tensor) -> torch.Tensor:
vmax = torch.iinfo(torch.int8).max
scales = (torch.max(torch.abs(weight), dim=1, keepdim=True)[0] / vmax)
weight.div_(scales)
weight.round_()
weight.clamp_(-vmax, vmax)
return scales
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment