Commit 081057de authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.5' into v0.8.5-ori

parents 7cf5d5c4 ba41cc90
...@@ -15,7 +15,7 @@ def cutlass_moe_fp8( ...@@ -15,7 +15,7 @@ def cutlass_moe_fp8(
w1_scale: torch.Tensor, w1_scale: torch.Tensor,
w2_scale: torch.Tensor, w2_scale: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids_: torch.Tensor,
ab_strides1: torch.Tensor, ab_strides1: torch.Tensor,
c_strides1: torch.Tensor, c_strides1: torch.Tensor,
ab_strides2: torch.Tensor, ab_strides2: torch.Tensor,
...@@ -23,6 +23,7 @@ def cutlass_moe_fp8( ...@@ -23,6 +23,7 @@ def cutlass_moe_fp8(
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
out_dtype: torch.dtype = torch.half, out_dtype: torch.dtype = torch.half,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
...@@ -57,12 +58,19 @@ def cutlass_moe_fp8( ...@@ -57,12 +58,19 @@ def cutlass_moe_fp8(
quantize the intermediate result between the gemms. quantize the intermediate result between the gemms.
Shape: scalar or [M] Shape: scalar or [M]
- out_dtype (torch.Tensor): The output tensor type. - out_dtype (torch.Tensor): The output tensor type.
- expert_map (Optional[torch.Tensor]): In the case of Expert parallel,
every Rank is responsible for a subset of experts. expert_map is a
mapping from global expert-id to local expert-id. When expert_map[i]
is -1, it means that this Rank is not responsible for global
expert-id i.
- apply_router_weight_on_input (bool): When true, the topk weights are
applied directly on the inputs. This is only applicable when topk is 1.
Returns: Returns:
- torch.Tensor: The fp16 output tensor after applying the MoE layer. - torch.Tensor: The fp16 output tensor after applying the MoE layer.
""" """
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert topk_weights.shape == topk_ids_.shape, "topk shape mismatch"
assert w1_q.dtype == torch.float8_e4m3fn assert w1_q.dtype == torch.float8_e4m3fn
assert w2_q.dtype == torch.float8_e4m3fn assert w2_q.dtype == torch.float8_e4m3fn
assert a.shape[1] == w1_q.shape[1], "Hidden size mismatch w1" assert a.shape[1] == w1_q.shape[1], "Hidden size mismatch w1"
...@@ -96,7 +104,13 @@ def cutlass_moe_fp8( ...@@ -96,7 +104,13 @@ def cutlass_moe_fp8(
k = w1_q.size(1) k = w1_q.size(1)
n = w2_q.size(1) n = w2_q.size(1)
topk = topk_ids.size(1) local_topk_ids = topk_ids_
if expert_map is not None:
"Translate info from expert_map to topk_ids"
local_topk_ids = torch.where(expert_map[topk_ids_] != -1,
expert_map[topk_ids_], -1)
topk = local_topk_ids.size(1)
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
a2_scale.numel() != 1 if a2_scale is not None else False) a2_scale.numel() != 1 if a2_scale is not None else False)
...@@ -120,10 +134,23 @@ def cutlass_moe_fp8( ...@@ -120,10 +134,23 @@ def cutlass_moe_fp8(
dtype=torch.int32, dtype=torch.int32,
device=device) device=device)
a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) a_map_initializer = torch.empty
c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) c2_initializer = torch.empty
if expert_map is not None:
ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, problem_sizes1, # With expert_map each Rank processes only a subset of experts. As
# a result not all of a_map and c2 tensors are filled. We fill it
# zeros for correctness.
a_map_initializer = torch.zeros
c2_initializer = torch.zeros
a_map = a_map_initializer((local_topk_ids.numel()),
dtype=torch.int32,
device=device)
c_map = torch.empty((local_topk_ids.numel()),
dtype=torch.int32,
device=device)
ops.get_cutlass_moe_mm_data(local_topk_ids, expert_offsets, problem_sizes1,
problem_sizes2, a_map, c_map, num_experts, n, problem_sizes2, a_map, c_map, num_experts, n,
k) k)
...@@ -131,7 +158,7 @@ def cutlass_moe_fp8( ...@@ -131,7 +158,7 @@ def cutlass_moe_fp8(
rep_a1_scales = a1_scale[a_map] if per_act_token else a1_scale rep_a1_scales = a1_scale[a_map] if per_act_token else a1_scale
c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype) c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype)
c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype) c2 = c2_initializer((m * topk, k), device=device, dtype=out_dtype)
ops.cutlass_moe_mm(c1, rep_a_q, w1_q, rep_a1_scales, w1_scale, ops.cutlass_moe_mm(c1, rep_a_q, w1_q, rep_a1_scales, w1_scale,
expert_offsets[:-1], problem_sizes1, ab_strides1, expert_offsets[:-1], problem_sizes1, ab_strides1,
......
...@@ -5,17 +5,16 @@ from typing import Optional ...@@ -5,17 +5,16 @@ from typing import Optional
import torch import torch
import vllm._custom_ops as ops
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, moe_align_block_size, try_get_optimal_moe_config) fused_topk, moe_align_block_size, try_get_optimal_moe_config)
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
def get_scalar_type(num_bits: int, has_zp: bool): def get_scalar_type(num_bits: int, has_zp: bool):
if has_zp: if has_zp:
assert num_bits == 4 return scalar_types.uint4 if num_bits == 4 else scalar_types.uint8
return scalar_types.uint4
else: else:
return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128 return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128
...@@ -27,9 +26,12 @@ def single_marlin_moe( ...@@ -27,9 +26,12 @@ def single_marlin_moe(
gating_output: torch.Tensor, gating_output: torch.Tensor,
topk: int, topk: int,
renormalize: bool, renormalize: bool,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
g_idx: Optional[torch.Tensor] = None, g_idx: Optional[torch.Tensor] = None,
sort_indices: Optional[torch.Tensor] = None, sort_indices: Optional[torch.Tensor] = None,
w_zeros: Optional[torch.Tensor] = None, w_zeros: Optional[torch.Tensor] = None,
workspace: Optional[torch.Tensor] = None,
num_bits: int = 8, num_bits: int = 8,
is_k_full: bool = True, is_k_full: bool = True,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -62,7 +64,7 @@ def single_marlin_moe( ...@@ -62,7 +64,7 @@ def single_marlin_moe(
assert gating_output.shape[1] == w.shape[0], "Number of experts mismatch" assert gating_output.shape[1] == w.shape[0], "Number of experts mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w.is_contiguous(), "Expert weights must be contiguous" assert w.is_contiguous(), "Expert weights must be contiguous"
assert hidden_states.dtype == torch.float16 assert hidden_states.dtype in [torch.float16, torch.bfloat16]
assert num_bits in [4, 8] assert num_bits in [4, 8]
M, K = hidden_states.shape M, K = hidden_states.shape
...@@ -83,39 +85,54 @@ def single_marlin_moe( ...@@ -83,39 +85,54 @@ def single_marlin_moe(
block_size_m = config['BLOCK_SIZE_M'] block_size_m = config['BLOCK_SIZE_M']
sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E) if global_num_experts == -1:
global_num_experts = E
max_workspace_size = (N // 64) * 16 sorted_token_ids, expert_ids, num_tokens_post_padded = \
workspace = torch.zeros(max_workspace_size, moe_align_block_size(topk_ids, block_size_m, E, expert_map)
dtype=torch.int,
device=hidden_states.device, if workspace is None:
requires_grad=False) max_workspace_size = (max(2 * N, K) // 64) * \
(sorted_token_ids.size(0) // block_size_m)
has_zero_point = w_zeros is not None device = hidden_states.device
if w_zeros is None: sms = torch.cuda.get_device_properties(device).multi_processor_count
w_zeros = torch.empty((0, 0), max_workspace_size = min(max_workspace_size, sms)
dtype=hidden_states.dtype, workspace = torch.zeros(max_workspace_size,
device=hidden_states.device, dtype=torch.int,
requires_grad=False) device=device,
requires_grad=False)
if g_idx is None:
g_idx = torch.empty((0, 0), scalar_type = get_scalar_type(num_bits, w_zeros is not None)
dtype=torch.int32, intermediate_cache = torch.empty(
device=hidden_states.device, (M * topk_ids.shape[1], N),
requires_grad=False) device=hidden_states.device,
dtype=hidden_states.dtype,
if sort_indices is None: )
sort_indices = torch.empty((0),
dtype=torch.int32,
device=hidden_states.device,
requires_grad=False)
scalar_type = get_scalar_type(num_bits, has_zero_point)
intermediate_cache = torch.ops._moe_C.marlin_gemm_moe( ops.moe_wna16_marlin_gemm(hidden_states,
hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales, intermediate_cache,
w_zeros, g_idx, sort_indices, workspace, scalar_type.id, M, N, K, w,
is_k_full, E, topk, block_size_m, True, False) scales,
w_zeros,
g_idx,
sort_indices,
workspace,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
topk_weights,
moe_block_size=block_size_m,
top_k=topk,
mul_topk_weights=False,
is_ep=expert_map is not None,
b_q_type=scalar_type,
size_m=M,
size_n=N,
size_k=K,
is_k_full=is_k_full,
use_atomic_add=False,
use_fp32_reduce=True,
is_zp_float=False)
intermediate_cache = intermediate_cache.view(-1, topk, N)
return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
...@@ -127,9 +144,12 @@ def single_marlin_moe_fake( ...@@ -127,9 +144,12 @@ def single_marlin_moe_fake(
gating_output: torch.Tensor, gating_output: torch.Tensor,
topk: int, topk: int,
renormalize: bool, renormalize: bool,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
g_idx: Optional[torch.Tensor] = None, g_idx: Optional[torch.Tensor] = None,
sort_indices: Optional[torch.Tensor] = None, sort_indices: Optional[torch.Tensor] = None,
w_zeros: Optional[torch.Tensor] = None, w_zeros: Optional[torch.Tensor] = None,
workspace: Optional[torch.Tensor] = None,
num_bits: int = 8, num_bits: int = 8,
is_k_full: bool = True, is_k_full: bool = True,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -144,24 +164,26 @@ direct_register_custom_op( ...@@ -144,24 +164,26 @@ direct_register_custom_op(
) )
def fused_marlin_moe( def fused_marlin_moe(hidden_states: torch.Tensor,
hidden_states: torch.Tensor, w1: torch.Tensor,
w1: torch.Tensor, w2: torch.Tensor,
w2: torch.Tensor, w1_scale: torch.Tensor,
w1_scale: torch.Tensor, w2_scale: torch.Tensor,
w2_scale: torch.Tensor, gating_output: torch.Tensor,
gating_output: torch.Tensor, topk_weights: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
topk_ids: torch.Tensor, global_num_experts: int = -1,
g_idx1: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
g_idx2: Optional[torch.Tensor] = None, g_idx1: Optional[torch.Tensor] = None,
sort_indices1: Optional[torch.Tensor] = None, g_idx2: Optional[torch.Tensor] = None,
sort_indices2: Optional[torch.Tensor] = None, sort_indices1: Optional[torch.Tensor] = None,
w1_zeros: Optional[torch.Tensor] = None, sort_indices2: Optional[torch.Tensor] = None,
w2_zeros: Optional[torch.Tensor] = None, w1_zeros: Optional[torch.Tensor] = None,
num_bits: int = 8, w2_zeros: Optional[torch.Tensor] = None,
is_k_full: bool = True, workspace: Optional[torch.Tensor] = None,
) -> torch.Tensor: num_bits: int = 8,
is_k_full: bool = True,
inplace: bool = False) -> torch.Tensor:
""" """
This function computes a Mixture of Experts (MoE) layer using two sets of This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism. weights, w1 and w2, and top-k gating mechanism.
...@@ -196,27 +218,12 @@ def fused_marlin_moe( ...@@ -196,27 +218,12 @@ def fused_marlin_moe(
1] == w1.shape[1] * 16, "Hidden size mismatch w1" 1] == w1.shape[1] * 16, "Hidden size mismatch w1"
assert hidden_states.shape[1] == w2.shape[2] // ( assert hidden_states.shape[1] == w2.shape[2] // (
num_bits // 2), "Hidden size mismatch w2" num_bits // 2), "Hidden size mismatch w2"
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous" assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype == torch.float16 assert hidden_states.dtype in [torch.float16, torch.bfloat16]
assert num_bits in [4, 8] assert num_bits in [4, 8]
has_no_act_order = (g_idx1 is None and g_idx2 is None
and sort_indices1 is None and sort_indices2 is None)
has_all_act_order = (g_idx1 is not None and g_idx2 is not None
and sort_indices1 is not None
and sort_indices2 is not None)
assert has_no_act_order or has_all_act_order, (
"g_idx and sorted_indices "
"must be all not None or must be all None")
has_no_zp = w1_zeros is None and w2_zeros is None
has_all_zp = w1_zeros is not None and w2_zeros is not None
assert has_no_zp or has_all_zp, ("zero points must be both not None or "
"must be both None")
M, K = hidden_states.shape M, K = hidden_states.shape
E = w1.shape[0] E = w1.shape[0]
N = w2.shape[1] * 16 N = w2.shape[1] * 16
...@@ -234,122 +241,128 @@ def fused_marlin_moe( ...@@ -234,122 +241,128 @@ def fused_marlin_moe(
block_size_m = config["BLOCK_SIZE_M"] block_size_m = config["BLOCK_SIZE_M"]
sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E) if global_num_experts == -1:
global_num_experts = E
max_workspace_size = (max(2 * N, K) // 64) * 16 sorted_token_ids, expert_ids, num_tokens_post_padded = \
workspace = torch.zeros(max_workspace_size, moe_align_block_size(topk_ids, block_size_m, global_num_experts,
dtype=torch.int, expert_map)
device=current_platform.device_type,
requires_grad=False) if workspace is None:
max_workspace_size = (max(2 * N, K) // 64) * \
if has_no_zp: (sorted_token_ids.size(0) // block_size_m)
w1_zeros = torch.empty((0, 0), device = hidden_states.device
dtype=hidden_states.dtype, sms = torch.cuda.get_device_properties(device).multi_processor_count
device=hidden_states.device, max_workspace_size = min(max_workspace_size, sms * 4)
requires_grad=False) workspace = torch.zeros(max_workspace_size,
w2_zeros = torch.empty((0, 0), dtype=torch.int,
dtype=hidden_states.dtype, device=device,
device=hidden_states.device, requires_grad=False)
requires_grad=False)
scalar_type1 = get_scalar_type(num_bits, w1_zeros is not None)
if has_no_act_order: scalar_type2 = get_scalar_type(num_bits, w2_zeros is not None)
g_idx1 = torch.empty((0, 0),
dtype=torch.int32,
device=hidden_states.device,
requires_grad=False)
g_idx2 = torch.empty((0, 0),
dtype=torch.int32,
device=hidden_states.device,
requires_grad=False)
sort_indices1 = torch.empty((0),
dtype=torch.int32,
device=hidden_states.device,
requires_grad=False)
sort_indices2 = torch.empty((0, 0),
dtype=torch.int32,
device=hidden_states.device,
requires_grad=False)
scalar_type1 = get_scalar_type(num_bits, has_all_zp)
scalar_type2 = get_scalar_type(num_bits, has_all_zp)
intermediate_cache2 = torch.empty( intermediate_cache2 = torch.empty(
(M * topk_ids.shape[1], N), (M * topk_ids.shape[1], N),
device=hidden_states.device, device=hidden_states.device,
dtype=hidden_states.dtype, dtype=hidden_states.dtype,
) )
intermediate_cache13 = torch.empty(
(M * topk_ids.shape[1] * max(2 * N, K), ),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
intermediate_cache1 = intermediate_cache13[:M * topk_ids.shape[1] * 2 * N]
intermediate_cache1 = intermediate_cache1.view(-1, 2 * N)
intermediate_cache3 = intermediate_cache13[:M * topk_ids.shape[1] * K]
intermediate_cache3 = intermediate_cache3.view(-1, K)
use_atomic_add = hidden_states.dtype == torch.half or \
torch.cuda.get_device_capability(hidden_states.device)[0] >= 9
intermediate_cache1 = torch.ops._moe_C.marlin_gemm_moe( intermediate_cache1 = ops.moe_wna16_marlin_gemm(
hidden_states, hidden_states,
intermediate_cache1,
w1, w1,
sorted_token_ids,
topk_weights,
topk_ids,
w1_scale, w1_scale,
w1_zeros, w1_zeros,
g_idx1, g_idx1,
sort_indices1, sort_indices1,
workspace, workspace,
scalar_type1.id, sorted_token_ids,
M, expert_ids,
2 * N, num_tokens_post_padded,
K, topk_weights,
is_k_full, moe_block_size=block_size_m,
E, top_k=topk,
topk, mul_topk_weights=False,
block_size_m, is_ep=expert_map is not None,
True, b_q_type=scalar_type1,
False, size_m=M,
) size_n=2 * N,
size_k=K,
is_k_full=is_k_full,
use_atomic_add=use_atomic_add,
use_fp32_reduce=True,
is_zp_float=False)
torch.ops._C.silu_and_mul(intermediate_cache2, torch.ops._C.silu_and_mul(intermediate_cache2,
intermediate_cache1.view(-1, 2 * N)) intermediate_cache1.view(-1, 2 * N))
intermediate_cache3 = torch.ops._moe_C.marlin_gemm_moe( if expert_map is not None:
intermediate_cache3.zero_()
intermediate_cache3 = ops.moe_wna16_marlin_gemm(
intermediate_cache2, intermediate_cache2,
intermediate_cache3,
w2, w2,
sorted_token_ids,
topk_weights,
topk_ids,
w2_scale, w2_scale,
w2_zeros, w2_zeros,
g_idx2, g_idx2,
sort_indices2, sort_indices2,
workspace, workspace,
scalar_type2.id, sorted_token_ids,
M, expert_ids,
K, num_tokens_post_padded,
N, topk_weights,
is_k_full, moe_block_size=block_size_m,
E, top_k=1,
topk, mul_topk_weights=True,
block_size_m, is_ep=expert_map is not None,
False, b_q_type=scalar_type2,
True, size_m=M * topk,
) size_n=K,
size_k=N,
is_k_full=is_k_full,
use_atomic_add=use_atomic_add,
use_fp32_reduce=True,
is_zp_float=False).view(-1, topk, K)
output = hidden_states if inplace else torch.empty_like(hidden_states)
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
dim=1) dim=1,
out=output)
def fused_marlin_moe_fake(
hidden_states: torch.Tensor, def fused_marlin_moe_fake(hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
w1_scale: torch.Tensor, w1_scale: torch.Tensor,
w2_scale: torch.Tensor, w2_scale: torch.Tensor,
gating_output: torch.Tensor, gating_output: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
g_idx1: Optional[torch.Tensor] = None, global_num_experts: int = -1,
g_idx2: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
sort_indices1: Optional[torch.Tensor] = None, g_idx1: Optional[torch.Tensor] = None,
sort_indices2: Optional[torch.Tensor] = None, g_idx2: Optional[torch.Tensor] = None,
w1_zeros: Optional[torch.Tensor] = None, sort_indices1: Optional[torch.Tensor] = None,
w2_zeros: Optional[torch.Tensor] = None, sort_indices2: Optional[torch.Tensor] = None,
num_bits: int = 8, w1_zeros: Optional[torch.Tensor] = None,
is_k_full: bool = True, w2_zeros: Optional[torch.Tensor] = None,
) -> torch.Tensor: workspace: Optional[torch.Tensor] = None,
num_bits: int = 8,
is_k_full: bool = True,
inplace: bool = False) -> torch.Tensor:
return torch.empty_like(hidden_states) return torch.empty_like(hidden_states)
......
...@@ -23,9 +23,7 @@ from vllm.model_executor.layers.quantization.utils.int8_utils import ( ...@@ -23,9 +23,7 @@ from vllm.model_executor.layers.quantization.utils.int8_utils import (
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
from .rocm_aiter_fused_moe import (is_rocm_aiter_moe_enabled, from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
rocm_aiter_fused_experts,
rocm_aiter_topk_softmax)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -792,6 +790,18 @@ def get_default_config( ...@@ -792,6 +790,18 @@ def get_default_config(
config = {"BLOCK_SIZE_M": 32, "GROUP_SIZE_M": 1} config = {"BLOCK_SIZE_M": 32, "GROUP_SIZE_M": 1}
else: else:
config = {"BLOCK_SIZE_M": 64, "GROUP_SIZE_M": 1} config = {"BLOCK_SIZE_M": 64, "GROUP_SIZE_M": 1}
elif is_marlin:
for block_size_m in [8, 16, 32, 48, 64]:
if M * topk / E / block_size_m < 0.9:
break
return {"BLOCK_SIZE_M": block_size_m}
elif M <= E:
config = {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
}
else: else:
config = { config = {
"BLOCK_SIZE_M": 64, "BLOCK_SIZE_M": 64,
...@@ -799,14 +809,7 @@ def get_default_config( ...@@ -799,14 +809,7 @@ def get_default_config(
"BLOCK_SIZE_K": 32, "BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8, "GROUP_SIZE_M": 8,
} }
# A heuristic: fused marlin works faster with this config for small M
if M <= E or (is_marlin and M <= 32):
config = {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
}
if use_nn_moe: if use_nn_moe:
config["num_ldmatrixes"] = 1 config["num_ldmatrixes"] = 1
return config return config
...@@ -867,6 +870,7 @@ def vllm_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor, ...@@ -867,6 +870,7 @@ def vllm_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor,
def dispatch_topk_func() -> Callable[..., tuple[torch.Tensor, ...]]: def dispatch_topk_func() -> Callable[..., tuple[torch.Tensor, ...]]:
if is_rocm_aiter_moe_enabled(): if is_rocm_aiter_moe_enabled():
from .rocm_aiter_fused_moe import rocm_aiter_topk_softmax
return rocm_aiter_topk_softmax return rocm_aiter_topk_softmax
return vllm_topk_softmax return vllm_topk_softmax
...@@ -1127,6 +1131,7 @@ def torch_vllm_outplace_fused_experts(**kwargs) -> torch.Tensor: ...@@ -1127,6 +1131,7 @@ def torch_vllm_outplace_fused_experts(**kwargs) -> torch.Tensor:
def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]: def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]:
if is_rocm_aiter_moe_enabled(): if is_rocm_aiter_moe_enabled():
from .rocm_aiter_fused_moe import rocm_aiter_fused_experts
return rocm_aiter_fused_experts return rocm_aiter_fused_experts
if inplace: if inplace:
return torch_vllm_inplace_fused_experts return torch_vllm_inplace_fused_experts
......
...@@ -128,12 +128,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -128,12 +128,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
super().process_weights_after_loading(layer) super().process_weights_after_loading(layer)
layer.w13_weight = torch.nn.Parameter(self._maybe_pad_weight( # Padding the weight for better performance on ROCm
layer.w13_weight.data), layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data)
requires_grad=False) layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data)
layer.w2_weight = torch.nn.Parameter(self._maybe_pad_weight(
layer.w2_weight.data),
requires_grad=False)
# Lazy import to avoid importing triton. # Lazy import to avoid importing triton.
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled, shuffle_weights) is_rocm_aiter_moe_enabled, shuffle_weights)
...@@ -142,10 +139,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -142,10 +139,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
shuffled_w13, shuffled_w2 = shuffle_weights( shuffled_w13, shuffled_w2 = shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data) layer.w13_weight.data, layer.w2_weight.data)
layer.w13_weight = torch.nn.Parameter(shuffled_w13, layer.w13_weight.data = shuffled_w13
requires_grad=False) layer.w2_weight.data = shuffled_w2
layer.w2_weight = torch.nn.Parameter(shuffled_w2,
requires_grad=False)
if current_platform.is_cpu(): if current_platform.is_cpu():
if current_platform.get_cpu_architecture() == CpuArchEnum.X86: if current_platform.get_cpu_architecture() == CpuArchEnum.X86:
...@@ -443,6 +438,7 @@ class FusedMoE(torch.nn.Module): ...@@ -443,6 +438,7 @@ class FusedMoE(torch.nn.Module):
if params_dtype is None: if params_dtype is None:
params_dtype = torch.get_default_dtype() params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
# Note: here we guard against accessing the TP and DP groups when # Note: here we guard against accessing the TP and DP groups when
# uninitialized (this happens when testing) # uninitialized (this happens when testing)
...@@ -493,6 +489,7 @@ class FusedMoE(torch.nn.Module): ...@@ -493,6 +489,7 @@ class FusedMoE(torch.nn.Module):
self.global_num_experts = num_experts self.global_num_experts = num_experts
assert intermediate_size % self.tp_size == 0 assert intermediate_size % self.tp_size == 0
self.hidden_size = hidden_size
self.intermediate_size_per_partition = intermediate_size // self.tp_size self.intermediate_size_per_partition = intermediate_size // self.tp_size
self.reduce_results = reduce_results self.reduce_results = reduce_results
self.renormalize = renormalize self.renormalize = renormalize
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import List, Optional from functools import cache
from typing import List, Optional, Tuple
import torch import torch
import vllm.envs as envs from vllm import envs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
@cache
def is_rocm_aiter_moe_enabled() -> bool: def is_rocm_aiter_moe_enabled() -> bool:
return current_platform.is_rocm() \ return current_platform.is_rocm() \
and envs.VLLM_ROCM_USE_AITER_MOE \ and envs.VLLM_ROCM_USE_AITER_MOE \
and envs.VLLM_ROCM_USE_AITER \ and envs.VLLM_ROCM_USE_AITER
def is_rocm_aiter_block_scaled_moe_enabled() -> bool: def rocm_aiter_asm_moe_tkw1_impl(
return is_rocm_aiter_moe_enabled() and \
envs.VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE
def rocm_aiter_fused_experts(
*,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
fc1_scale: Optional[torch.Tensor] = None,
fc2_scale: Optional[torch.Tensor] = None,
fc1_smooth_scale: Optional[torch.Tensor] = None,
fc2_smooth_scale: Optional[torch.Tensor] = None,
a16: bool = False,
per_tensor_quant_scale: Optional[torch.Tensor] = None,
expert_mask: Optional[torch.Tensor] = None,
activation_str: str = "silu") -> torch.Tensor:
from aiter import ActivationType
from aiter.fused_moe_bf16_asm import asm_moe_tkw1
activation = \
ActivationType.Gelu if activation_str == "gelu" else ActivationType.Silu
return asm_moe_tkw1(hidden_states,
w1,
w2,
topk_weight,
topk_ids,
fc1_scale=fc1_scale,
fc2_scale=fc2_scale,
fc1_smooth_scale=fc1_smooth_scale,
fc2_smooth_scale=fc2_smooth_scale,
a16=a16,
per_tensor_quant_scale=per_tensor_quant_scale,
expert_mask=expert_mask,
activation=activation)
def rocm_aiter_asm_moe_tkw1_fake(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
use_fp8_w8a8: bool = False, fc1_scale: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None, fc2_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None, fc1_smooth_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, fc2_smooth_scale: Optional[torch.Tensor] = None,
a16: bool = False,
per_tensor_quant_scale: Optional[torch.Tensor] = None,
expert_mask: Optional[torch.Tensor] = None, expert_mask: Optional[torch.Tensor] = None,
**kwagrs # Ignore additional keyword arguments activation_str: str = "silu") -> torch.Tensor:
) -> torch.Tensor: return torch.empty_like(hidden_states)
def rocm_aiter_ck_moe_impl(hidden_states: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor) -> torch.Tensor:
from aiter import ck_moe
return ck_moe(hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids)
import aiter as rocm_aiter def rocm_aiter_ck_moe_fake(hidden_states: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor) -> torch.Tensor:
return torch.empty_like(hidden_states)
def rocm_aiter_fmoe_fp8_blockscale_g1u1_impl(
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
hidden_states_dtype: torch.dtype,
expert_mask: torch.Tensor,
a1: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
a1_scale: torch.Tensor,
block_shape: List[int],
smooth_scale: Optional[torch.Tensor] = None) -> torch.Tensor:
from aiter import fmoe_fp8_blockscale_g1u1
from aiter.fused_moe_bf16_asm import moe_sorting_ck
topk = topk_ids.shape[1]
model_dim = w1.shape[-1]
local_E = E = w1.shape[0]
if expert_mask is not None:
E = expert_mask.numel()
(
sorted_token_ids,
sorted_weight_buf,
sorted_expert_ids,
num_valid_ids,
out_asm,
) = moe_sorting_ck(topk_ids,
topk_weights,
E,
model_dim,
hidden_states_dtype,
expert_mask=expert_mask)
fmoe_fp8_blockscale_g1u1(out_asm, a1, w1, w2, sorted_token_ids,
sorted_weight_buf, sorted_expert_ids,
num_valid_ids, topk, w1_scale.view(local_E, -1),
w2_scale.view(local_E, -1),
a1_scale.t().contiguous(), *block_shape,
smooth_scale)
return out_asm
def rocm_aiter_fmoe_fp8_blockscale_g1u1_fake(
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
hidden_states_dtype: torch.dtype,
expert_mask: torch.Tensor,
a1: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
a1_scale: torch.Tensor,
block_shape: List[int],
smooth_scale: Optional[torch.Tensor] = None) -> torch.Tensor:
return torch.empty_like(a1, dtype=torch.bf16)
def rocm_aiter_asm_moe_impl(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
fc1_scale: Optional[torch.Tensor] = None,
fc2_scale: Optional[torch.Tensor] = None,
fc1_smooth_scale: Optional[torch.Tensor] = None,
fc2_smooth_scale: Optional[torch.Tensor] = None,
a16: bool = False,
activation: str = "silu") -> torch.Tensor:
import aiter.fused_moe_bf16_asm as rocm_aiter_asm_fmoe import aiter.fused_moe_bf16_asm as rocm_aiter_asm_fmoe
from aiter import ActivationType
assert activation in ["silu", "gelu"], "The given activation:" \
f" {activation}" \
" is not supported in" \
" AITER."
if activation == "silu":
aiter_activation = ActivationType.Silu
else:
aiter_activation = ActivationType.Gelu
return rocm_aiter_asm_fmoe.asm_moe(hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weight=topk_weight,
topk_ids=topk_ids,
fc1_scale=fc1_scale,
fc2_scale=fc2_scale,
fc1_smooth_scale=fc1_smooth_scale,
fc2_smooth_scale=fc2_smooth_scale,
a16=a16,
activation=aiter_activation)
def rocm_aiter_asm_moe_fake(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
fc1_scale: Optional[torch.Tensor] = None,
fc2_scale: Optional[torch.Tensor] = None,
fc1_smooth_scale: Optional[torch.Tensor] = None,
fc2_smooth_scale: Optional[torch.Tensor] = None,
a16: bool = False,
activation: str = "silu") -> torch.Tensor:
return torch.empty_like(hidden_states)
def rocm_aiter_topk_softmax_impl(topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool) -> None:
from aiter import topk_softmax
topk_softmax(topk_weights, topk_indices, token_expert_indices,
gating_output, renormalize)
def rocm_aiter_topk_softmax_fake(topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool) -> None:
pass
if current_platform.is_rocm():
direct_register_custom_op(
op_name="rocm_aiter_asm_moe_tkw1",
op_func=rocm_aiter_asm_moe_tkw1_impl,
mutates_args=[],
fake_impl=rocm_aiter_asm_moe_tkw1_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_ck_moe",
op_func=rocm_aiter_ck_moe_impl,
mutates_args=[],
fake_impl=rocm_aiter_ck_moe_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_fmoe_fp8_blockscale_g1u1",
op_func=rocm_aiter_fmoe_fp8_blockscale_g1u1_impl,
mutates_args=[],
fake_impl=rocm_aiter_fmoe_fp8_blockscale_g1u1_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_asm_moe",
op_func=rocm_aiter_asm_moe_impl,
mutates_args=[],
fake_impl=rocm_aiter_asm_moe_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_topk_softmax",
op_func=rocm_aiter_topk_softmax_impl,
mutates_args=["topk_weights", "topk_indices", "token_expert_indices"],
fake_impl=rocm_aiter_topk_softmax_fake,
dispatch_key=current_platform.dispatch_key,
)
def rocm_aiter_fused_experts(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
allow_deep_gemm: bool = False) -> torch.Tensor:
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8) per_token_group_quant_fp8)
if envs.VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE and use_fp8_w8a8: # All AITER Fused MoE kernels are expecting the following datatypes
topk_weights = topk_weights.to(torch.float32)
topk_ids = topk_ids.to(torch.int32)
# w8a8 block-scaled
if block_shape is not None and use_fp8_w8a8:
assert not apply_router_weight_on_input, (
"apply_router_weight_on_input is not supported for block scaled moe"
)
assert w1_scale is not None assert w1_scale is not None
assert w2_scale is not None assert w2_scale is not None
local_E = E = w1.shape[0]
if expert_mask is not None:
E = expert_mask.numel()
topk = topk_ids.shape[1]
model_dim = w1.shape[-1]
dtype = hidden_states.dtype
# The default block sizes are 128 in AITER. # The default block sizes are 128 in AITER.
if block_shape is None: block_shape = [128, 128] if block_shape is None else block_shape
block_shape = [128, 128]
a1, a1_scale = per_token_group_quant_fp8(hidden_states, block_shape[1])
scale_blk_k = block_shape[1]
return torch.ops.vllm.rocm_aiter_fmoe_fp8_blockscale_g1u1(
( topk_ids, topk_weights, hidden_states.dtype, expert_map, a1, w1,
sorted_token_ids, w2, w1_scale, w2_scale, a1_scale, block_shape, None)
sorted_weight_buf,
sorted_expert_ids, # w8a8 per-channel quantization
num_valid_ids, elif per_channel_quant and apply_router_weight_on_input and use_fp8_w8a8:
out_asm, # AITER tkw1 kernel for FP8 models with `apply_router_weight_on_input`
) = rocm_aiter_asm_fmoe.moe_sorting_ck(topk_ids, # This applies topk_weights on the GEMM output of the first FC layer
topk_weights, # rather than the second FC.
E, assert (topk_weights.dim() == 2
model_dim, ), "`topk_weights` should be in shape (num_tokens, topk)"
dtype, assert topk_weights.shape[-1] == 1, (
expert_mask=expert_mask) "Only support topk=1 when"
" `apply_router_weight_on_input` is True")
a1, a1_scale = per_token_group_quant_fp8(hidden_states, scale_blk_k)
rocm_aiter.fmoe_fp8_blockscale_g1u1( return torch.ops.vllm.rocm_aiter_asm_moe_tkw1(
out_asm, hidden_states,
a1,
w1, w1,
w2, w2,
sorted_token_ids, topk_weights,
sorted_weight_buf, topk_ids,
sorted_expert_ids, fc1_scale=w1_scale,
num_valid_ids, fc2_scale=w2_scale,
topk, fc1_smooth_scale=None,
w1_scale.view(local_E, -1), fc2_smooth_scale=None,
w2_scale.view(local_E, -1), a16=False,
a1_scale.t().contiguous(), per_tensor_quant_scale=None,
block_shape[0], expert_mask=expert_map,
block_shape[1], activation_str=activation)
None,
) # w8a8 per-tensor activation per-tensor weight
return out_asm
elif use_fp8_w8a8: elif use_fp8_w8a8:
return rocm_aiter_asm_fmoe.asm_moe(hidden_states=hidden_states, assert not apply_router_weight_on_input, (
w1=w1, "apply_router_weight_on_input is not supported for fp8_w8a8")
w2=w2, return torch.ops.vllm.rocm_aiter_asm_moe(hidden_states=hidden_states,
topk_weight=topk_weights, w1=w1,
topk_ids=topk_ids, w2=w2,
fc1_scale=w1_scale, topk_weight=topk_weights,
fc2_scale=w2_scale, topk_ids=topk_ids,
fc1_smooth_scale=None, fc1_scale=w1_scale,
fc2_smooth_scale=None, fc2_scale=w2_scale,
a16=False) fc1_smooth_scale=None,
fc2_smooth_scale=None,
return rocm_aiter.ck_moe(hidden_states=hidden_states, a16=False,
w1=w1, activation=activation)
w2=w2, if apply_router_weight_on_input:
topk_weights=topk_weights, assert (topk_weights.dim() == 2
topk_ids=topk_ids) ), "`topk_weights` should be in shape (num_tokens, topk)"
_, topk = topk_weights.shape
assert (
topk == 1
), "Only support topk=1 when `apply_router_weight_on_input` is True"
hidden_states = hidden_states * topk_weights.to(hidden_states.dtype)
topk_ids = topk_ids.to(torch.int32)
topk_weights = torch.ones_like(topk_weights, dtype=torch.float32)
# w16a16 fallback to rocm_aiter_ck_moe w16a16
return torch.ops.vllm.rocm_aiter_ck_moe(hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids)
def rocm_aiter_topk_softmax(topk_weights: torch.Tensor, def rocm_aiter_topk_softmax(topk_weights: torch.Tensor,
topk_indices: torch.Tensor, topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor, token_expert_indices: torch.Tensor,
gating_output: torch.Tensor, gating_output: torch.Tensor,
renormalize: bool) -> tuple[torch.Tensor, ...]: renormalize: bool) -> Tuple[torch.Tensor, ...]:
import aiter as rocm_aiter torch.ops.vllm.rocm_aiter_topk_softmax(topk_weights, topk_indices,
rocm_aiter.topk_softmax(topk_weights, topk_indices, token_expert_indices, token_expert_indices, gating_output,
gating_output, renormalize) renormalize)
return topk_weights, topk_indices return topk_weights, topk_indices
def shuffle_weights(*tensors: torch.Tensor) -> tuple[torch.Tensor, ...]: def shuffle_weights(*tensors: torch.Tensor) -> Tuple[torch.Tensor, ...]:
""" """
Applies shuffle_weight function from AITER to each Applies shuffle_weight function from AITER to each
input tensor and returns them. input tensor and returns them.
...@@ -129,15 +388,14 @@ def shuffle_weights(*tensors: torch.Tensor) -> tuple[torch.Tensor, ...]: ...@@ -129,15 +388,14 @@ def shuffle_weights(*tensors: torch.Tensor) -> tuple[torch.Tensor, ...]:
*tensors: Variable number of torch.Tensor objects. *tensors: Variable number of torch.Tensor objects.
Returns: Returns:
A tuple of shuffled tensors. A Tuple of shuffled tensors.
""" """
from aiter.ops.shuffle import shuffle_weight from aiter.ops.shuffle import shuffle_weight
return tuple(shuffle_weight(tensor) for tensor in tensors) return tuple(shuffle_weight(tensor) for tensor in tensors)
def expand_weights(*tensors: torch.Tensor, def expand_weights(*tensors: torch.Tensor,
expansion_dims: list[int]) -> tuple[torch.Tensor, ...]: expansion_dims: list[int]) -> Tuple[torch.Tensor, ...]:
""" """
Expands the dimensions of input tensors. Expands the dimensions of input tensors.
...@@ -147,7 +405,7 @@ def expand_weights(*tensors: torch.Tensor, ...@@ -147,7 +405,7 @@ def expand_weights(*tensors: torch.Tensor,
corresponding to each tensor. corresponding to each tensor.
Returns: Returns:
A tuple of tensors with expanded dimensions. A Tuple of tensors with expanded dimensions.
""" """
assert len(tensors) == len(expansion_dims), \ assert len(tensors) == len(expansion_dims), \
......
...@@ -168,7 +168,8 @@ class RMSNorm(CustomOp): ...@@ -168,7 +168,8 @@ class RMSNorm(CustomOp):
x: torch.Tensor, x: torch.Tensor,
residual: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
from vllm_hpu_extension.ops import HPUFusedRMSNorm from vllm_hpu_extension.kernels import rms_norm
HPUFusedRMSNorm = rms_norm()
if HPUFusedRMSNorm is None: if HPUFusedRMSNorm is None:
return self.forward_native(x, residual) return self.forward_native(x, residual)
if residual is not None: if residual is not None:
......
...@@ -6,7 +6,6 @@ from typing import Any, Literal, Optional, Union ...@@ -6,7 +6,6 @@ from typing import Any, Literal, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter, UninitializedParameter from torch.nn.parameter import Parameter, UninitializedParameter
from vllm.distributed import (divide, get_tensor_model_parallel_rank, from vllm.distributed import (divide, get_tensor_model_parallel_rank,
...@@ -17,6 +16,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank, ...@@ -17,6 +16,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
# yapf: disable # yapf: disable
from vllm.model_executor.parameter import (BasevLLMParameter, from vllm.model_executor.parameter import (BasevLLMParameter,
BlockQuantScaleParameter, BlockQuantScaleParameter,
...@@ -31,6 +31,8 @@ logger = init_logger(__name__) ...@@ -31,6 +31,8 @@ logger = init_logger(__name__)
WEIGHT_LOADER_V2_SUPPORTED = [ WEIGHT_LOADER_V2_SUPPORTED = [
"CompressedTensorsLinearMethod", "CompressedTensorsLinearMethod",
"BitBLASLinearMethod",
"GPTQBitBLASLinearMethod",
"AWQMarlinLinearMethod", "AWQMarlinLinearMethod",
"AWQLinearMethod", "AWQLinearMethod",
"GPTQMarlinLinearMethod", "GPTQMarlinLinearMethod",
...@@ -50,6 +52,15 @@ WEIGHT_LOADER_V2_SUPPORTED = [ ...@@ -50,6 +52,15 @@ WEIGHT_LOADER_V2_SUPPORTED = [
] ]
def adjust_bitblas_shard(param, shard_size, shard_offset):
bitblas_tile_size = getattr(param, "bitblas_tile_size", None)
if bitblas_tile_size is not None:
return (shard_size // bitblas_tile_size,
shard_offset // bitblas_tile_size)
return shard_size, shard_offset
def adjust_marlin_shard(param, shard_size, shard_offset): def adjust_marlin_shard(param, shard_size, shard_offset):
marlin_tile_size = getattr(param, "marlin_tile_size", None) marlin_tile_size = getattr(param, "marlin_tile_size", None)
if marlin_tile_size is None: if marlin_tile_size is None:
...@@ -188,7 +199,7 @@ class UnquantizedLinearMethod(LinearMethodBase): ...@@ -188,7 +199,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return F.linear(x, layer.weight, bias) return dispatch_unquantized_gemm()(x, layer.weight, bias)
class LinearBase(torch.nn.Module): class LinearBase(torch.nn.Module):
...@@ -615,6 +626,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -615,6 +626,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_size, shard_offset = adjust_marlin_shard( shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset) param, shard_size, shard_offset)
shard_size, shard_offset = adjust_bitblas_shard(
param, shard_size, shard_offset)
if use_bitsandbytes_4bit: if use_bitsandbytes_4bit:
index = list(itertools.accumulate([0] + self.output_sizes)) index = list(itertools.accumulate([0] + self.output_sizes))
orig_offsets = { orig_offsets = {
...@@ -646,6 +660,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -646,6 +660,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
# Special case for Marlin. # Special case for Marlin.
shard_size, shard_offset = adjust_marlin_shard( shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset) param, shard_size, shard_offset)
shard_size, shard_offset = adjust_bitblas_shard(
param, shard_size, shard_offset)
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
False) False)
...@@ -913,6 +929,15 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -913,6 +929,15 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_offset = self._get_shard_offset_mapping(loaded_shard_id) shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
shard_size = self._get_shard_size_mapping(loaded_shard_id) shard_size = self._get_shard_size_mapping(loaded_shard_id)
# Note(simon): This is needed for Qwen3's fp8 quantization.
if isinstance(param, BlockQuantScaleParameter):
assert self.quant_method is not None
assert hasattr(self.quant_method, "quant_config")
weight_block_size = self.quant_method.quant_config.weight_block_size
block_n, _ = weight_block_size[0], weight_block_size[1]
shard_offset = (shard_offset + block_n - 1) // block_n
shard_size = (shard_size + block_n - 1) // block_n
param.load_qkv_weight(loaded_weight=loaded_weight, param.load_qkv_weight(loaded_weight=loaded_weight,
num_heads=self.num_kv_head_replicas, num_heads=self.num_kv_head_replicas,
shard_id=loaded_shard_id, shard_id=loaded_shard_id,
......
...@@ -10,8 +10,10 @@ from packaging import version ...@@ -10,8 +10,10 @@ from packaging import version
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.triton_utils import HAS_TRITON
TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0") TRITON3 = HAS_TRITON and (version.parse(triton.__version__)
>= version.parse("3.0.0"))
if TRITON3: if TRITON3:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Dict, List, Type from typing import Literal, Type, get_args
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
QUANTIZATION_METHODS: List[str] = [ QuantizationMethods = Literal[
"aqlm", "aqlm",
"awq", "awq",
"deepspeedfp", "deepspeedfp",
...@@ -15,12 +15,12 @@ QUANTIZATION_METHODS: List[str] = [ ...@@ -15,12 +15,12 @@ QUANTIZATION_METHODS: List[str] = [
"fbgemm_fp8", "fbgemm_fp8",
"modelopt", "modelopt",
"nvfp4", "nvfp4",
# The order of gptq methods is important for config.py iteration over
# override_quantization_method(..)
"marlin", "marlin",
"bitblas",
"gguf", "gguf",
"gptq_marlin_24", "gptq_marlin_24",
"gptq_marlin", "gptq_marlin",
"gptq_bitblas",
"awq_marlin", "awq_marlin",
"gptq", "gptq",
"compressed-tensors", "compressed-tensors",
...@@ -34,6 +34,7 @@ QUANTIZATION_METHODS: List[str] = [ ...@@ -34,6 +34,7 @@ QUANTIZATION_METHODS: List[str] = [
"moe_wna16", "moe_wna16",
"torchao", "torchao",
] ]
QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))
# The customized quantization methods which will be added to this dict. # The customized quantization methods which will be added to this dict.
_CUSTOMIZED_METHOD_TO_QUANT_CONFIG = {} _CUSTOMIZED_METHOD_TO_QUANT_CONFIG = {}
...@@ -85,6 +86,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: ...@@ -85,6 +86,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
from .aqlm import AQLMConfig from .aqlm import AQLMConfig
from .awq import AWQConfig from .awq import AWQConfig
from .awq_marlin import AWQMarlinConfig from .awq_marlin import AWQMarlinConfig
from .bitblas import BitBLASConfig
from .bitsandbytes import BitsAndBytesConfig from .bitsandbytes import BitsAndBytesConfig
from .compressed_tensors.compressed_tensors import ( # noqa: E501 from .compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsConfig) CompressedTensorsConfig)
...@@ -94,6 +96,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: ...@@ -94,6 +96,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
from .fp8 import Fp8Config from .fp8 import Fp8Config
from .gguf import GGUFConfig from .gguf import GGUFConfig
from .gptq import GPTQConfig from .gptq import GPTQConfig
from .gptq_bitblas import GPTQBitBLASConfig
from .gptq_marlin import GPTQMarlinConfig from .gptq_marlin import GPTQMarlinConfig
from .gptq_marlin_24 import GPTQMarlin24Config from .gptq_marlin_24 import GPTQMarlin24Config
from .hqq_marlin import HQQMarlinConfig from .hqq_marlin import HQQMarlinConfig
...@@ -107,7 +110,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: ...@@ -107,7 +110,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
from .torchao import TorchAOConfig from .torchao import TorchAOConfig
from .tpu_int8 import Int8TpuConfig from .tpu_int8 import Int8TpuConfig
method_to_config: Dict[str, Type[QuantizationConfig]] = { method_to_config: dict[str, Type[QuantizationConfig]] = {
"aqlm": AQLMConfig, "aqlm": AQLMConfig,
"awq": AWQConfig, "awq": AWQConfig,
"deepspeedfp": DeepSpeedFPConfig, "deepspeedfp": DeepSpeedFPConfig,
...@@ -116,12 +119,12 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: ...@@ -116,12 +119,12 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
"fbgemm_fp8": FBGEMMFp8Config, "fbgemm_fp8": FBGEMMFp8Config,
"modelopt": ModelOptFp8Config, "modelopt": ModelOptFp8Config,
"nvfp4": ModelOptNvFp4Config, "nvfp4": ModelOptNvFp4Config,
# The order of gptq methods is important for config.py iteration over
# override_quantization_method(..)
"marlin": MarlinConfig, "marlin": MarlinConfig,
"bitblas": BitBLASConfig,
"gguf": GGUFConfig, "gguf": GGUFConfig,
"gptq_marlin_24": GPTQMarlin24Config, "gptq_marlin_24": GPTQMarlin24Config,
"gptq_marlin": GPTQMarlinConfig, "gptq_marlin": GPTQMarlinConfig,
"gptq_bitblas": GPTQBitBLASConfig,
"awq_marlin": AWQMarlinConfig, "awq_marlin": AWQMarlinConfig,
"gptq": GPTQConfig, "gptq": GPTQConfig,
"compressed-tensors": CompressedTensorsConfig, "compressed-tensors": CompressedTensorsConfig,
...@@ -144,6 +147,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: ...@@ -144,6 +147,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
__all__ = [ __all__ = [
"QuantizationConfig", "QuantizationConfig",
"QuantizationMethods",
"get_quantization_config", "get_quantization_config",
"QUANTIZATION_METHODS", "QUANTIZATION_METHODS",
] ]
\ No newline at end of file
...@@ -17,14 +17,13 @@ from vllm.model_executor.layers.quantization.awq import (AWQConfig, ...@@ -17,14 +17,13 @@ from vllm.model_executor.layers.quantization.awq import (AWQConfig,
is_layer_skipped_awq) is_layer_skipped_awq)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config
from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported, apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported,
check_marlin_supports_layer, marlin_make_empty_g_idx, check_marlin_supports_layer, check_moe_marlin_supports_layer,
marlin_make_workspace, marlin_moe_permute_scales, marlin_permute_scales, marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales,
moe_awq_to_marlin_zero_points, verify_marlin_supported, marlin_permute_scales, moe_awq_to_marlin_zero_points,
verify_marlin_supports_shape) verify_marlin_supported, verify_marlin_supports_shape)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.parameter import (GroupQuantScaleParameter, from vllm.model_executor.parameter import (GroupQuantScaleParameter,
PackedvLLMParameter) PackedvLLMParameter)
...@@ -136,12 +135,15 @@ class AWQMarlinConfig(QuantizationConfig): ...@@ -136,12 +135,15 @@ class AWQMarlinConfig(QuantizationConfig):
self.full_config).get_quant_method(layer, prefix) self.full_config).get_quant_method(layer, prefix)
return AWQMarlinLinearMethod(self) return AWQMarlinLinearMethod(self)
elif isinstance(layer, FusedMoE): elif isinstance(layer, FusedMoE):
if layer.local_num_experts > 32: from vllm.model_executor.layers.quantization.moe_wna16 import (
# For MoEs with many experts the moe_wna16 kernel is faster MoeWNA16Config)
if not check_moe_marlin_supports_layer(layer, self.group_size):
logger.warning_one(
f"Layer '{prefix}' is not supported by AWQMoeMarlin. "
"Falling back to Moe WNA16 kernels.")
return MoeWNA16Config.from_config( return MoeWNA16Config.from_config(
self.full_config).get_quant_method(layer, prefix) self.full_config).get_quant_method(layer, prefix)
else: return AWQMoEMethod(self)
return AWQMoEMethod(self)
return None return None
@classmethod @classmethod
...@@ -391,6 +393,13 @@ class AWQMoEMethod(FusedMoEMethodBase): ...@@ -391,6 +393,13 @@ class AWQMoEMethod(FusedMoEMethodBase):
layer.register_parameter("w2_qzeros", w2_qzeros) layer.register_parameter("w2_qzeros", w2_qzeros)
set_weight_attrs(w2_qzeros, extra_weight_attrs) set_weight_attrs(w2_qzeros, extra_weight_attrs)
device = layer.w13_qweight.device
sms = torch.cuda.get_device_properties(device).multi_processor_count
layer.workspace = torch.zeros((sms * 4, ),
dtype=torch.int,
device=device,
requires_grad=False)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
num_experts = layer.w13_qweight.shape[0] num_experts = layer.w13_qweight.shape[0]
device = layer.w13_qweight.device device = layer.w13_qweight.device
...@@ -473,10 +482,7 @@ class AWQMoEMethod(FusedMoEMethodBase): ...@@ -473,10 +482,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
activation: str = "silu", activation: str = "silu",
) -> torch.Tensor: ) -> torch.Tensor:
assert activation == "silu", "Only SiLU activation is supported." assert activation == "silu", "Only SiLU activation is supported."
if expert_map is not None:
raise NotImplementedError(
"Expert Parallelism is not supported for "
"fused Marlin MoE method.")
if apply_router_weight_on_input: if apply_router_weight_on_input:
raise NotImplementedError( raise NotImplementedError(
"Apply router weight on input is not supported for" "Apply router weight on input is not supported for"
...@@ -503,7 +509,10 @@ class AWQMoEMethod(FusedMoEMethodBase): ...@@ -503,7 +509,10 @@ class AWQMoEMethod(FusedMoEMethodBase):
router_logits, router_logits,
topk_weights, topk_weights,
topk_ids, topk_ids,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_zeros=layer.w13_qzeros, w1_zeros=layer.w13_qzeros,
w2_zeros=layer.w2_qzeros, w2_zeros=layer.w2_qzeros,
workspace=layer.workspace,
num_bits=self.quant_config.weight_bits, num_bits=self.quant_config.weight_bits,
) )
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional
import torch
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
BITBLAS_OPTIMIZE_FEATURES, BITBLAS_SUPPORTED_NUM_BITS,
BITBLAS_SUPPORTED_SYM, MINIMUM_BITBLAS_VERSION)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedvLLMParameter)
from vllm.model_executor.utils import set_weight_attrs
logger = init_logger(__name__)
class BitBLASConfig(QuantizationConfig):
"""Config class for BitBLAS.
Reference: https://github.com/Microsoft/BitBLAS
"""
TORCH_DTYPE = torch.float16
STORAGE_DTYPE = "int8" # assume int8 storage
TORCH_STORAGE_DTYPE = getattr(torch, STORAGE_DTYPE)
# "original" or "rescale" or "quantized",
# gptq_with_bitblas prefer "quantized implementation"
ZEROS_MODE = "quantized"
def __init__(
self,
weight_bits: int,
group_size: Optional[int],
desc_act: Optional[bool],
is_sym: Optional[bool],
quant_method: Optional[str],
lm_head_quantized: bool,
) -> None:
try:
import bitblas
if bitblas.__version__ < MINIMUM_BITBLAS_VERSION:
raise ImportError(
"bitblas version is wrong. Please "
f"install bitblas>={MINIMUM_BITBLAS_VERSION}")
except ImportError as e:
bitblas_import_exception = e
raise ValueError(
"Trying to use the bitblas backend, but could not import"
f"with the following error: {bitblas_import_exception}. "
"Please install bitblas through the following command: "
f"`pip install bitblas>={MINIMUM_BITBLAS_VERSION}`"
) from bitblas_import_exception
if desc_act and group_size == -1:
# In this case, act_order == True is the same as act_order == False
# (since we have only one group per output channel)
desc_act = False
self.weight_bits = weight_bits
self.group_size = group_size
self.desc_act = desc_act
self.is_sym = is_sym
self.quant_method = quant_method
self.lm_head_quantized = lm_head_quantized
# Verify
if self.weight_bits not in BITBLAS_SUPPORTED_NUM_BITS:
raise ValueError(
f"BitBLAS does not support weight_bits = {self.weight_bits}. "
f"Only weight_bits = {BITBLAS_SUPPORTED_NUM_BITS} "
"are supported.")
if self.is_sym not in BITBLAS_SUPPORTED_SYM:
raise ValueError(
f"BitBLAS does not support is_sym = {self.is_sym}. "
f"Only sym = {BITBLAS_SUPPORTED_SYM} are supported.")
storage_dtype = self.STORAGE_DTYPE
storage_nbit = int("".join(c for c in storage_dtype if c.isdigit()))
self.storage_dtype = storage_dtype
self.storage_torch_dtype = self.TORCH_STORAGE_DTYPE
# 4 Bits packed into 32 bit datatype.
self.pack_factor = storage_nbit // weight_bits
self.nbits = weight_bits
# Zeros type for the quantized weights.
self.zeros_mode = self.ZEROS_MODE
def __repr__(self) -> str:
return (f"BitBLASConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
f"desc_act={self.desc_act}, "
f"is_sym={self.is_sym}, "
f"quant_method={self.quant_method})")
@classmethod
def get_name(cls) -> str:
return "bitblas"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half, torch.bfloat16]
@classmethod
# Need to figure it out
def get_min_capability(cls) -> int:
return 70
@classmethod
def get_config_filenames(cls) -> List[str]:
return ["quantize_config.json"]
@staticmethod
def get_from_keys(config: Dict[str, Any],
keys: List[str],
default: Any = None) -> Any:
"""Get a value from the model's quantization config."""
for key in keys:
if key in config:
return config[key]
return default
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "BitBLASConfig":
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"], -1)
desc_act = cls.get_from_keys(config, ["desc_act"], False)
is_sym = cls.get_from_keys(config, ["sym"], False)
quant_method = cls.get_from_keys(config, ["quant_method"])
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False)
return cls(weight_bits, group_size, desc_act, is_sym, quant_method,
lm_head_quantized)
@classmethod
def override_quantization_method(cls, hf_quant_cfg,
user_quant) -> Optional[str]:
# compat: autogptq >=0.8.0 use checkpoint_format: str
# compat: autogptq <=0.7.1 is_bitblas_format: bool
is_bitblas_format = (hf_quant_cfg.get("checkpoint_format") == "bitblas"
or hf_quant_cfg.get("is_bitblas_format", False))
is_valid_user_quant = (user_quant is None or user_quant == "gptq"
or user_quant == "bitblas")
if is_bitblas_format and is_valid_user_quant:
msg = ("The model is serialized in {} format. Using {} kernel.".
format(cls.get_name(), cls.get_name()))
logger.info(msg)
return cls.get_name()
return None
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["BitBLASLinearMethod"]:
if isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead)
and self.lm_head_quantized):
return BitBLASLinearMethod(self)
return None
class BitBLASLinearMethod(LinearMethodBase):
"""Linear method for BitBLAS.
Args:
quant_config: The BitBLAS quantization config.
"""
# USE BITBLAS_OPTIMIZE_FEATURES_CONTIGUOUS
# Instead of BITBLAS_OPTIMIZE_FEATURES
# If you want to high contiguous batching
# performance
OPT_FEATURES = BITBLAS_OPTIMIZE_FEATURES
ENABLE_TUNING = True
BITBLAS_DTYPES = {
torch.float32: "float32",
torch.float16: "float16",
torch.bfloat16: "bfloat16",
torch.half: "float16",
torch.int8: "int8",
}
def __init__(self, quant_config: BitBLASConfig):
self.quant_config = quant_config
def create_weights_gptq(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
"""Creates quantized weights for use in linear operations.
The function initializes and returns a dictionary containing quantized
weights, scales, and zeros
for performing quantized matrix multiplication operations.
Args:
input_size_per_partition: The size of the input partition.
output_size_per_partition: The size of the output partition.
input_size: The total size of the input (unused).
output_size: The total size of the output (unused).
params_dtype:
The data type of the parameters (expected to be torch.float16).
Returns:
A dictionary containing the quantized weights ('qweight'),
scales ('scales'), and zeros ('zeros').
Raises:
ValueError: If `params_dtype` is not `torch.float16` or if the
input size per partition is not divisible by the group size in
`quant_config`.
"""
del input_size, output_size # Unused arguments.
weight_loader = extra_weight_attrs["weight_loader"]
if params_dtype not in self.quant_config.get_supported_act_dtypes():
raise ValueError("Parameter data type must be torch.float16, "
f"but got {params_dtype}")
group_size = self.quant_config.group_size
if group_size is None:
group_size = -1
# Validate output_size_per_partition
output_size_per_partition = sum(output_partition_sizes)
if (group_size != -1 and input_size_per_partition % group_size != 0):
raise ValueError(
f"Input size per partition ({input_size_per_partition}) must "
f"be divisible by group size ({group_size}).")
# Initialize or retrieve the BitBLAS matrix multiplication operator.
self._configure_bitblas_matmul(
input_size_per_partition,
output_size_per_partition,
params_dtype=params_dtype,
enable_tuning=self.ENABLE_TUNING,
bias=False,
layout="nt",
bits=self.quant_config.weight_bits,
)
# Initialize quantized weights with dimensions
# Quantized 4Bit weights packed.
qweight = PackedvLLMParameter(
data=torch.empty(
self.bitblas_matmul.retrieve_weight_shape(),
device="cuda",
dtype=self.quant_config.storage_torch_dtype,
requires_grad=False,
),
input_dim=1,
output_dim=0,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
bitblas_tile_size=(self.bitblas_matmul.retrieve_weight_shape()[-2]
if self.bitblas_matmul.propagate_b else None),
weight_loader=weight_loader,
)
# Compute the number of input groups for channel-wise quantization.
input_groups = (1 if group_size == -1 else input_size_per_partition //
group_size)
# Initialize scales and zeros for the quantized weights.
weight_scale_args = {
"data":
torch.empty(
output_size_per_partition,
input_groups,
device="cuda",
dtype=params_dtype,
),
"weight_loader":
weight_loader
}
if input_groups == 1:
scales = ChannelQuantScaleParameter(output_dim=0,
**weight_scale_args)
else:
scales = GroupQuantScaleParameter(output_dim=0,
input_dim=1,
**weight_scale_args)
if self.quant_config.zeros_mode == "quantized":
zeros = PackedvLLMParameter(
data=torch.empty(
input_groups,
output_size_per_partition // self.quant_config.pack_factor,
device="cuda",
dtype=self.quant_config.storage_torch_dtype,
requires_grad=False,
),
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
weight_loader=weight_loader,
)
else:
zeros = BasevLLMParameter(
torch.empty(output_size_per_partition,
input_groups,
device="cuda",
dtype=params_dtype),
weight_loader=weight_loader,
)
# Set attributes to indicate how scales and zeros are applied.
set_weight_attrs(zeros, {
"input_dim": None if input_groups == 1 else 1,
"output_dim": 0,
})
layer.register_parameter("qweight", qweight)
layer.register_parameter("scales", scales)
layer.register_parameter("zeros", zeros)
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
if self.quant_config.quant_method == "gptq":
return self.create_weights_gptq(layer, input_size_per_partition,
output_partition_sizes, input_size,
output_size, params_dtype,
**extra_weight_attrs)
else:
raise ValueError(
f"Unsupported quant_method {self.quant_config.quant_method}")
def _configure_bitblas_matmul(
self,
infeatures,
outfeatures,
params_dtype,
enable_tuning,
bias,
layout,
bits,
out_dtype="float16",
):
from bitblas import MatmulConfig
bitblas_dtype = self.BITBLAS_DTYPES[params_dtype]
with_scaling = False
with_zeros = False
group_size = self.quant_config.group_size
zeros_mode = self.quant_config.zeros_mode
if self.quant_config.quant_method == "gptq":
with_scaling = True
with_zeros = True
W_dtype = f"uint{bits}"
if self.quant_config.is_sym:
with_zeros = False
W_dtype = f"int{bits}"
else:
raise ValueError(
f"Unsupported quant_method {self.quant_config.quant_method}")
matmul_config = MatmulConfig(
N=outfeatures,
K=infeatures,
A_dtype=bitblas_dtype,
W_dtype=W_dtype,
out_dtype=out_dtype,
accum_dtype="int32" if bitblas_dtype == "int8" else bitblas_dtype,
storage_dtype=self.quant_config.STORAGE_DTYPE,
with_scaling=with_scaling,
with_zeros=with_zeros,
group_size=group_size,
with_bias=bias,
layout=layout,
zeros_mode=zeros_mode,
)
self.bitblas_matmul = self._get_or_create_bitblas_operator(
matmul_config, enable_tuning)
def _get_or_create_bitblas_operator(self, config, enable_tuning):
from bitblas import Matmul, auto_detect_nvidia_target
from bitblas.cache import get_database_path, global_operator_cache
BITBLAS_DATABASE_PATH = get_database_path()
BITBLAS_TARGET = auto_detect_nvidia_target()
if global_operator_cache.size() == 0:
global_operator_cache.load_from_database(BITBLAS_DATABASE_PATH,
BITBLAS_TARGET)
bitblas_matmul = global_operator_cache.get(config)
if bitblas_matmul is None:
bitblas_matmul = Matmul(config,
target=BITBLAS_TARGET,
enable_tuning=False)
if enable_tuning:
TUNING_MESSAGE = (f"BitBLAS Operator {config} is tuning ...")
logger.info(TUNING_MESSAGE)
bitblas_matmul.hardware_aware_finetune(topk=20)
global_operator_cache.add(config, bitblas_matmul)
global_operator_cache.save_into_database(
BITBLAS_DATABASE_PATH, BITBLAS_TARGET)
TUNED_MESSAGE = (
f"BitBLAS Operator {config} tuned and saved to database.")
logger.info(TUNED_MESSAGE)
else:
_message = f"BitBLAS Operator {config} created."
logger.info(_message)
else:
_message = (
f"BitBLAS Operator {config} found in global_operator_cache.")
logger.info(_message)
return bitblas_matmul
def apply_gptq(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
qweight = layer.qweight
scales = layer.scales
qzeros = layer.zeros
x_2d = x.view(-1, x.shape[-1])
if self.quant_config.is_sym:
output_2d = self.bitblas_matmul(x_2d, qweight, scales)
else:
output_2d = self.bitblas_matmul(x_2d, qweight, scales, qzeros)
output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
if bias is not None:
output.add_(bias) # In-place add
return output
def apply(
self,
*args: Any,
**kwargs: Any,
) -> torch.Tensor:
if self.quant_config.quant_method == "gptq":
return self.apply_gptq(*args, **kwargs)
else:
raise ValueError(
f"Unsupported quant_method {self.quant_config.quant_method}")
...@@ -72,7 +72,7 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -72,7 +72,7 @@ class CompressedTensorsConfig(QuantizationConfig):
return 70 return 70
def get_name(self) -> str: def get_name(self) -> str:
return "compressed_tensors" return "compressed-tensors"
def get_quant_method( def get_quant_method(
self, self,
...@@ -302,14 +302,12 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -302,14 +302,12 @@ class CompressedTensorsConfig(QuantizationConfig):
def _is_wNa16_group_channel(self, weight_quant: BaseModel, def _is_wNa16_group_channel(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool: input_quant: BaseModel) -> bool:
input_quant_none = input_quant is None input_quant_none = input_quant is None
is_symmetric = weight_quant.symmetric
is_channel_group = ( is_channel_group = (
weight_quant.strategy == QuantizationStrategy.CHANNEL.value weight_quant.strategy == QuantizationStrategy.CHANNEL.value
or weight_quant.strategy == QuantizationStrategy.GROUP.value) or weight_quant.strategy == QuantizationStrategy.GROUP.value)
is_static = not weight_quant.dynamic is_static = not weight_quant.dynamic
return (is_channel_group and input_quant_none and is_symmetric return (is_channel_group and input_quant_none and is_static)
and is_static)
def _get_scheme_from_parts( def _get_scheme_from_parts(
self, weight_quant: BaseModel, self, weight_quant: BaseModel,
...@@ -319,6 +317,7 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -319,6 +317,7 @@ class CompressedTensorsConfig(QuantizationConfig):
if self._is_wNa16_group_channel(weight_quant, input_quant): if self._is_wNa16_group_channel(weight_quant, input_quant):
if (self.quant_format == CompressionFormat.marlin_24.value if (self.quant_format == CompressionFormat.marlin_24.value
and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS): and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS):
assert weight_quant.symmetric
return CompressedTensorsW4A16Sparse24( return CompressedTensorsW4A16Sparse24(
strategy=weight_quant.strategy, strategy=weight_quant.strategy,
num_bits=weight_quant.num_bits, num_bits=weight_quant.num_bits,
...@@ -328,6 +327,7 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -328,6 +327,7 @@ class CompressedTensorsConfig(QuantizationConfig):
return CompressedTensorsWNA16( return CompressedTensorsWNA16(
num_bits=weight_quant.num_bits, num_bits=weight_quant.num_bits,
strategy=weight_quant.strategy, strategy=weight_quant.strategy,
symmetric=weight_quant.symmetric,
group_size=weight_quant.group_size, group_size=weight_quant.group_size,
actorder=weight_quant.actorder) actorder=weight_quant.actorder)
......
...@@ -67,7 +67,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): ...@@ -67,7 +67,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
else: else:
return CompressedTensorsWNA16MarlinMoEMethod(quant_config) return CompressedTensorsWNA16MarlinMoEMethod(quant_config)
elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant) elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant)
and layer.activation == "silu" and layer.expert_map is None): and layer.activation == "silu"):
return CompressedTensorsW8A8Fp8MoECutlassMethod(quant_config) return CompressedTensorsW8A8Fp8MoECutlassMethod(quant_config)
elif quant_config._is_fp8_w8a8(weight_quant, input_quant): elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Fp8MoEMethod(quant_config) return CompressedTensorsW8A8Fp8MoEMethod(quant_config)
...@@ -250,6 +250,28 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -250,6 +250,28 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
requires_grad=False) requires_grad=False)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled)
# Property to determine if AITER is used
if is_rocm_aiter_moe_enabled():
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501
rocm_aiter_fused_experts, shuffle_weights)
# reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data)
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
requires_grad=False)
layer.w2_weight = torch.nn.Parameter(shuffled_w2,
requires_grad=False)
self.fused_experts_func = rocm_aiter_fused_experts
else:
from vllm.model_executor.layers.fused_moe import fused_experts
self.fused_experts_func = fused_experts
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -268,7 +290,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -268,7 +290,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
) -> torch.Tensor: ) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
...@@ -282,10 +303,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -282,10 +303,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias) e_score_correction_bias=e_score_correction_bias)
return fused_experts( return self.fused_experts_func(
x, hidden_states=x,
layer.w13_weight, w1=layer.w13_weight,
layer.w2_weight, w2=layer.w2_weight,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=True,
...@@ -489,8 +510,6 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod): ...@@ -489,8 +510,6 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
) -> torch.Tensor: ) -> torch.Tensor:
assert activation == "silu" assert activation == "silu"
assert global_num_experts == layer.w13_weight.shape[0]
assert expert_map is None
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
...@@ -521,6 +540,7 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod): ...@@ -521,6 +540,7 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
a1_scale=layer.w13_input_scale, a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
out_dtype=x.dtype, out_dtype=x.dtype,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
) )
......
...@@ -12,11 +12,15 @@ from vllm.model_executor.layers.quantization.kernels.mixed_precision import ( ...@@ -12,11 +12,15 @@ from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
MPLinearLayerConfig, choose_mp_linear_kernel) MPLinearLayerConfig, choose_mp_linear_kernel)
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_repeat_scales_on_all_ranks) marlin_repeat_scales_on_all_ranks)
# yapf conflicts with isort for this block
# yapf: disable
from vllm.model_executor.parameter import (BasevLLMParameter, from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter, ChannelQuantScaleParameter,
GroupQuantScaleParameter, GroupQuantScaleParameter,
PackedColumnParameter,
PackedvLLMParameter, PackedvLLMParameter,
RowvLLMParameter) RowvLLMParameter)
# yapf: enable
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -26,6 +30,7 @@ WNA16_SUPPORTED_TYPES_MAP = { ...@@ -26,6 +30,7 @@ WNA16_SUPPORTED_TYPES_MAP = {
4: scalar_types.uint4b8, 4: scalar_types.uint4b8,
8: scalar_types.uint8b128 8: scalar_types.uint8b128
} }
WNA16_ZP_SUPPORTED_TYPES_MAP = {4: scalar_types.uint4, 8: scalar_types.uint8}
WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys()) WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys())
...@@ -36,10 +41,12 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): ...@@ -36,10 +41,12 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
strategy: str, strategy: str,
num_bits: int, num_bits: int,
group_size: Optional[int] = None, group_size: Optional[int] = None,
symmetric: Optional[bool] = True,
actorder: Optional[ActivationOrdering] = None): actorder: Optional[ActivationOrdering] = None):
self.pack_factor = 32 // num_bits self.pack_factor = 32 // num_bits
self.strategy = strategy self.strategy = strategy
self.symmetric = symmetric
self.group_size = -1 if group_size is None else group_size self.group_size = -1 if group_size is None else group_size
self.has_g_idx = actorder == ActivationOrdering.GROUP self.has_g_idx = actorder == ActivationOrdering.GROUP
...@@ -53,7 +60,9 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): ...@@ -53,7 +60,9 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
f"Unsupported num_bits = {num_bits}. " f"Unsupported num_bits = {num_bits}. "
f"Supported num_bits = {WNA16_SUPPORTED_TYPES_MAP.keys()}") f"Supported num_bits = {WNA16_SUPPORTED_TYPES_MAP.keys()}")
self.quant_type = WNA16_SUPPORTED_TYPES_MAP[num_bits] self.quant_type = (WNA16_ZP_SUPPORTED_TYPES_MAP[num_bits]
if not self.symmetric else
WNA16_SUPPORTED_TYPES_MAP[num_bits])
@classmethod @classmethod
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
...@@ -75,7 +84,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): ...@@ -75,7 +84,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
weight_type=self.quant_type, weight_type=self.quant_type,
act_type=params_dtype, act_type=params_dtype,
group_size=self.group_size, group_size=self.group_size,
zero_points=False, zero_points=not self.symmetric,
has_g_idx=self.has_g_idx has_g_idx=self.has_g_idx
) )
...@@ -120,13 +129,37 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): ...@@ -120,13 +129,37 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
dtype=params_dtype, dtype=params_dtype,
) )
} }
zeros_args = {
"weight_loader":
weight_loader,
"data":
torch.zeros(
output_size_per_partition // self.pack_factor,
scales_and_zp_size,
dtype=torch.int32,
)
}
if not partition_scales: if not partition_scales:
weight_scale = ChannelQuantScaleParameter(output_dim=0, weight_scale = ChannelQuantScaleParameter(output_dim=0,
**weight_scale_args) **weight_scale_args)
if not self.symmetric:
qzeros = PackedColumnParameter(output_dim=0,
packed_dim=0,
packed_factor=self.pack_factor,
**zeros_args)
else: else:
weight_scale = GroupQuantScaleParameter(output_dim=0, weight_scale = GroupQuantScaleParameter(output_dim=0,
input_dim=1, input_dim=1,
**weight_scale_args) **weight_scale_args)
if not self.symmetric:
qzeros = PackedvLLMParameter(input_dim=1,
output_dim=0,
packed_dim=0,
packed_factor=self.pack_factor,
**zeros_args)
# A 2D array defining the original shape of the weights # A 2D array defining the original shape of the weights
# before packing # before packing
...@@ -138,6 +171,9 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): ...@@ -138,6 +171,9 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
layer.register_parameter("weight_scale", weight_scale) layer.register_parameter("weight_scale", weight_scale)
layer.register_parameter("weight_shape", weight_shape) layer.register_parameter("weight_shape", weight_shape)
if not self.symmetric:
layer.register_parameter("weight_zero_point", qzeros)
# group index (for activation reordering) # group index (for activation reordering)
if self.has_g_idx: if self.has_g_idx:
weight_g_idx = RowvLLMParameter(data=torch.empty( weight_g_idx = RowvLLMParameter(data=torch.empty(
...@@ -151,7 +187,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): ...@@ -151,7 +187,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
self.kernel = kernel_type(mp_linear_kernel_config, self.kernel = kernel_type(mp_linear_kernel_config,
w_q_param_name="weight_packed", w_q_param_name="weight_packed",
w_s_param_name="weight_scale", w_s_param_name="weight_scale",
w_zp_param_name=None, w_zp_param_name="weight_zero_point",
w_gidx_param_name="weight_g_idx") w_gidx_param_name="weight_g_idx")
# Checkpoints are serialized in compressed-tensors format, which is # Checkpoints are serialized in compressed-tensors format, which is
......
...@@ -140,6 +140,11 @@ class Fp8Config(QuantizationConfig): ...@@ -140,6 +140,11 @@ class Fp8Config(QuantizationConfig):
return name.replace(".k_proj.output_scale", ".attn.k_scale") return name.replace(".k_proj.output_scale", ".attn.k_scale")
if name.endswith(".output_scale") and ".v_proj" in name: if name.endswith(".output_scale") and ".v_proj" in name:
return name.replace(".v_proj.output_scale", ".attn.v_scale") return name.replace(".v_proj.output_scale", ".attn.v_scale")
if name.endswith(".output_scale") and ".q_proj" in name:
return name.replace(".q_proj.output_scale", ".attn.q_scale")
if name.endswith("self_attn.prob_output_scale"):
return name.replace(".prob_output_scale", ".attn.prob_scale")
# If no matches, return None
return None return None
...@@ -575,8 +580,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -575,8 +580,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def process_weights_after_loading(self, layer: Module) -> None: def process_weights_after_loading(self, layer: Module) -> None:
# Lazy import to avoid importing triton too early. # Lazy import to avoid importing triton too early.
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
expand_weights, is_rocm_aiter_block_scaled_moe_enabled, expand_weights, is_rocm_aiter_moe_enabled, shuffle_weights)
is_rocm_aiter_moe_enabled, shuffle_weights)
# TODO (rob): refactor block quant into separate class. # TODO (rob): refactor block quant into separate class.
if self.block_quant: if self.block_quant:
...@@ -603,7 +607,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -603,7 +607,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w2_weight = Parameter(w2_weight, requires_grad=False) layer.w2_weight = Parameter(w2_weight, requires_grad=False)
layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv, layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv,
requires_grad=False) requires_grad=False)
if is_rocm_aiter_block_scaled_moe_enabled(): if is_rocm_aiter_moe_enabled():
# reshaping weights is required for aiter moe kernel. # reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = shuffle_weights( shuffled_w13, shuffled_w2 = shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data) layer.w13_weight.data, layer.w2_weight.data)
......
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional, Set
import torch
from torch.nn.parameter import Parameter
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
BitBLASLinearKernel, MPLinearLayerConfig)
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
BITBLAS_SUPPORTED_NUM_BITS as GPTQ_BITBLAS_SUPPORTED_NUM_BITS)
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
BITBLAS_SUPPORTED_SYM as GPTQ_BITBLAS_SUPPORTED_SYM)
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
MINIMUM_BITBLAS_VERSION, bitblas_repeat_scales_on_all_ranks,
check_bitblas_supported, verify_bitblas_supported)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedColumnParameter,
PackedvLLMParameter,
RowvLLMParameter)
from vllm.scalar_type import scalar_types
logger = init_logger(__name__)
class GPTQBitBLASConfig(QuantizationConfig):
"""Config class for GPTQ BitBLAS"""
# (num_bits, is_sym) -> quant_type
TYPE_MAP = {
(4, True): scalar_types.uint4b8,
(8, True): scalar_types.uint8b128,
}
TORCH_DTYPE = torch.float16
GPTQ_CKPT_STORAGE_DTYPE = (
"int32" # GPTQ Default Checkpoints use int32 as storage dtype
)
GPTQ_BITBLAS_STORAGE_DTYPE = "int8" # BitBLAS uses int8 as storage dtype
TORCH_BITBLAS_STORAGE_DTYPE = getattr(torch, GPTQ_BITBLAS_STORAGE_DTYPE)
# "original" or "rescale" or "quantized",
# the gptq_bitblas prefer "quantized"
ZEROS_MODE = "quantized"
def __init__(
self,
weight_bits: int,
group_size: int,
desc_act: bool,
is_sym: bool,
quant_method: Optional[str],
lm_head_quantized: bool,
) -> None:
try:
import bitblas
if bitblas.__version__ < MINIMUM_BITBLAS_VERSION:
raise ImportError(
"bitblas version is wrong. Please "
f"install bitblas>={MINIMUM_BITBLAS_VERSION}")
except ImportError as e:
bitblas_import_exception = e
raise ValueError(
"Trying to use the bitblas backend, but could not import"
f"with the following error: {bitblas_import_exception}. "
"Please install bitblas through the following command: "
f"`pip install bitblas>={MINIMUM_BITBLAS_VERSION}`"
) from bitblas_import_exception
if desc_act and group_size == -1:
# In this case, act_order == True is the same as act_order == False
# (since we have only one group per output channel)
desc_act = False
self.weight_bits = weight_bits
self.group_size = group_size
self.desc_act = desc_act
self.is_sym = is_sym
self.quant_method = quant_method
self.lm_head_quantized = lm_head_quantized
# Verify
if self.weight_bits not in GPTQ_BITBLAS_SUPPORTED_NUM_BITS:
raise ValueError(
f"BitBLAS does not support weight_bits = {self.weight_bits}. "
f"Only weight_bits = {GPTQ_BITBLAS_SUPPORTED_NUM_BITS} "
"are supported.")
if self.is_sym not in GPTQ_BITBLAS_SUPPORTED_SYM:
raise ValueError(
f"BitBLAS does not support is_sym = {self.is_sym}. "
f"Only sym = {GPTQ_BITBLAS_SUPPORTED_SYM} are supported.")
self.storage_dtype = self.GPTQ_BITBLAS_STORAGE_DTYPE
storage_nbit = int("".join(c for c in self.GPTQ_CKPT_STORAGE_DTYPE
if c.isdigit()))
# 4 Bits packed into 32 bit datatype.
self.pack_factor = storage_nbit // weight_bits
self.nbits = weight_bits
# Zeros type for the quantized weights.
self.zeros_mode = self.ZEROS_MODE
if (weight_bits, is_sym) not in self.TYPE_MAP:
raise ValueError("Unsupported quantization config: "
f"bits={weight_bits}, sym={is_sym}")
self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)]
def __repr__(self) -> str:
return (f"GPTQBitBLASConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
f"desc_act={self.desc_act})"
f"is_sym={self.is_sym}, "
f"quant_method={self.quant_method})")
@classmethod
def get_name(cls) -> str:
return "gptq_bitblas"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half, torch.bfloat16]
@classmethod
def get_min_capability(cls) -> int:
return 70
@classmethod
def get_config_filenames(cls) -> List[str]:
return ["quantize_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "GPTQBitBLASConfig":
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
desc_act = cls.get_from_keys(config, ["desc_act"])
is_sym = cls.get_from_keys(config, ["sym"])
quant_method = cls.get_from_keys(config, ["quant_method"])
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False)
return cls(weight_bits, group_size, desc_act, is_sym, quant_method,
lm_head_quantized)
@classmethod
def override_quantization_method(cls, hf_quant_cfg,
user_quant) -> Optional[str]:
can_convert = cls.is_gptq_bitblas_compatible(hf_quant_cfg)
is_valid_user_quant = (user_quant is None or user_quant == "bitblas"
or user_quant == "gptq_bitblas")
if can_convert and is_valid_user_quant:
msg = ("The model is convertible to {} during runtime."
" Using {} kernel.".format(cls.get_name(), cls.get_name()))
logger.info(msg)
return cls.get_name()
if can_convert and user_quant == "gptq":
logger.info("Detected that the model can run with gptq_bitblas"
", however you specified quantization=gptq explicitly,"
" so forcing gptq. Use quantization=gptq_bitblas for"
" faster inference")
return None
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["GPTQBitBLASLinearMethod"]:
if isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead)
and self.lm_head_quantized):
return GPTQBitBLASLinearMethod(self)
return None
@property
def torch_storage_dtype(self) -> torch.dtype:
return self.TORCH_BITBLAS_STORAGE_DTYPE
@classmethod
def is_gptq_bitblas_compatible(cls, quant_config: Dict[str, Any]):
# Extract data from quant config.
num_bits = quant_config.get("bits")
group_size = quant_config.get("group_size")
sym = quant_config.get("sym")
desc_act = quant_config.get("desc_act")
# If we cannot find the info needed in the config, cannot convert.
if (num_bits is None or group_size is None or sym is None
or desc_act is None):
return False
if (num_bits, sym) not in cls.TYPE_MAP:
return False
# If the capability of the device is too low, cannot convert.
major, minor = torch.cuda.get_device_capability()
device_capability = major * 10 + minor
if device_capability < cls.get_min_capability():
return False
# Otherwise, can convert if model satisfies bitblas constraints.
return check_bitblas_supported(quant_type=cls.TYPE_MAP[(num_bits,
sym)],
group_size=group_size)
class GPTQBitBLASLinearMethod(LinearMethodBase):
"""Linear method for GPTQ BitBLAS.
Args:
quant_config: The GPTQ BitBLAS quantization config.
"""
kernel_type = BitBLASLinearKernel
_kernel_backends_being_used: Set[str] = set()
def __init__(self, quant_config: GPTQBitBLASConfig) -> None:
self.quant_config = quant_config
# Verify supported on platform.
verify_bitblas_supported(quant_type=self.quant_config.quant_type,
group_size=self.quant_config.group_size)
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> None:
"""Creates quantized weights for use in linear operations.
The function initializes and returns a dictionary containing
quantized weights, scales, and zeros
for performing quantized matrix multiplication operations.
Args:
input_size_per_partition: The size of the input partition.
output_partition_sizes: The size of the output partition.
input_size: The total size of the input (unused).
output_size: The total size of the output (unused).
params_dtype:
The data type of the parameters (expected to be torch.float16).
Returns:
A dictionary containing the quantized weights ('qweight'),
scales ('scales'), and zeros ('zeros').
Raises:
ValueError: If `params_dtype` is not `torch.float16` or
if the input size per partition is not divisible by the
group size in `quant_config`.
"""
if params_dtype != torch.float16:
raise ValueError("Parameter data type must be torch.float16, "
f"but got {params_dtype}")
# Normalize group_size
if self.quant_config.group_size != -1:
group_size = self.quant_config.group_size
else:
group_size = input_size
if input_size_per_partition % group_size != 0:
raise ValueError(
f"Input size per partition ({input_size_per_partition}) must "
f"be divisible by group size ({self.quant_config.group_size})."
)
kernel_type = self.kernel_type
# Validate output_size_per_partition
output_size_per_partition = sum(output_partition_sizes)
is_row_parallel = input_size != input_size_per_partition
weight_loader = extra_weight_attrs.get("weight_loader")
mp_linear_kernel_config = MPLinearLayerConfig(
full_weight_shape=(input_size, output_size),
partition_weight_shape=\
(input_size_per_partition, output_size_per_partition),
weight_type=self.quant_config.quant_type,
act_type=params_dtype,
group_size=self.quant_config.group_size,
zero_points=False,
has_g_idx=self.quant_config.desc_act
)
if kernel_type.__name__ not in self._kernel_backends_being_used:
logger.info("Using %s for GPTQBitBLASLinearMethod",
kernel_type.__name__)
self._kernel_backends_being_used.add(kernel_type.__name__)
# Normalize group_size
if self.quant_config.group_size != -1:
group_size = self.quant_config.group_size
else:
group_size = input_size
# Determine sharding
if bitblas_repeat_scales_on_all_ranks(self.quant_config.desc_act,
self.quant_config.group_size,
is_row_parallel):
# By setting scale_dim == None, weight_loader will
# repeat the scales on each GPU in TP>1 case.
scales_and_zp_input_dim = None
scales_and_zp_size = input_size // group_size
else:
# By setting scale_dim == 0, weight_loader will
# shard the scales in TP>1 case.
scales_and_zp_input_dim = 0
scales_and_zp_size = input_size_per_partition // group_size
# Init buffers
# Quantized weights
qweight = PackedvLLMParameter(
data=torch.empty(
input_size_per_partition // self.quant_config.pack_factor,
output_size_per_partition,
dtype=torch.int32,
),
input_dim=0,
output_dim=1,
packed_dim=0,
packed_factor=self.quant_config.pack_factor,
weight_loader=weight_loader)
# Activation order
# Ignore warning from fused linear layers such as QKVParallelLinear.
g_idx = RowvLLMParameter(data=torch.empty(
input_size_per_partition,
dtype=torch.int32,
),
input_dim=0,
weight_loader=weight_loader)
# Scales
scales = Parameter(
torch.empty(
scales_and_zp_size,
output_size_per_partition,
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(
scales,
{
**extra_weight_attrs,
"input_dim": scales_and_zp_input_dim,
"output_dim": 1,
},
)
# Quantized zero-points
qzeros_args = {
"data":
torch.empty(
scales_and_zp_size,
output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
),
"weight_loader":
weight_loader
}
weight_scale_args = {
"data":
torch.empty(
scales_and_zp_size,
output_size_per_partition,
dtype=params_dtype,
),
"weight_loader":
weight_loader
}
if scales_and_zp_input_dim is None:
scales = ChannelQuantScaleParameter(output_dim=1,
**weight_scale_args)
qzeros = PackedColumnParameter(
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
**qzeros_args)
else:
scales = GroupQuantScaleParameter(output_dim=1,
input_dim=0,
**weight_scale_args)
qzeros = PackedvLLMParameter(
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
**qzeros_args)
layer.register_parameter("qweight", qweight)
layer.register_parameter("g_idx", g_idx)
layer.register_parameter("scales", scales)
layer.register_parameter("qzeros", qzeros)
self.kernel = kernel_type(
mp_linear_kernel_config,
w_q_param_name="qweight",
w_s_param_name="scales",
w_zp_param_name="qzeros",
w_gidx_param_name="g_idx",
bitblas_quant_config=self.quant_config,
)
# Initialize or retrieve the BitBLAS matrix multiplication operator.
self.kernel.configure_bitblas_matmul(
input_size_per_partition,
output_size_per_partition,
params_dtype=params_dtype,
bias=False,
)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
self.kernel.process_weights_after_loading(layer)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
out = self.kernel.apply_gptq_bitblas_linear(layer, x)
if bias is not None:
out.add_(bias)
return out
...@@ -15,13 +15,13 @@ from vllm.model_executor.layers.quantization.base_config import ( ...@@ -15,13 +15,13 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kernels.mixed_precision import ( from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
MPLinearLayerConfig, choose_mp_linear_kernel) MPLinearLayerConfig, choose_mp_linear_kernel)
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config
from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.gptq_utils import ( from vllm.model_executor.layers.quantization.utils.gptq_utils import (
get_linear_quant_method) get_linear_quant_method)
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supported, marlin_moe_permute_scales, check_marlin_supported, check_moe_marlin_supports_layer,
marlin_repeat_scales_on_all_ranks, verify_marlin_supported) marlin_moe_permute_scales, marlin_repeat_scales_on_all_ranks,
verify_marlin_supported)
from vllm.model_executor.parameter import (ChannelQuantScaleParameter, from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
GroupQuantScaleParameter, GroupQuantScaleParameter,
PackedColumnParameter, PackedColumnParameter,
...@@ -153,12 +153,15 @@ class GPTQMarlinConfig(QuantizationConfig): ...@@ -153,12 +153,15 @@ class GPTQMarlinConfig(QuantizationConfig):
def get_quant_method(self, layer: torch.nn.Module, def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]: prefix: str) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, FusedMoE): if isinstance(layer, FusedMoE):
if layer.local_num_experts > 32: from vllm.model_executor.layers.quantization.moe_wna16 import (
# For MoEs with many experts the moe_wna16 kernel is faster MoeWNA16Config)
if not check_moe_marlin_supports_layer(layer, self.group_size):
logger.warning_one(
f"Layer '{prefix}' is not supported by GPTQMoeMarlin. "
"Falling back to Moe WNA16 kernels.")
return MoeWNA16Config.from_config( return MoeWNA16Config.from_config(
self.full_config).get_quant_method(layer, prefix) self.full_config).get_quant_method(layer, prefix)
else: return GPTQMarlinMoEMethod(self)
return GPTQMarlinMoEMethod(self)
return get_linear_quant_method(self, layer, prefix, return get_linear_quant_method(self, layer, prefix,
GPTQMarlinLinearMethod) GPTQMarlinLinearMethod)
...@@ -408,7 +411,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -408,7 +411,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
torch.empty(num_experts, torch.empty(num_experts,
scales_size13, scales_size13,
2 * intermediate_size_per_partition, 2 * intermediate_size_per_partition,
dtype=torch.half), dtype=params_dtype),
requires_grad=False, requires_grad=False,
) )
layer.register_parameter("w13_scales", w13_scales) layer.register_parameter("w13_scales", w13_scales)
...@@ -418,7 +421,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -418,7 +421,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
torch.empty(num_experts, torch.empty(num_experts,
scales_size2, scales_size2,
hidden_size, hidden_size,
dtype=torch.half), dtype=params_dtype),
requires_grad=False, requires_grad=False,
) )
layer.register_parameter("w2_scales", w2_scales) layer.register_parameter("w2_scales", w2_scales)
...@@ -493,6 +496,13 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -493,6 +496,13 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
w2_g_idx_sort_indices) w2_g_idx_sort_indices)
set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs) set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)
device = layer.w13_qweight.device
sms = torch.cuda.get_device_properties(device).multi_processor_count
layer.workspace = torch.zeros((sms * 4, ),
dtype=torch.int,
device=device,
requires_grad=False)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Process act_order # Process act_order
...@@ -601,10 +611,6 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -601,10 +611,6 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
"Apply router weight on input is not supported for" "Apply router weight on input is not supported for"
"fused Marlin MoE method.") "fused Marlin MoE method.")
# The input must currently be float16
orig_dtype = x.dtype
x = x.half()
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
...@@ -626,9 +632,12 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -626,9 +632,12 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
router_logits, router_logits,
topk_weights, topk_weights,
topk_ids, topk_ids,
global_num_experts=global_num_experts,
expert_map=expert_map,
g_idx1=layer.w13_g_idx, g_idx1=layer.w13_g_idx,
g_idx2=layer.w2_g_idx, g_idx2=layer.w2_g_idx,
sort_indices1=layer.w13_g_idx_sort_indices, sort_indices1=layer.w13_g_idx_sort_indices,
sort_indices2=layer.w2_g_idx_sort_indices, sort_indices2=layer.w2_g_idx_sort_indices,
num_bits=self.quant_config.quant_type.size_bits, num_bits=self.quant_config.quant_type.size_bits,
is_k_full=self.is_k_full).to(orig_dtype) workspace=layer.workspace,
is_k_full=self.is_k_full)
...@@ -5,6 +5,8 @@ from typing import List, Optional, Type ...@@ -5,6 +5,8 @@ from typing import List, Optional, Type
import vllm.envs as envs import vllm.envs as envs
from vllm.model_executor.layers.quantization.kernels.mixed_precision.allspark import ( # noqa: E501 from vllm.model_executor.layers.quantization.kernels.mixed_precision.allspark import ( # noqa: E501
AllSparkLinearKernel) AllSparkLinearKernel)
from vllm.model_executor.layers.quantization.kernels.mixed_precision.bitblas import ( # noqa: E501
BitBLASLinearKernel)
from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( # noqa: E501 from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( # noqa: E501
ExllamaLinearKernel) ExllamaLinearKernel)
from vllm.model_executor.layers.quantization.kernels.mixed_precision.machete import ( # noqa: E501 from vllm.model_executor.layers.quantization.kernels.mixed_precision.machete import ( # noqa: E501
...@@ -20,6 +22,7 @@ _POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [ ...@@ -20,6 +22,7 @@ _POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [
MacheteLinearKernel, MacheteLinearKernel,
AllSparkLinearKernel, AllSparkLinearKernel,
MarlinLinearKernel, MarlinLinearKernel,
BitBLASLinearKernel,
ExllamaLinearKernel, ExllamaLinearKernel,
] ]
...@@ -76,4 +79,4 @@ def choose_mp_linear_kernel( ...@@ -76,4 +79,4 @@ def choose_mp_linear_kernel(
raise ValueError( raise ValueError(
"Failed to find a kernel that can implement the "\ "Failed to find a kernel that can implement the "\
"WNA16 linear layer. Reasons: \n" "WNA16 linear layer. Reasons: \n"
+ '\n'.join(failure_reasons)) + '\n'.join(failure_reasons))
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0
from typing import Dict, List, Optional, Tuple
import torch
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
BITBLAS_OPTIMIZE_FEATURES, BITBLAS_SUPPORTED_GROUP_SIZES,
MINIMUM_BITBLAS_VERSION, bitblas_make_empty_g_idx, bitblas_sort_g_idx,
check_bitblas_supports_shape, query_bitblas_supported_quant_types,
unpack_gptq_qweight, unpack_gptq_qzeros)
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
logger = init_logger(__name__)
class BitBLASLinearKernel(MPLinearKernel):
OPT_FEATURES: List[int] = BITBLAS_OPTIMIZE_FEATURES
ENABLE_TUNING: bool = True
MATMUL_LAYOUT: str = "nt"
BITBLAS_DTYPES: Dict[torch.dtype, str] = {
torch.float32: "float32",
torch.float16: "float16",
torch.bfloat16: "bfloat16",
torch.half: "float16",
torch.int8: "int8",
}
bitblas_matmul: object = None
def __init__(
self,
c: MPLinearLayerConfig,
w_q_param_name: str,
w_s_param_name: str,
w_zp_param_name: Optional[str] = None,
w_gidx_param_name: Optional[str] = None,
bitblas_quant_config: Optional[QuantizationConfig] = None,
):
self.quant_config = bitblas_quant_config
super().__init__(c, w_q_param_name, w_s_param_name, w_zp_param_name,
w_gidx_param_name)
def repack_bitblas_from_gptq(
self,
b_q_weight: torch.Tensor,
scales: torch.Tensor,
qzeros: Optional[torch.Tensor] = None,
):
from bitblas.quantization.utils import general_compress
assert self.bitblas_matmul is not None, "bitblas_matmul is None"
quant_config = self.quant_config
# qweight in gptq old quant linear stored with
# (outfeatures, infeatures), should be transposed.
qweight = b_q_weight.T.contiguous().view(
quant_config.torch_storage_dtype) # type: ignore[union-attr]
intweight = unpack_gptq_qweight(
qweight,
quant_config.weight_bits).contiguous() # type: ignore[union-attr]
if self.bitblas_matmul.weight_transform is not None: # type: ignore[attr-defined]
qweight = self.bitblas_matmul.weight_transform( # type: ignore[attr-defined]
intweight.cpu()).cuda()
# scales in gptq old quant linear stored with
# (infeatures // group_size, outfeatures), should be transposed.
scales = scales.T.contiguous()
if qzeros is None:
return qweight, scales, None
# qzeros should be de-quantized to int zeros.
weight_bits = quant_config.weight_bits # type: ignore[union-attr]
intzeros = unpack_gptq_qzeros(qzeros, weight_bits).T.contiguous()
zeros: Optional[torch.Tensor] = None
zeros_mode = self.bitblas_matmul.config.zeros_mode # type: ignore[attr-defined]
if zeros_mode == "original":
zeros = intzeros.to(torch.float16).contiguous()
elif zeros_mode == "rescale":
assert zeros is not None, "zeros should not be None"
zeros[:, :] = intzeros.to(torch.float16)[:, :] * scales[:, :]
elif zeros_mode == "quantized":
zeros = (
torch.Tensor(
general_compress(
intzeros.T.contiguous().cpu().numpy(),
weight_bits,
)).to(qweight.device).
to(quant_config.torch_storage_dtype # type: ignore[union-attr]
).contiguous())
else:
raise ValueError("Unsupported zeros type: {}".format(zeros_mode))
return qweight, scales, zeros
@classmethod
def get_min_capability(cls) -> int:
return 70
@classmethod
def can_implement(cls,
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
is_bitblas_installed = True
try:
import bitblas
if bitblas.__version__ < MINIMUM_BITBLAS_VERSION:
raise ImportError(
"bitblas version is wrong. Please "
f"install bitblas>={MINIMUM_BITBLAS_VERSION}")
except ImportError:
is_bitblas_installed = False
if not is_bitblas_installed:
return False, "bitblas is not installed. Please install bitblas "\
"by running `pip install bitblas>="\
f"{MINIMUM_BITBLAS_VERSION}`"
quant_types = query_bitblas_supported_quant_types(c.zero_points)
if c.weight_type not in quant_types:
return False, (f"Quant type ({c.weight_type}) not supported by"
f" BitBLAS, supported types are: {quant_types}")
if c.group_size not in BITBLAS_SUPPORTED_GROUP_SIZES:
return False, (f"Group size ({c.group_size}) not supported by "
"BitBLAS, supported group sizes are: "
f"{BITBLAS_SUPPORTED_GROUP_SIZES}")
return check_bitblas_supports_shape(
c.partition_weight_shape[1], # out_features
c.partition_weight_shape[0], # in_features
c.full_weight_shape[0], # in_features
c.group_size)
# note assumes that
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_scale` is: {input_dim = 0, output_dim = 1}
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
device = getattr(layer, self.w_q_name).device
c = self.config
quant_config = self.quant_config
# Default names since bitblas requires empty parameters for these,
# TODO: remove this requirement from bitblas (allow optional tensors)
if self.w_gidx_name is None:
self.w_gidx_name = "g_idx"
if self.w_zp_name is None:
self.w_zp_name = "qzeros"
if c.has_g_idx:
g_idx, g_idx_sort_indices = bitblas_sort_g_idx(
getattr(layer, self.w_gidx_name))
self._transform_param(layer, self.w_gidx_name, lambda _: g_idx)
layer.g_idx_sort_indices = g_idx_sort_indices
else:
setattr(layer, self.w_gidx_name, bitblas_make_empty_g_idx(device))
layer.g_idx_sort_indices = bitblas_make_empty_g_idx(device)
if c.zero_points:
raise NotImplementedError("Zero points not supported by BitBLAS")
else:
setattr(layer, self.w_zp_name, bitblas_make_empty_g_idx(device))
# Repack weights
bitblas_qweight, bitblas_scales, bitblas_qzeros = (
self.repack_bitblas_from_gptq(
layer.qweight,
layer.scales,
None if quant_config.is_sym else # type: ignore[union-attr]
layer.qzeros, # type: ignore[union-attr]
))
replace_parameter(layer, self.w_q_name, bitblas_qweight)
replace_parameter(layer, self.w_s_name, bitblas_scales)
if bitblas_qzeros is not None:
replace_parameter(layer, self.w_zp_name, bitblas_qzeros)
def configure_bitblas_matmul(
self,
infeatures: int,
outfeatures: int,
params_dtype: torch.dtype,
bias: bool,
) -> None:
enable_tuning = self.ENABLE_TUNING
layout = self.MATMUL_LAYOUT
bits = self.quant_config.weight_bits # type: ignore[union-attr]
self._configure_bitblas_matmul(
infeatures,
outfeatures,
params_dtype,
enable_tuning,
bias,
layout,
bits,
)
def _configure_bitblas_matmul(
self,
infeatures,
outfeatures,
params_dtype,
enable_tuning,
bias,
layout,
bits,
):
from bitblas import MatmulConfig
bitblas_dtype = self.BITBLAS_DTYPES[params_dtype]
quant_config = self.quant_config
with_scaling = False
with_zeros = False
group_size = quant_config.group_size # type: ignore[union-attr]
zeros_mode = quant_config.zeros_mode # type: ignore[union-attr]
if quant_config.quant_method == "gptq": # type: ignore[union-attr]
with_scaling = True
with_zeros = True
W_dtype = f"uint{bits}"
if quant_config.is_sym: # type: ignore[union-attr]
with_zeros = False
W_dtype = f"int{bits}"
else:
raise ValueError(
f"Unsupported quant_method {quant_config.quant_method}" # type: ignore[union-attr]
) # type: ignore[union-attr]
matmul_config = MatmulConfig(
M=self.OPT_FEATURES,
N=outfeatures,
K=infeatures,
A_dtype=bitblas_dtype,
W_dtype=W_dtype,
out_dtype=bitblas_dtype,
accum_dtype="int32" if bitblas_dtype == "int8" else bitblas_dtype,
storage_dtype=quant_config. # type: ignore[union-attr]
storage_dtype, # type: ignore[union-attr]
with_scaling=with_scaling,
with_zeros=with_zeros,
group_size=group_size,
with_bias=bias,
layout=layout,
zeros_mode=zeros_mode,
)
self.bitblas_matmul = self._get_or_create_bitblas_operator(
matmul_config, enable_tuning)
def _get_or_create_bitblas_operator(self, config, enable_tuning):
from bitblas import Matmul, auto_detect_nvidia_target
from bitblas.cache import get_database_path, global_operator_cache
BITBLAS_DATABASE_PATH = get_database_path()
BITBLAS_TARGET = auto_detect_nvidia_target()
if global_operator_cache.size() == 0:
global_operator_cache.load_from_database(BITBLAS_DATABASE_PATH,
BITBLAS_TARGET)
bitblas_matmul = global_operator_cache.get(config)
if bitblas_matmul is None:
bitblas_matmul = Matmul(config,
target=BITBLAS_TARGET,
enable_tuning=False)
if enable_tuning:
bitblas_matmul.hardware_aware_finetune(topk=20)
global_operator_cache.add(config, bitblas_matmul)
global_operator_cache.save_into_database(
BITBLAS_DATABASE_PATH, BITBLAS_TARGET)
TUNING_MESSAGE = (
f"BitBLAS Operator {config} tuned and saved to database.")
logger.info(TUNING_MESSAGE)
else:
_message = f"BitBLAS Operator {config} created without tuning. "
logger.info(_message)
else:
_message = f"BitBLAS Operator {config} retrieved from cache."
logger.info(_message)
return bitblas_matmul
def apply_gptq_bitblas_linear(
self,
layer: torch.nn.Module,
x: torch.Tensor,
) -> torch.Tensor:
output_size_per_partition = self.config.partition_weight_shape[1]
out_shape = x.shape[:-1] + (output_size_per_partition, )
args = [x, layer.qweight, layer.scales]
if self.bitblas_matmul.config.with_zeros: # type: ignore[attr-defined]
args.append(layer.qzeros)
output = self.bitblas_matmul(*args) # type: ignore[operator]
return output.view(out_shape)
def apply_weights(self, layer, x, bias=None):
NOT_IMPLEMENT_MESSAGE = (
f"{self.__class__.__name__}.apply_weights is not implemented. "
"Please use BitBLASLinearKernel.apply_gptq_bitblas_linear instead")
raise NotImplementedError(NOT_IMPLEMENT_MESSAGE)
...@@ -26,17 +26,14 @@ class MacheteLinearKernel(MPLinearKernel): ...@@ -26,17 +26,14 @@ class MacheteLinearKernel(MPLinearKernel):
@classmethod @classmethod
def can_implement(cls, def can_implement(cls,
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
if c.has_g_idx and\ if c.has_g_idx and\
c.partition_weight_shape[0] != c.full_weight_shape[0]: c.partition_weight_shape[0] != c.full_weight_shape[0]:
return False, "Act reordering currently not supported by Machete, "\ return False, "Act reordering currently not supported by Machete, "\
"when the input features are partitioned across "\ "when the input features are partitioned across "\
"devices" "devices"
if c.zero_points: if c.zero_points:
return False, "Zero points currently not supported by "\ return False, "Zero points currently not supported by Machete"
" Compressed Tensors + Machete. (Kernel supports it"\
" but CompressedTensorsWNA16 does not so support has"\
" not been added to MacheteWNA16Kernel yet"
if c.weight_type not in query_machete_supported_quant_types( if c.weight_type not in query_machete_supported_quant_types(
c.zero_points): c.zero_points):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment