Unverified Commit 766392c6 authored by ronnie_zheng's avatar ronnie_zheng Committed by GitHub
Browse files

[feature]Ascend quantization support (#7791)


Co-authored-by: default avatarichernob <ichernobnn@gmail.com>
Co-authored-by: default avatarliupeng <liupeng374@huawei.com>
parent 4a0d1919
......@@ -413,7 +413,9 @@ class ModelConfig:
quant_cfg = self._parse_quant_hf_config()
if quant_cfg is not None:
quant_method = quant_cfg.get("quant_method", "").lower()
quant_method = quant_cfg.get(
"quant_method", "" if not self.quantization else self.quantization
).lower()
# Detect which checkpoint is it
for _, method in QUANTIZATION_METHODS.items():
......
......@@ -34,6 +34,7 @@ from sglang.srt.layers.quantization.base_config import (
from sglang.srt.utils import (
cpu_has_amx_support,
is_cpu,
is_npu,
set_weight_attrs,
use_intel_amx_backend,
)
......@@ -60,6 +61,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
_is_npu = is_npu()
def adjust_marlin_shard(param, shard_size, shard_offset):
......@@ -297,6 +299,14 @@ class ReplicatedLinear(LinearBase):
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
# The per-tensor quant-scale must be 1 dimension
if _is_npu:
if param.size() != loaded_weight.size() and param.size(0) == 1:
if torch.allclose(loaded_weight, loaded_weight[0]):
loaded_weight = loaded_weight[:1]
else:
raise ValueError(f"{loaded_weight} are not all equal")
assert param.size() == loaded_weight.size()
param.data.copy_(loaded_weight)
......
......@@ -12,7 +12,6 @@ from sglang.srt.distributed import (
)
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
from sglang.srt.layers.moe.ep_moe.kernels import (
ep_gather,
ep_scatter,
......@@ -65,6 +64,8 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if not _is_npu:
from sgl_kernel import silu_and_mul
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
if _is_hip:
from vllm._custom_ops import scaled_fp8_quant
......
......@@ -850,7 +850,7 @@ class FusedMoE(torch.nn.Module):
return
# Case weight scales and zero_points
if "scale" in weight_name or "zero" in weight_name:
if "scale" in weight_name or "zero" in weight_name or "offset" in weight_name:
# load the weight scales and zp based on the quantization scheme
# supported weight scales/zp can be found in
# FusedMoeWeightScaleSupported
......
......@@ -308,7 +308,7 @@ def biased_grouped_topk_gpu(
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
compiled: bool = True,
compiled: bool = not _is_npu,
num_fused_shared_experts: int = 0,
routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None,
......
......@@ -116,8 +116,7 @@ class MoeWNA16Config(QuantizationConfig):
@classmethod
def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
can_convert = cls.is_moe_wna16_compatible(hf_quant_cfg)
if can_convert and user_quant == "moe_wna16":
if user_quant == "moe_wna16" and cls.is_moe_wna16_compatible(hf_quant_cfg):
return cls.get_name()
return None
......
......@@ -34,16 +34,18 @@ import torch
import torch.distributed as dist
import triton
import triton.language as tl
from sgl_kernel.kvcacheio import transfer_kv_per_layer, transfer_kv_per_layer_mla
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.utils import get_bool_env_var, is_cuda, next_power_of_2
from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2
logger = logging.getLogger(__name__)
GB = 1024 * 1024 * 1024
_is_cuda = is_cuda()
_is_npu = is_npu()
if not _is_npu:
from sgl_kernel.kvcacheio import transfer_kv_per_layer, transfer_kv_per_layer_mla
class ReqToTokenPool:
......
......@@ -64,10 +64,13 @@ from sglang.srt.model_loader.weight_utils import (
from sglang.srt.utils import (
get_bool_env_var,
get_device_capability,
is_npu,
is_pin_memory_available,
set_weight_attrs,
)
_is_npu = is_npu()
@contextmanager
def device_loading_context(module: torch.nn.Module, target_device: torch.device):
......@@ -127,18 +130,19 @@ def _get_quantization_config(
# (yizhang2077) workaround for nvidia/Llama-4-Maverick-17B-128E-Eagle3
if quant_config is None:
return None
major, minor = get_device_capability()
if major is not None and minor is not None:
assert 0 <= minor < 10
capability = major * 10 + minor
if capability < quant_config.get_min_capability():
raise ValueError(
f"The quantization method {model_config.quantization} "
"is not supported for the current GPU. "
f"Minimum capability: {quant_config.get_min_capability()}. "
f"Current capability: {capability}."
)
if not _is_npu:
major, minor = get_device_capability()
if major is not None and minor is not None:
assert 0 <= minor < 10
capability = major * 10 + minor
if capability < quant_config.get_min_capability():
raise ValueError(
f"The quantization method {model_config.quantization} "
"is not supported for the current GPU. "
f"Minimum capability: {quant_config.get_min_capability()}. "
f"Current capability: {capability}."
)
supported_dtypes = quant_config.get_supported_act_dtypes()
if model_config.dtype not in supported_dtypes:
raise ValueError(
......@@ -157,6 +161,13 @@ def _initialize_model(
"""Initialize a model with the given configurations."""
model_class, _ = get_model_architecture(model_config)
packed_modules_mapping = getattr(model_class, "packed_modules_mapping", {})
if _is_npu:
packed_modules_mapping["fused_qkv_a_proj_with_mqa"] = [
"q_a_proj",
"kv_a_proj_with_mqa",
]
packed_modules_mapping["qkv_proj"] = ["q_proj", "k_proj", "v_proj"]
packed_modules_mapping["gate_up_proj"] = ["gate_proj", "up_proj"]
quant_config = _get_quantization_config(
model_config, load_config, packed_modules_mapping
)
......
......@@ -575,6 +575,8 @@ class LlamaForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
......
......@@ -407,6 +407,8 @@ class QuantMixtralForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
......@@ -418,6 +420,8 @@ class QuantMixtralForCausalLM(nn.Module):
# Skip experts that are not assigned to this worker.
if "block_sparse_moe.experts." in name and name not in params_dict:
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
......
......@@ -538,6 +538,8 @@ class Qwen2ForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
......
......@@ -197,7 +197,7 @@ def get_int_env_var(name: str, default: int = 0) -> int:
def support_triton(backend: str) -> bool:
return backend not in ["torch_native", "intel_amx"]
return backend not in ["torch_native", "intel_amx", "ascend"]
try:
......@@ -2782,3 +2782,101 @@ def lru_cache_frozenset(maxsize=128):
return wrapper
return decorator
def apply_module_patch(target_module, target_function, wrappers):
original_module, original_function = parse_module_path(
target_module, target_function, False
)
original_function_id = id(original_function)
candidate = original_function
for wrapper in wrappers:
candidate = wrapper(candidate)
if target_function is not None:
setattr(original_module, target_function, candidate)
for key, value in sys.modules.copy().items():
if (
target_function is not None
and hasattr(value, target_function)
and id(getattr(value, target_function)) == original_function_id
):
setattr(value, target_function, candidate)
def parse_module_path(module_path, function_name, create_dummy):
from importlib.machinery import ModuleSpec
def create_dummy_module(full_path, parent=None):
"""Create and register a placeholder module"""
dummy = types.ModuleType(full_path)
dummy.__file__ = "vllm_ascend.dummy_module.py"
dummy.__spec__ = ModuleSpec(full_path, None)
sys.modules[full_path] = dummy
if parent:
setattr(parent, full_path.split(".")[-1], dummy)
return dummy
def create_placeholder_function(func_name):
"""Create dummy function that raises when called"""
def placeholder(*args, **kwargs):
raise NotImplementedError(f"Function {func_name} is a placeholder")
placeholder.__name__ = func_name
return placeholder
modules = module_path.split(".")
current_module = None
processed_path = []
for idx, part in enumerate(modules):
current_path = ".".join(modules[: idx + 1])
parent_path = ".".join(modules[:idx]) if idx > 0 else None
try:
current_module = importlib.import_module(current_path)
except ModuleNotFoundError:
# Handle missing module
parent = importlib.import_module(parent_path) if parent_path else None
if parent and hasattr(parent, part):
# Use existing attribute from parent
current_module = getattr(parent, part)
# Check for early function resolution
if function_name and hasattr(current_module, function_name):
return current_module, getattr(current_module, function_name)
if function_name and create_dummy:
ph_func = create_placeholder_function(function_name)
setattr(current_module, function_name, ph_func)
return current_module, ph_func
if function_name:
raise AttributeError(
f"Function {function_name} missing in {current_path}"
)
else:
if not create_dummy:
raise
# Create and register dummy module
current_module = create_dummy_module(
current_path,
parent=(
importlib.import_module(parent_path) if parent_path else None
),
)
processed_path.append(part)
# Final function handling
final_module = sys.modules[module_path]
if function_name is not None:
if not hasattr(final_module, function_name):
if create_dummy:
ph_func = create_placeholder_function(function_name)
setattr(final_module, function_name, ph_func)
else:
setattr(final_module, function_name, None)
return final_module, getattr(final_module, function_name)
return final_module, None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment