Commit 8d75f22e authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.13.0rc1' into v0.13.0rc1-ori

parents ce888aa4 7d80c73d
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from typing import Any, Optional from typing import Any, Optional
import torch import torch
...@@ -60,7 +59,7 @@ class MoeWNA16Config(QuantizationConfig): ...@@ -60,7 +59,7 @@ class MoeWNA16Config(QuantizationConfig):
if self.linear_quant_method == "gptq": if self.linear_quant_method == "gptq":
self.use_marlin = GPTQMarlinConfig.is_gptq_marlin_compatible(full_config) self.use_marlin = GPTQMarlinConfig.is_gptq_marlin_compatible(full_config)
elif self.linear_quant_method == "awq": elif self.linear_quant_method in ("awq", "awq_marlin"):
capability_tuple = current_platform.get_device_capability() capability_tuple = current_platform.get_device_capability()
device_capability = ( device_capability = (
-1 if capability_tuple is None else capability_tuple.to_int() -1 if capability_tuple is None else capability_tuple.to_int()
...@@ -107,7 +106,7 @@ class MoeWNA16Config(QuantizationConfig): ...@@ -107,7 +106,7 @@ class MoeWNA16Config(QuantizationConfig):
if linear_quant_method == "gptq": if linear_quant_method == "gptq":
has_zp = not cls.get_from_keys(config, ["sym"]) has_zp = not cls.get_from_keys(config, ["sym"])
modules_to_not_convert = [] modules_to_not_convert = []
elif linear_quant_method == "awq": elif linear_quant_method in ("awq", "awq_marlin"):
has_zp = cls.get_from_keys(config, ["zero_point"]) has_zp = cls.get_from_keys(config, ["zero_point"])
modules_to_not_convert = cls.get_from_keys_or( modules_to_not_convert = cls.get_from_keys_or(
config, ["modules_to_not_convert"], None config, ["modules_to_not_convert"], None
...@@ -184,7 +183,7 @@ class MoeWNA16Config(QuantizationConfig): ...@@ -184,7 +183,7 @@ class MoeWNA16Config(QuantizationConfig):
return GPTQConfig.from_config(self.full_config).get_quant_method( return GPTQConfig.from_config(self.full_config).get_quant_method(
layer, prefix layer, prefix
) )
elif self.linear_quant_method == "awq": elif self.linear_quant_method in ("awq", "awq_marlin"):
if self.use_marlin and check_marlin_supports_layer( if self.use_marlin and check_marlin_supports_layer(
layer, self.group_size layer, self.group_size
): ):
...@@ -362,27 +361,10 @@ class MoeWNA16Method(FusedMoEMethodBase): ...@@ -362,27 +361,10 @@ class MoeWNA16Method(FusedMoEMethodBase):
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
assert activation == "silu", "Only SiLU activation is supported." assert layer.activation == "silu", "Only SiLU activation is supported."
topk_weights, topk_ids, _ = layer.select_experts( topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
...@@ -395,9 +377,9 @@ class MoeWNA16Method(FusedMoEMethodBase): ...@@ -395,9 +377,9 @@ class MoeWNA16Method(FusedMoEMethodBase):
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=True,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=expert_map, expert_map=layer.expert_map,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
) )
...@@ -468,7 +450,8 @@ class MoeWNA16Method(FusedMoEMethodBase): ...@@ -468,7 +450,8 @@ class MoeWNA16Method(FusedMoEMethodBase):
shard_size = layer.intermediate_size_per_partition shard_size = layer.intermediate_size_per_partition
# convert gptq and awq weight to a standard format # convert gptq and awq weight to a standard format
if layer.quant_config.linear_quant_method == "awq": # awq_marlin uses the same weight format as awq
if layer.quant_config.linear_quant_method in ("awq", "awq_marlin"):
assert layer.quant_config.weight_bits == 4 assert layer.quant_config.weight_bits == 4
if "weight" in weight_name: if "weight" in weight_name:
loaded_weight = convert_awq_tensor(loaded_weight, "qweight") loaded_weight = convert_awq_tensor(loaded_weight, "qweight")
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from enum import Enum from enum import Enum
from typing import Optional from typing import Optional
...@@ -892,25 +891,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -892,25 +891,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if enable_eplb: if layer.enable_eplb:
raise NotImplementedError("EPLB is not supported for mxfp4") raise NotImplementedError("EPLB is not supported for mxfp4")
if self.mxfp4_backend == Mxfp4Backend.MARLIN: if self.mxfp4_backend == Mxfp4Backend.MARLIN:
...@@ -933,26 +915,26 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -933,26 +915,26 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
global_scale1=None, global_scale1=None,
global_scale2=None, global_scale2=None,
quant_type_id=scalar_types.float4_e2m1f.id, quant_type_id=scalar_types.float4_e2m1f.id,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
activation=activation, activation=layer.activation,
expert_map=expert_map, expert_map=layer.expert_map,
input_dtype=self.marlin_input_dtype, input_dtype=self.marlin_input_dtype,
) )
assert _can_support_mxfp4( assert _can_support_mxfp4(
use_grouped_topk, layer.use_grouped_topk,
topk_group, layer.topk_group,
num_expert_group, layer.num_expert_group,
expert_map, layer.expert_map,
custom_routing_function, layer.custom_routing_function,
e_score_correction_bias, layer.e_score_correction_bias,
apply_router_weight_on_input, layer.apply_router_weight_on_input,
scoring_func, layer.scoring_func,
activation, layer.activation,
expert_load_view, layer.expert_load_view,
logical_to_physical_map, layer.logical_to_physical_map,
logical_replica_count, layer.logical_replica_count,
), "MXFP4 are not supported with this configuration." ), "MXFP4 are not supported with this configuration."
if ( if (
...@@ -988,8 +970,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -988,8 +970,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
None, # output1_scale_scalar None, # output1_scale_scalar
None, # output1_scale_gate_scalar None, # output1_scale_gate_scalar
None, # output2_scale_scalar None, # output2_scale_scalar
global_num_experts, layer.global_num_experts,
top_k, layer.top_k,
None, # n_group None, # n_group
None, # topk_group None, # topk_group
self.intermediate_size, # padded to multiple of 256 self.intermediate_size, # padded to multiple of 256
...@@ -997,7 +979,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -997,7 +979,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
self.num_experts, # local num experts self.num_experts, # local num experts
None, None,
None, None,
1 if renormalize else 0, # routing_method_type, renormalize 1 if layer.renormalize else 0, # routing_method_type, renormalize
True, # do finalize True, # do finalize
tune_max_num_tokens=max(self.max_capture_size, 1), tune_max_num_tokens=max(self.max_capture_size, 1),
)[0] )[0]
...@@ -1081,12 +1063,12 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -1081,12 +1063,12 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
w1=layer.w13_weight, w1=layer.w13_weight,
w2=layer.w2_weight, w2=layer.w2_weight,
gating_output=router_logits, gating_output=router_logits,
topk=top_k, topk=layer.top_k,
renormalize=renormalize, renormalize=layer.renormalize,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=expert_map, expert_map=layer.expert_map,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
) )
else: else:
raise ValueError(f"Unsupported backend: {self.mxfp4_backend}") raise ValueError(f"Unsupported backend: {self.mxfp4_backend}")
...@@ -1138,37 +1120,20 @@ class IpexMxfp4MoEMethod(Mxfp4MoEMethod): ...@@ -1138,37 +1120,20 @@ class IpexMxfp4MoEMethod(Mxfp4MoEMethod):
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert activation == "swigluoai", ( assert layer.activation == "swigluoai", (
"Only swiglu_oai activation is supported for IPEX MXFP4 MoE" "Only swiglu_oai activation is supported for IPEX MXFP4 MoE"
) )
hidden_size_pad = round_up(self.original_hidden_size, 128) hidden_size_pad = round_up(self.original_hidden_size, 128)
x_pad = torch.nn.functional.pad(x, (0, hidden_size_pad - x.size(-1))) x_pad = torch.nn.functional.pad(x, (0, hidden_size_pad - x.size(-1)))
hidden_states = layer.ipex_fusion( hidden_states = layer.ipex_fusion(
x_pad, x_pad,
use_grouped_topk, layer.use_grouped_topk,
top_k, layer.top_k,
router_logits, router_logits,
renormalize, layer.renormalize,
topk_group, layer.topk_group,
num_expert_group, layer.num_expert_group,
activation="swiglu_oai", activation="swiglu_oai",
) )
hidden_states = hidden_states[..., : self.original_hidden_size].contiguous() hidden_states = hidden_states[..., : self.original_hidden_size].contiguous()
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from typing import Any from typing import Any
import torch import torch
...@@ -337,23 +336,6 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): ...@@ -337,23 +336,6 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids, _ = layer.select_experts( topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x, hidden_states=x,
...@@ -371,13 +353,15 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): ...@@ -371,13 +353,15 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
w2=layer.w2_weight, w2=layer.w2_weight,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
activation=activation, activation=layer.activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
expert_map=expert_map, expert_map=layer.expert_map,
) )
elif self.use_marlin: elif self.use_marlin:
assert activation == "silu", f"{activation} not supported for Marlin MoE." assert layer.activation == "silu", (
f"{layer.activation} not supported for Marlin MoE."
)
return fused_marlin_moe( return fused_marlin_moe(
x, x,
layer.w13_weight, layer.w13_weight,
...@@ -390,9 +374,9 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): ...@@ -390,9 +374,9 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
topk_weights, topk_weights,
topk_ids, topk_ids,
quant_type_id=scalar_types.float8_e4m3fn.id, quant_type_id=scalar_types.float8_e4m3fn.id,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=expert_map, expert_map=layer.expert_map,
) )
else: else:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
...@@ -404,10 +388,10 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): ...@@ -404,10 +388,10 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=True,
activation=activation, activation=layer.activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=expert_map, expert_map=layer.expert_map,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
) )
...@@ -597,23 +581,6 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): ...@@ -597,23 +581,6 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids, _ = layer.select_experts( topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x, hidden_states=x,
...@@ -631,8 +598,9 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): ...@@ -631,8 +598,9 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
layer.w2_weight, layer.w2_weight,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
activation=activation, activation=layer.activation,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
expert_map=layer.expert_map,
) )
else: else:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
...@@ -644,10 +612,11 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): ...@@ -644,10 +612,11 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=True,
activation=activation, activation=layer.activation,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
expert_map=expert_map, expert_map=layer.expert_map,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
) )
return out return out
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
# Copyright © 2025, Oracle and/or its affiliates. # Copyright © 2025, Oracle and/or its affiliates.
import os import os
from collections.abc import Callable
from typing import Any, Optional from typing import Any, Optional
import numpy as np import numpy as np
...@@ -359,23 +358,6 @@ class RTNMoEMethod(FusedMoEMethodBase): ...@@ -359,23 +358,6 @@ class RTNMoEMethod(FusedMoEMethodBase):
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids, _ = layer.select_experts( topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x, hidden_states=x,
...@@ -394,9 +376,9 @@ class RTNMoEMethod(FusedMoEMethodBase): ...@@ -394,9 +376,9 @@ class RTNMoEMethod(FusedMoEMethodBase):
topk_weights, topk_weights,
topk_ids, topk_ids,
quant_type_id=self.quant_config.quant_type.id, quant_type_id=self.quant_config.quant_type.id,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=expert_map, expert_map=layer.expert_map,
workspace=workspace, workspace=workspace,
) )
......
...@@ -247,6 +247,11 @@ def flashinfer_cutlass_moe_fp8( ...@@ -247,6 +247,11 @@ def flashinfer_cutlass_moe_fp8(
assert quant_config is not None assert quant_config is not None
# Construct modular kernel with block-scale support when requested. # Construct modular kernel with block-scale support when requested.
parallel_config = getattr(
getattr(layer, "vllm_config", None),
"parallel_config",
None,
)
fused_experts = mk.FusedMoEModularKernel( fused_experts = mk.FusedMoEModularKernel(
build_flashinfer_fp8_cutlass_moe_prepare_finalize( build_flashinfer_fp8_cutlass_moe_prepare_finalize(
moe=moe, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale moe=moe, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale
...@@ -257,6 +262,7 @@ def flashinfer_cutlass_moe_fp8( ...@@ -257,6 +262,7 @@ def flashinfer_cutlass_moe_fp8(
out_dtype=hidden_states.dtype, out_dtype=hidden_states.dtype,
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
), ),
parallel_config=parallel_config,
) )
return fused_experts( return fused_experts(
......
...@@ -27,6 +27,7 @@ from vllm.model_executor.parameter import ( ...@@ -27,6 +27,7 @@ from vllm.model_executor.parameter import (
ChannelQuantScaleParameter, ChannelQuantScaleParameter,
PerTensorScaleParameter, PerTensorScaleParameter,
) )
from vllm.model_executor.utils import replace_parameter
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils.deep_gemm import ( from vllm.utils.deep_gemm import (
...@@ -194,6 +195,39 @@ direct_register_custom_op( ...@@ -194,6 +195,39 @@ direct_register_custom_op(
) )
def _triton_per_token_group_quant_fp8_impl(
x: torch.Tensor,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
return per_token_group_quant_fp8(
x, group_size, column_major_scales=False, use_ue8m0=False
)
def _triton_per_token_group_quant_fp8_fake(
x: torch.Tensor,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
M, N = x.shape
x_fp8 = torch.empty((M, N), dtype=current_platform.fp8_dtype(), device=x.device)
out_bs = torch.empty(
(
M,
(N + group_size - 1) // group_size,
),
dtype=torch.float32,
device=x.device,
)
return x_fp8, out_bs
direct_register_custom_op(
"triton_per_token_group_quant_fp8",
_triton_per_token_group_quant_fp8_impl,
fake_impl=_triton_per_token_group_quant_fp8_fake,
)
# TODO fix ROCm->Triton custom path: # TODO fix ROCm->Triton custom path:
# https://github.com/vllm-project/vllm/issues/14397 # https://github.com/vllm-project/vllm/issues/14397
class W8A8BlockFp8LinearOp: class W8A8BlockFp8LinearOp:
...@@ -213,6 +247,7 @@ class W8A8BlockFp8LinearOp: ...@@ -213,6 +247,7 @@ class W8A8BlockFp8LinearOp:
self.act_quant_group_shape = act_quant_group_shape self.act_quant_group_shape = act_quant_group_shape
self.is_deep_gemm_supported = is_deep_gemm_supported() self.is_deep_gemm_supported = is_deep_gemm_supported()
self.is_hopper = current_platform.is_device_capability(90) self.is_hopper = current_platform.is_device_capability(90)
self.is_blackwell = current_platform.is_device_capability(100)
self.use_deep_gemm_e8m0 = is_deep_gemm_e8m0_used() self.use_deep_gemm_e8m0 = is_deep_gemm_e8m0_used()
# Get the correct blockscale mul and input quant operations. # Get the correct blockscale mul and input quant operations.
...@@ -268,8 +303,15 @@ class W8A8BlockFp8LinearOp: ...@@ -268,8 +303,15 @@ class W8A8BlockFp8LinearOp:
weight: torch.Tensor, weight: torch.Tensor,
weight_scale: torch.Tensor, weight_scale: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
assert self.deepgemm_input_quant_op is not None if self.use_deep_gemm_e8m0 and self.is_blackwell:
q_input, input_scale = self.deepgemm_input_quant_op(input_2d) q_input, input_scale = per_token_group_quant_fp8_packed_for_deepgemm(
input_2d,
group_size=self.act_quant_group_shape.col,
use_ue8m0=True,
)
else:
assert self.deepgemm_input_quant_op is not None
q_input, input_scale = self.deepgemm_input_quant_op(input_2d)
output = torch.empty( output = torch.empty(
(q_input.shape[0], weight.shape[0]), (q_input.shape[0], weight.shape[0]),
dtype=torch.bfloat16, dtype=torch.bfloat16,
...@@ -332,17 +374,15 @@ class W8A8BlockFp8LinearOp: ...@@ -332,17 +374,15 @@ class W8A8BlockFp8LinearOp:
if input_scale is not None: if input_scale is not None:
q_input = input_2d q_input = input_2d
# MI350 case uses triton kernel
elif use_triton: elif use_triton:
q_input, input_scale = per_token_group_quant_fp8( q_input, input_scale = torch.ops.vllm.triton_per_token_group_quant_fp8(
input_2d, input_2d,
self.act_quant_group_shape.col, self.act_quant_group_shape.col,
column_major_scales=False,
use_ue8m0=False,
) )
# MI300 uses tuned AITER ASM/C++ kernel
else: else:
q_input, input_scale = rocm_aiter_ops.group_fp8_quant(input_2d) q_input, input_scale = rocm_aiter_ops.group_fp8_quant(
input_2d, self.act_quant_group_shape.col
)
return gemm_a8w8_blockscale_op( return gemm_a8w8_blockscale_op(
q_input, q_input,
...@@ -492,6 +532,139 @@ def _per_token_group_quant_fp8( ...@@ -492,6 +532,139 @@ def _per_token_group_quant_fp8(
tl.store(y_s_ptr, y_s) tl.store(y_s_ptr, y_s)
@triton.jit
def _silu_mul_per_token_group_quant_fp8_colmajor(
y_ptr, # [M, N]
y_q_ptr, # [M, N // 2]
y_s_ptr, # [M, (N // 2) // GROUP_SIZE]
M, # num tokens
N, # intermediate size
# Stride
y_s_col_stride: tl.int64,
# Information for float8
eps,
fp8_min,
fp8_max,
use_ue8m0: tl.constexpr,
# Meta-parameters
GROUP_SIZE: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
# TODO(varun) : Add expert_ids so we may early-exit no-op thread blocks.
"""
Each thread block (BLOCK_N) computes [BLOCK_M, GROUP_SIZE] act-mul outputs. Then
the thread block quantizes the [BLOCK_M, GROUP_SIZE] block of values and fills
the outputs tensors at the right positions.
"""
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
N_2 = N // 2
m_offset = pid_m * BLOCK_M
n_offset = pid_n * BLOCK_N
if m_offset >= M:
return
offs_n = tl.arange(0, BLOCK_N).to(tl.int64)
offs_m = tl.arange(0, BLOCK_M).to(tl.int64)
base_y_ptr = y_ptr + m_offset * N + n_offset
act_in_ptrs = base_y_ptr + offs_m[:, None] * N + offs_n[None, :]
act_in = tl.load(act_in_ptrs)
mul_in = tl.load(act_in_ptrs + N_2)
# silu & mul
act_in = act_in.to(tl.float32)
one_f32 = tl.cast(1, tl.float32)
silu_out = (act_in / (one_f32 + tl.exp(-act_in))).to(y_ptr.dtype.element_ty)
y = (silu_out * mul_in).to(tl.float32)
# quant
_absmax = tl.maximum(tl.max(tl.abs(y), axis=1), eps)
scale_raw = _absmax / fp8_max
y_s = tl.math.exp2(tl.ceil(tl.log2(scale_raw))) if use_ue8m0 else scale_raw
y_s = tl.reshape(y_s, (BLOCK_M, 1))
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
# store y_q
base_y_q_ptr = y_q_ptr + m_offset * N_2 + n_offset
y_q_ptrs = base_y_q_ptr + offs_m[:, None] * N_2 + offs_n[None, :]
tl.store(y_q_ptrs, y_q)
# store y_s
group_id = n_offset // GROUP_SIZE
base_y_s_ptr = y_s_ptr + group_id * y_s_col_stride + m_offset
y_s_ptrs = base_y_s_ptr + offs_m
y_s = tl.reshape(y_s, (BLOCK_M,))
tl.store(y_s_ptrs, y_s)
def silu_mul_per_token_group_quant_fp8_colmajor(
input: torch.Tensor, # [M, N]
output: torch.Tensor | None = None, # [M, N // 2]
use_ue8m0: bool | None = None,
eps: float = 1e-10,
):
"""
silu+mul + block-fp8 quant with group size 128.
"""
GROUP_SIZE = 128
assert input.ndim == 2
if output is not None:
assert output.ndim == 2
assert input.size(0) % GROUP_SIZE == 0
assert input.size(1) % (GROUP_SIZE * 2) == 0
if use_ue8m0 is None:
use_ue8m0 = is_deep_gemm_e8m0_used()
M, N = input.size()
N_2 = N // 2
if output is None:
output = torch.empty((M, N_2), dtype=torch.float8_e4m3fn, device=input.device)
output_scales = torch.empty(
((N_2 // GROUP_SIZE), M), dtype=torch.float32, device=input.device
).transpose(0, 1)
BLOCK_M = 8
BLOCK_N = GROUP_SIZE
assert M % BLOCK_M == 0
assert N_2 % BLOCK_N == 0
finfo = torch.finfo(torch.float8_e4m3fn)
fp8_min = finfo.min
fp8_max = finfo.max
# Force even division so we can avoid edgecases within the kernel.
assert M % BLOCK_M == 0
assert N_2 % BLOCK_N == 0
grid = (M // BLOCK_M, N_2 // BLOCK_N)
_silu_mul_per_token_group_quant_fp8_colmajor[grid](
input,
output,
output_scales,
M,
N,
output_scales.stride(-1),
eps,
fp8_min,
fp8_max,
use_ue8m0,
GROUP_SIZE,
BLOCK_M,
BLOCK_N,
)
return output, output_scales
@triton.jit @triton.jit
def _per_token_group_quant_fp8_colmajor( def _per_token_group_quant_fp8_colmajor(
# Pointers to inputs and output # Pointers to inputs and output
...@@ -596,7 +769,7 @@ def per_token_group_quant_fp8( ...@@ -596,7 +769,7 @@ def per_token_group_quant_fp8(
assert out_q is None or out_q.shape == x.shape assert out_q is None or out_q.shape == x.shape
x_q = out_q x_q = out_q
if x_q is None: if x_q is None:
x_q = torch.empty_like(x, device=x.device, dtype=dtype) x_q = torch.empty(x.shape, device=x.device, dtype=dtype)
# Allocate the scale tensor in either row- or column-major format. # Allocate the scale tensor in either row- or column-major format.
if column_major_scales: if column_major_scales:
...@@ -658,6 +831,80 @@ def per_token_group_quant_fp8( ...@@ -658,6 +831,80 @@ def per_token_group_quant_fp8(
return x_q, x_s return x_q, x_s
def per_token_group_quant_fp8_packed_for_deepgemm(
x: torch.Tensor,
group_size: int,
eps: float = 1e-10,
use_ue8m0: bool | None = None,
out_q: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""FP8 per-token-group quantization for DeepGEMM.
Returns:
(x_q, x_s_packed)
x_q: FP8 activations, same shape as `x`.
x_s_packed: Int32 tensor with logical shape
[mn, ceil(num_groups_per_row / 4)], laid out with
TMA-aligned stride along the packed-K dimension
"""
if use_ue8m0 is None:
use_ue8m0 = is_deep_gemm_e8m0_used()
# for DeepGEMM UE8M0-packed layout we *require* UE8M0 scales.
assert use_ue8m0, (
"per_token_group_quant_fp8_packed_for_deepgemm requires UE8M0 scales."
)
dtype = current_platform.fp8_dtype()
assert x.shape[-1] % group_size == 0, (
f"the last dimension of `x` {x.shape[-1]} must be divisible "
f"by `group_size` {group_size}"
)
assert x.stride(-1) == 1, "`x` groups must be contiguous"
finfo = torch.finfo(dtype)
fp8_min, fp8_max = finfo.min, finfo.max
# compute DeepGEMM-style packed scale tensor shape.
hidden_dim = x.shape[-1]
mn = x.numel() // hidden_dim
num_groups_per_row = hidden_dim // group_size
k_num_packed_sf_k = (num_groups_per_row + 3) // 4
tma_aligned_mn = ((mn + 3) // 4) * 4
x_s_packed = torch.empty_strided(
(mn, k_num_packed_sf_k),
(1, tma_aligned_mn),
device=x.device,
dtype=torch.int32,
)
# CUDA kernel path only (DeepGEMM + E8M0 is CUDA-specific).
assert current_platform.is_cuda(), (
"per_token_group_quant_fp8_packed_for_deepgemm is only valid on CUDA "
"platforms using DeepGEMM."
)
x_contiguous = x.contiguous()
if out_q is not None:
x_q_local = out_q
else:
x_q_local = torch.empty_like(x_contiguous, device=x.device, dtype=dtype)
torch.ops._C.per_token_group_fp8_quant_packed(
x_contiguous,
x_q_local,
x_s_packed,
group_size,
eps,
fp8_min,
fp8_max,
)
# return a tensor with the original logical shape.
x_q = x_q_local.view_as(x)
return x_q, x_s_packed
@triton.jit @triton.jit
def _w8a8_triton_block_scaled_mm( def _w8a8_triton_block_scaled_mm(
# Pointers to inputs and output # Pointers to inputs and output
...@@ -1189,12 +1436,12 @@ def maybe_post_process_fp8_weight_block(layer: torch.nn.Module): ...@@ -1189,12 +1436,12 @@ def maybe_post_process_fp8_weight_block(layer: torch.nn.Module):
if should_use_deepgemm: if should_use_deepgemm:
dg_weight, dg_weight_scale = deepgemm_post_process_fp8_weight_block( dg_weight, dg_weight_scale = deepgemm_post_process_fp8_weight_block(
wq=layer.weight.data, wq=layer.weight.data,
ws=layer.weight_scale.data, ws=layer.weight_scale_inv.data,
quant_block_shape=tuple(layer.weight_block_size), quant_block_shape=tuple(layer.weight_block_size),
use_e8m0=is_deep_gemm_e8m0_used(), use_e8m0=is_deep_gemm_e8m0_used(),
) )
layer.weight = torch.nn.Parameter(dg_weight, requires_grad=False) replace_parameter(layer, "weight", dg_weight)
layer.weight_scale = torch.nn.Parameter(dg_weight_scale, requires_grad=False) replace_parameter(layer, "weight_scale_inv", dg_weight_scale)
def expert_weight_is_col_major(x: torch.Tensor) -> bool: def expert_weight_is_col_major(x: torch.Tensor) -> bool:
......
...@@ -83,26 +83,11 @@ def block_dequant( ...@@ -83,26 +83,11 @@ def block_dequant(
if current_platform.is_rocm(): if current_platform.is_rocm():
from triton.language import core
# NOTE: This can be removed when hip.libdevice.round() is available.
@core.extern
def round_f32(arg0, _builder=None):
return core.extern_elementwise(
"",
"",
[arg0],
{
(core.dtype("fp32"),): ("llvm.round", core.dtype("fp32")),
(core.dtype("fp64"),): ("llvm.round", core.dtype("fp64")),
},
is_pure=True,
_builder=_builder,
)
@triton.jit @triton.jit
def round_int8(x): def round_int8(x):
return round_f32(x).to(tl.int8) return tl.extra.hip.libdevice.round(x).to(tl.int8)
else: else:
@triton.jit @triton.jit
......
...@@ -179,6 +179,8 @@ def check_marlin_supports_shape( ...@@ -179,6 +179,8 @@ def check_marlin_supports_shape(
def check_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool: def check_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool:
if current_platform.is_rocm():
return False
output_size_per_partition = ( output_size_per_partition = (
getattr(layer, "output_size_per_partition", None) or layer.output_size getattr(layer, "output_size_per_partition", None) or layer.output_size
) )
...@@ -195,6 +197,8 @@ def check_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool: ...@@ -195,6 +197,8 @@ def check_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool:
def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool: def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool:
if current_platform.is_rocm():
return False
hidden_size = layer.hidden_size hidden_size = layer.hidden_size
intermediate_size_per_partition = layer.intermediate_size_per_partition intermediate_size_per_partition = layer.intermediate_size_per_partition
# apply_router_weight_on_input is not supported for moe marlin # apply_router_weight_on_input is not supported for moe marlin
......
...@@ -14,6 +14,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import ( ...@@ -14,6 +14,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_quant_input, marlin_quant_input,
should_use_atomic_add_reduce, should_use_atomic_add_reduce,
) )
from vllm.model_executor.utils import replace_parameter
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
...@@ -130,7 +131,7 @@ def prepare_fp8_layer_for_marlin( ...@@ -130,7 +131,7 @@ def prepare_fp8_layer_for_marlin(
size_n=part_size_n, size_n=part_size_n,
num_bits=8, num_bits=8,
) )
layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) replace_parameter(layer, "weight", marlin_qweight)
# WEIGHT SCALES # WEIGHT SCALES
# Permute scales # Permute scales
...@@ -138,7 +139,6 @@ def prepare_fp8_layer_for_marlin( ...@@ -138,7 +139,6 @@ def prepare_fp8_layer_for_marlin(
scales = layer.weight_scale.to(layer.orig_dtype) scales = layer.weight_scale.to(layer.orig_dtype)
elif "weight_scale_inv" in dir(layer): elif "weight_scale_inv" in dir(layer):
scales = layer.weight_scale_inv.to(layer.orig_dtype) scales = layer.weight_scale_inv.to(layer.orig_dtype)
del layer.weight_scale_inv
group_size = -1 if weight_block_size is None else weight_block_size[1] group_size = -1 if weight_block_size is None else weight_block_size[1]
...@@ -177,12 +177,15 @@ def prepare_fp8_layer_for_marlin( ...@@ -177,12 +177,15 @@ def prepare_fp8_layer_for_marlin(
) )
if input_dtype != torch.float8_e4m3fn: if input_dtype != torch.float8_e4m3fn:
marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales) marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales)
layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False) if hasattr(layer, "weight_scale"):
replace_parameter(layer, "weight_scale", marlin_scales)
elif hasattr(layer, "weight_scale_inv"):
replace_parameter(layer, "weight_scale_inv", marlin_scales)
if hasattr(layer, "bias") and layer.bias is not None: if hasattr(layer, "bias") and layer.bias is not None:
assert layer.bias.shape == (part_size_n,) assert layer.bias.shape == (part_size_n,)
bias = marlin_permute_bias(layer.bias) bias = marlin_permute_bias(layer.bias)
layer.bias = torch.nn.Parameter(bias, requires_grad=False) replace_parameter(layer, "bias", bias)
def prepare_moe_fp8_layer_for_marlin( def prepare_moe_fp8_layer_for_marlin(
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""This file is used for /tests and /benchmarks""" """This file is used for /tests and /benchmarks"""
from collections.abc import Mapping from collections.abc import Callable, Mapping
from dataclasses import dataclass from dataclasses import dataclass
from types import MappingProxyType from types import MappingProxyType
from typing import ClassVar, NamedTuple from typing import ClassVar, NamedTuple
...@@ -115,6 +115,12 @@ kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, kDynamicTokenScale, symmetric=True) ...@@ -115,6 +115,12 @@ kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, kDynamicTokenScale, symmetric=True)
kNvfp4GroupScale = ScaleDesc(FP8_DTYPE, False, GroupShape(1, 16)) kNvfp4GroupScale = ScaleDesc(FP8_DTYPE, False, GroupShape(1, 16))
kNvfp4Quant = QuantKey(FP4_DTYPE, scale=kNvfp4GroupScale, scale2=kStaticTensorScale) kNvfp4Quant = QuantKey(FP4_DTYPE, scale=kNvfp4GroupScale, scale2=kStaticTensorScale)
kDynamic128Scale = ScaleDesc(torch.float32, False, GroupShape(1, 128))
kFp8Dynamic128Sym = QuantKey(FP8_DTYPE, kDynamic128Scale, symmetric=True)
kDynamic64Scale = ScaleDesc(torch.float32, False, GroupShape(1, 64))
kFp8Dynamic64Sym = QuantKey(FP8_DTYPE, kDynamic64Scale, symmetric=True)
# Normalize the group_shape to the full extent for any dims that are -1 # Normalize the group_shape to the full extent for any dims that are -1
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape): def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape):
...@@ -685,3 +691,51 @@ def cutlass_fp4_supported() -> bool: ...@@ -685,3 +691,51 @@ def cutlass_fp4_supported() -> bool:
capability_tuple = current_platform.get_device_capability() capability_tuple = current_platform.get_device_capability()
capability = -1 if capability_tuple is None else capability_tuple.to_int() capability = -1 if capability_tuple is None else capability_tuple.to_int()
return cutlass_scaled_mm_supports_fp4(capability) return cutlass_scaled_mm_supports_fp4(capability)
def convert_bf16_scales_to_fp8(
quant_fp8: Callable, scales: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Convert a BF16 scale tensor into the pair of (fp8_scales, channel_scales)
expected by W4A8 GEMM kernels.
"""
assert scales.is_contiguous(), (
f"scale tensor must be contiguous, got {scales.stride()=}"
)
assert scales.is_cuda, "scales must be on gpu"
orig_shape = scales.shape
k_groups = orig_shape[-1]
flat_scales = scales.view(-1, k_groups)
fp8_scales, chan_scales = quant_fp8(flat_scales)
fp8_scales = (fp8_scales.float() / 8.0).to(torch.float8_e4m3fn)
chan_scales *= 8.0
# restore original shape
fp8_scales = fp8_scales.view(orig_shape)
chan_scales = chan_scales.view(orig_shape[:-1], -1)
return fp8_scales, chan_scales
def convert_packed_uint4b8_to_signed_int4_inplace(t: torch.Tensor) -> torch.Tensor:
"""
Convert int4b8 (packed to int32) to signed int4
"""
assert t.is_cuda, "tensor must be on gpu"
assert t.dtype == torch.int32, f"expected int32 packed weights but got {t.dtype}"
# loop through the 8 4-bit nibbles in each int32 entry
for i in range(8):
shift = 4 * i
# extract the i-th 4-bit nibble
nib = (t >> shift) & 0xF
# clear the original nibble by masking out
t &= ~(0xF << shift)
# convert int4b8 [0..15] to signed int4 [-8..7] by subtracting 8
# and update in-place
t |= ((nib - 8) & 0xF) << shift
return t
...@@ -118,8 +118,11 @@ def requantize_with_max_scale( ...@@ -118,8 +118,11 @@ def requantize_with_max_scale(
# from disk in this case. Skip requantization in this case (since) # from disk in this case. Skip requantization in this case (since)
# we already are quantized with the single scale. # we already are quantized with the single scale.
# * Sample Model: nm-testing/Phi-3-mini-128k-instruct-FP8 # * Sample Model: nm-testing/Phi-3-mini-128k-instruct-FP8
#
# Extra note: upon weight reloading weight_scale.ndim == 0
unfused_module_in_checkpoint = ( unfused_module_in_checkpoint = (
weight_scale[-1] > torch.finfo(torch.float8_e4m3fn).min weight_scale.ndim != 0
and weight_scale[-1] > torch.finfo(torch.float8_e4m3fn).min
) )
# If unfused checkpoint, need requanize with the single scale. # If unfused checkpoint, need requanize with the single scale.
......
...@@ -30,7 +30,6 @@ def get_rope( ...@@ -30,7 +30,6 @@ def get_rope(
is_neox_style: bool = True, is_neox_style: bool = True,
rope_parameters: dict[str, Any] | None = None, rope_parameters: dict[str, Any] | None = None,
dtype: torch.dtype | None = None, dtype: torch.dtype | None = None,
partial_rotary_factor: float = 1.0,
dual_chunk_attention_config: dict[str, Any] | None = None, dual_chunk_attention_config: dict[str, Any] | None = None,
) -> RotaryEmbedding: ) -> RotaryEmbedding:
if dtype is None: if dtype is None:
...@@ -55,6 +54,10 @@ def get_rope( ...@@ -55,6 +54,10 @@ def get_rope(
else: else:
dual_chunk_attention_args = None dual_chunk_attention_args = None
partial_rotary_factor = 1.0
if rope_parameters is not None:
partial_rotary_factor = rope_parameters.get("partial_rotary_factor", 1.0)
if partial_rotary_factor < 1.0: if partial_rotary_factor < 1.0:
rotary_dim = int(rotary_dim * partial_rotary_factor) rotary_dim = int(rotary_dim * partial_rotary_factor)
key = ( key = (
......
...@@ -4,6 +4,7 @@ import os ...@@ -4,6 +4,7 @@ import os
from collections.abc import Generator from collections.abc import Generator
import gguf import gguf
import regex as re
import torch import torch
import torch.nn as nn import torch.nn as nn
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
...@@ -94,6 +95,7 @@ class GGUFModelLoader(BaseModelLoader): ...@@ -94,6 +95,7 @@ class GGUFModelLoader(BaseModelLoader):
hasattr(config, "vision_config") and config.vision_config is not None hasattr(config, "vision_config") and config.vision_config is not None
) )
gguf_to_hf_name_map = {} gguf_to_hf_name_map = {}
sideload_params: list[re.Pattern] = []
# hack: ggufs have a different name than transformers # hack: ggufs have a different name than transformers
if model_type == "cohere": if model_type == "cohere":
model_type = "command-r" model_type = "command-r"
...@@ -118,6 +120,12 @@ class GGUFModelLoader(BaseModelLoader): ...@@ -118,6 +120,12 @@ class GGUFModelLoader(BaseModelLoader):
gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = ( gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = (
f"model.layers.{idx}.mlp.experts.0.up_proj.weight" f"model.layers.{idx}.mlp.experts.0.up_proj.weight"
) )
sideload_params.append(
re.compile(
f"model\\.layers\\.{idx}"
r"\.mlp\.experts\.[0-9]+\.(gate|up|down)_proj\.weight"
)
)
if model_type in ("qwen2_moe", "qwen3_moe"): if model_type in ("qwen2_moe", "qwen3_moe"):
model_type = model_type.replace("_", "") model_type = model_type.replace("_", "")
# GGUF layer map assumes that we will have a merged expert weights # GGUF layer map assumes that we will have a merged expert weights
...@@ -132,6 +140,12 @@ class GGUFModelLoader(BaseModelLoader): ...@@ -132,6 +140,12 @@ class GGUFModelLoader(BaseModelLoader):
gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = ( gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = (
f"model.layers.{idx}.mlp.experts.0.up_proj.weight" f"model.layers.{idx}.mlp.experts.0.up_proj.weight"
) )
sideload_params.append(
re.compile(
f"model\\.layers\\.{idx}"
r"\.mlp\.experts\.[0-9]+\.(gate|up|down)_proj\.weight"
)
)
arch = None arch = None
for key, value in gguf.MODEL_ARCH_NAMES.items(): for key, value in gguf.MODEL_ARCH_NAMES.items():
...@@ -241,7 +255,15 @@ class GGUFModelLoader(BaseModelLoader): ...@@ -241,7 +255,15 @@ class GGUFModelLoader(BaseModelLoader):
# Parameter not in manual overrides either # Parameter not in manual overrides either
unmapped_params.append(hf_name) unmapped_params.append(hf_name)
# All parameters must be mapped: both vision/projector and backbone # All parameters (except those initialized by other means) must be mapped:
# both vision/projector and backbone
if unmapped_params:
unmapped_params = list(
filter(
lambda x: not any(re.fullmatch(p, x) for p in sideload_params),
unmapped_params,
)
)
if unmapped_params: if unmapped_params:
raise RuntimeError( raise RuntimeError(
f"Failed to map GGUF parameters " f"Failed to map GGUF parameters "
......
...@@ -167,7 +167,6 @@ _MODEL_ARCH_BY_HASH = dict[int, tuple[type[nn.Module], str]]() ...@@ -167,7 +167,6 @@ _MODEL_ARCH_BY_HASH = dict[int, tuple[type[nn.Module], str]]()
def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], str]: def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], str]:
from vllm.model_executor.models.adapters import ( from vllm.model_executor.models.adapters import (
as_embedding_model, as_embedding_model,
as_reward_model,
as_seq_cls_model, as_seq_cls_model,
try_create_mm_pooling_model_cls, try_create_mm_pooling_model_cls,
) )
...@@ -207,9 +206,6 @@ def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], ...@@ -207,9 +206,6 @@ def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module],
elif convert_type == "classify": elif convert_type == "classify":
logger.debug_once("Converting to sequence classification model.") logger.debug_once("Converting to sequence classification model.")
model_cls = as_seq_cls_model(model_cls) model_cls = as_seq_cls_model(model_cls)
elif convert_type == "reward":
logger.debug_once("Converting to reward model.")
model_cls = as_reward_model(model_cls)
else: else:
assert_never(convert_type) assert_never(convert_type)
......
...@@ -641,7 +641,6 @@ def safetensors_weights_iterator( ...@@ -641,7 +641,6 @@ def safetensors_weights_iterator(
if safetensors_load_strategy == "eager": if safetensors_load_strategy == "eager":
loading_desc += " (eager)" loading_desc += " (eager)"
state_dict = {}
leftover_state_dict: dict[str, torch.Tensor] = {} leftover_state_dict: dict[str, torch.Tensor] = {}
for st_file in tqdm( for st_file in tqdm(
...@@ -667,6 +666,7 @@ def safetensors_weights_iterator( ...@@ -667,6 +666,7 @@ def safetensors_weights_iterator(
) )
with safe_open(st_file, framework="pt") as f: with safe_open(st_file, framework="pt") as f:
state_dict = {}
for name in f.keys(): # noqa: SIM118 for name in f.keys(): # noqa: SIM118
state_dict[name] = f.get_tensor(name) state_dict[name] = f.get_tensor(name)
...@@ -921,7 +921,17 @@ def gguf_quant_weights_iterator( ...@@ -921,7 +921,17 @@ def gguf_quant_weights_iterator(
name = gguf_to_hf_name_map[tensor.name] name = gguf_to_hf_name_map[tensor.name]
if weight_type.name not in ("F32", "BF16", "F16"): if weight_type.name not in ("F32", "BF16", "F16"):
name = name.replace("weight", "qweight") name = name.replace("weight", "qweight")
param = torch.tensor(weight) if weight_type.name == "BF16" and tensor.data.dtype == np.uint8:
# BF16 is currently the only "quantization" type that isn't
# actually quantized but is read as a raw byte tensor.
# Reinterpret as `torch.bfloat16` tensor.
weight = weight.view(np.uint16)
if reader.byte_order == "S":
# GGUF endianness != system endianness
weight = weight.byteswap()
param = torch.tensor(weight).view(torch.bfloat16)
else:
param = torch.tensor(weight)
yield name, param yield name, param
......
...@@ -175,9 +175,14 @@ def _create_pooling_model_cls(orig_cls: _T) -> _T: ...@@ -175,9 +175,14 @@ def _create_pooling_model_cls(orig_cls: _T) -> _T:
self.vllm_config = vllm_config self.vllm_config = vllm_config
# These are not used in pooling models # These are not used in pooling models
for attr in ("lm_head", "logits_processor"): objects_to_clean = [self]
if hasattr(self, attr): if language_model := getattr(self, "language_model", None):
delattr(self, attr) objects_to_clean.append(language_model)
for obj in objects_to_clean:
for attr in ("lm_head", "logits_processor"):
if hasattr(obj, attr):
delattr(obj, attr)
# If the model already defines a pooler instance, don't overwrite it # If the model already defines a pooler instance, don't overwrite it
if not getattr(self, "pooler", None): if not getattr(self, "pooler", None):
...@@ -346,44 +351,6 @@ def as_seq_cls_model(cls: _T) -> _T: ...@@ -346,44 +351,6 @@ def as_seq_cls_model(cls: _T) -> _T:
return ModelForSequenceClassification # type: ignore return ModelForSequenceClassification # type: ignore
def as_reward_model(cls: _T) -> _T:
"""
Subclass an existing vLLM model to support reward modeling.
By default, we return the hidden states of each token directly.
Note:
We assume that no extra layers are added to the original model;
please implement your own model if this is not the case.
"""
# Avoid modifying existing reward models
if is_pooling_model(cls):
return cls
# Lazy import
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
from .interfaces_base import default_pooling_type
@default_pooling_type("ALL")
class ModelForReward(_create_pooling_model_cls(cls)):
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler(
{
"token_classify": Pooler.for_token_classify(
pooler_config=pooler_config
)
}
)
ModelForReward.__name__ = _get_pooling_model_name(cls.__name__, "ForReward")
return ModelForReward # type: ignore
class SequenceClassificationConfig(VerifyAndUpdateConfig): class SequenceClassificationConfig(VerifyAndUpdateConfig):
@staticmethod @staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None: def verify_and_update_config(vllm_config: "VllmConfig") -> None:
......
...@@ -148,8 +148,6 @@ class ApertusAttention(nn.Module): ...@@ -148,8 +148,6 @@ class ApertusAttention(nn.Module):
if head_dim is None: if head_dim is None:
head_dim = self.hidden_size // self.total_num_heads head_dim = self.hidden_size // self.total_num_heads
self.head_dim = head_dim self.head_dim = head_dim
# Phi models introduced a partial_rotary_factor parameter in the config
self.partial_rotary_factor = getattr(config, "partial_rotary_factor", 1)
self.q_size = self.num_heads * self.head_dim self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
...@@ -228,11 +226,10 @@ class ApertusAttention(nn.Module): ...@@ -228,11 +226,10 @@ class ApertusAttention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=int(self.partial_rotary_factor * self.head_dim), rotary_dim=self.head_dim,
max_position=self.max_position_embeddings, max_position=self.max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
is_neox_style=is_neox_style, is_neox_style=is_neox_style,
partial_rotary_factor=self.partial_rotary_factor,
) )
......
...@@ -499,8 +499,6 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -499,8 +499,6 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
model to perform tasks that involve both image and text inputs. model to perform tasks that involve both image and text inputs.
""" """
merge_by_field_config = True
hf_to_vllm_mapper = WeightsMapper( hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={ orig_to_new_prefix={
# mapping for new names in checkpoint saved after transformers v4.52 # mapping for new names in checkpoint saved after transformers v4.52
......
...@@ -318,8 +318,6 @@ def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int: ...@@ -318,8 +318,6 @@ def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
dummy_inputs=AyaVisionDummyInputsBuilder, dummy_inputs=AyaVisionDummyInputsBuilder,
) )
class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
merge_by_field_config = True
hf_to_vllm_mapper = WeightsMapper( hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={ orig_to_new_prefix={
# mapping for new names in checkpoint saved after transformers v4.52 # mapping for new names in checkpoint saved after transformers v4.52
......
...@@ -127,8 +127,6 @@ class BailingAttention(nn.Module): ...@@ -127,8 +127,6 @@ class BailingAttention(nn.Module):
prefix=f"{prefix}.dense", prefix=f"{prefix}.dense",
) )
self.partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0)
self.rotary_dim = getattr(config, "rotary_dim", self.head_dim) self.rotary_dim = getattr(config, "rotary_dim", self.head_dim)
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
...@@ -137,7 +135,6 @@ class BailingAttention(nn.Module): ...@@ -137,7 +135,6 @@ class BailingAttention(nn.Module):
max_position=config.max_position_embeddings, max_position=config.max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
is_neox_style=True, is_neox_style=True,
partial_rotary_factor=self.partial_rotary_factor,
) )
self.attn = Attention( self.attn = Attention(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment