Unverified Commit 9fcc9a80 authored by Chunyuan WU's avatar Chunyuan WU Committed by GitHub
Browse files

[CPU] refine CPU integration code (#7647)

parent ac49dac0
import logging
import torch
from sglang.srt.utils import cpu_has_amx_support
logger = logging.getLogger(__name__)
def amx_process_weight_after_loading(weight):
if weight.device != torch.device("cpu"):
return weight
if not cpu_has_amx_support():
return weight
return torch.ops.sgl_kernel.convert_weight_packed(weight)
# TODO: currently gemm kernel has the below requirements:
# OC % TILE_N == 0, where TILE_N = 16
# IC % TILE_K == 0, where TILE_K = 32
def dim_is_supported(weight):
TILE_N = 16
TILE_K = 32
ndim = weight.ndim
OC = weight.size(1) if ndim == 3 else weight.size(0)
IC = weight.size(2) if ndim == 3 else weight.size(1)
return OC % TILE_N == 0 and IC % TILE_K == 0
def _amx_process_weight_after_loading(
module, weight_names, transpose_dims=None
) -> None:
# Pack weight for get better performance on CPU
devices = {getattr(module, weight_name).device for weight_name in weight_names}
assert len(devices) == 1, f"Expects all weights to be on the same device"
device = devices.pop()
if transpose_dims:
assert len(weight_names) == len(
transpose_dims
), "len(weight_names) should be equal to len(transpose_dims)"
for i, weight_name in enumerate(weight_names):
weight_tensor = getattr(module, weight_name)
if transpose_dims and transpose_dims[i]:
weight_tensor = weight_tensor.transpose(*transpose_dims[i])
# We don't pack weight or use intel amx backend if any weight of this module has unsupported dim.
if not dim_is_supported(weight_tensor):
logger.warning(
f"Unsupported dimension for prepacking for weight '{weight_name}' with shape {weight_tensor.shape} in {module}. "
f"The derived (OC, IC) dimensions must be divisible by (16, 32). "
)
module.use_intel_amx_backend = False
return
packed_weight = torch.nn.Parameter(
amx_process_weight_after_loading(weight_tensor),
requires_grad=False,
)
packed_weight.__dict__ = weight_tensor.__dict__
setattr(module, weight_name, packed_weight)
module.use_intel_amx_backend = (
device == torch.device("cpu") and cpu_has_amx_support()
)
if (
module.use_intel_amx_backend
and hasattr(module, "bias")
and module.bias is not None
):
module.bias = torch.nn.Parameter(module.bias.data.float(), requires_grad=False)
class PackWeightMethod:
def __init__(self, weight_names, transpose_dims=None):
self.weight_names = weight_names
self.transpose_dims = transpose_dims
def process_weights_after_loading(self, module) -> None:
_amx_process_weight_after_loading(
module, self.weight_names, self.transpose_dims
)
...@@ -17,6 +17,7 @@ from sglang.srt.distributed import ( ...@@ -17,6 +17,7 @@ from sglang.srt.distributed import (
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
from sglang.srt.layers.parameter import ( from sglang.srt.layers.parameter import (
BasevLLMParameter, BasevLLMParameter,
BlockQuantScaleParameter, BlockQuantScaleParameter,
...@@ -31,10 +32,10 @@ from sglang.srt.layers.quantization.base_config import ( ...@@ -31,10 +32,10 @@ from sglang.srt.layers.quantization.base_config import (
QuantizeMethodBase, QuantizeMethodBase,
) )
from sglang.srt.utils import ( from sglang.srt.utils import (
_process_weight_after_loading,
cpu_has_amx_support, cpu_has_amx_support,
is_cpu, is_cpu,
set_weight_attrs, set_weight_attrs,
use_intel_amx_backend,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -175,7 +176,7 @@ class UnquantizedLinearMethod(LinearMethodBase): ...@@ -175,7 +176,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if _is_cpu and _is_cpu_amx_available: if _is_cpu and _is_cpu_amx_available:
_process_weight_after_loading(layer, ["weight"]) _amx_process_weight_after_loading(layer, ["weight"])
def apply( def apply(
self, self,
...@@ -184,7 +185,7 @@ class UnquantizedLinearMethod(LinearMethodBase): ...@@ -184,7 +185,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if getattr(layer, "use_intel_amx_backend", False): if use_intel_amx_backend(layer):
return torch.ops.sgl_kernel.weight_packed_linear( return torch.ops.sgl_kernel.weight_packed_linear(
x, layer.weight, bias, True # is_vnni x, layer.weight, bias, True # is_vnni
) )
......
...@@ -42,7 +42,7 @@ from sglang.srt.model_executor.forward_batch_info import ( ...@@ -42,7 +42,7 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch, ForwardBatch,
ForwardMode, ForwardMode,
) )
from sglang.srt.utils import dump_to_file from sglang.srt.utils import dump_to_file, use_intel_amx_backend
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -442,7 +442,7 @@ class LogitsProcessor(nn.Module): ...@@ -442,7 +442,7 @@ class LogitsProcessor(nn.Module):
dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata) dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
if hasattr(lm_head, "weight"): if hasattr(lm_head, "weight"):
if getattr(lm_head, "use_intel_amx_backend", False): if use_intel_amx_backend(lm_head):
logits = torch.ops.sgl_kernel.weight_packed_linear( logits = torch.ops.sgl_kernel.weight_packed_linear(
hidden_states.to(lm_head.weight.dtype), hidden_states.to(lm_head.weight.dtype),
lm_head.weight, lm_head.weight,
......
...@@ -12,6 +12,7 @@ from sglang.srt.distributed import ( ...@@ -12,6 +12,7 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
from sglang.srt.layers.moe.fused_moe_native import moe_forward_native from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
...@@ -19,12 +20,12 @@ from sglang.srt.layers.quantization.base_config import ( ...@@ -19,12 +20,12 @@ from sglang.srt.layers.quantization.base_config import (
QuantizeMethodBase, QuantizeMethodBase,
) )
from sglang.srt.utils import ( from sglang.srt.utils import (
_process_weight_after_loading,
cpu_has_amx_support, cpu_has_amx_support,
get_bool_env_var, get_bool_env_var,
is_cpu, is_cpu,
is_hip, is_hip,
set_weight_attrs, set_weight_attrs,
use_intel_amx_backend,
) )
if torch.cuda.is_available(): if torch.cuda.is_available():
...@@ -129,7 +130,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -129,7 +130,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
# Pack weight for get better performance on CPU # Pack weight for get better performance on CPU
if _is_cpu and _is_cpu_amx_available: if _is_cpu and _is_cpu_amx_available:
_process_weight_after_loading(layer, ["w13_weight", "w2_weight"]) _amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
return return
...@@ -264,10 +265,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -264,10 +265,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
) -> torch.Tensor: ) -> torch.Tensor:
assert activation == "silu", f"activation = {activation} is not supported." assert activation == "silu", f"activation = {activation} is not supported."
if ( if use_intel_amx_backend(layer) and not apply_router_weight_on_input:
getattr(layer, "use_intel_amx_backend", False)
and not apply_router_weight_on_input
):
topk_weights, topk_ids = select_experts( topk_weights, topk_ids = select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
......
...@@ -27,6 +27,7 @@ except ImportError: ...@@ -27,6 +27,7 @@ except ImportError:
from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
LinearBase, LinearBase,
LinearMethodBase, LinearMethodBase,
...@@ -64,7 +65,6 @@ from sglang.srt.layers.quantization.utils import ( ...@@ -64,7 +65,6 @@ from sglang.srt.layers.quantization.utils import (
) )
from sglang.srt.layers.utils import is_sm100_supported from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.utils import ( from sglang.srt.utils import (
_process_weight_after_loading,
cpu_has_amx_support, cpu_has_amx_support,
get_bool_env_var, get_bool_env_var,
is_cpu, is_cpu,
...@@ -74,6 +74,7 @@ from sglang.srt.utils import ( ...@@ -74,6 +74,7 @@ from sglang.srt.utils import (
log_info_on_rank0, log_info_on_rank0,
print_warning_once, print_warning_once,
set_weight_attrs, set_weight_attrs,
use_intel_amx_backend,
) )
_is_hip = is_hip() _is_hip = is_hip()
...@@ -335,7 +336,7 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -335,7 +336,7 @@ class Fp8LinearMethod(LinearMethodBase):
assert ( assert (
_is_cpu_amx_available _is_cpu_amx_available
), "Fp8LinearMethod on CPU requires that CPU has AMX support" ), "Fp8LinearMethod on CPU requires that CPU has AMX support"
_process_weight_after_loading(layer, ["weight"]) _amx_process_weight_after_loading(layer, ["weight"])
return return
else: else:
weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data
...@@ -433,7 +434,7 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -433,7 +434,7 @@ class Fp8LinearMethod(LinearMethodBase):
) )
if self.block_quant: if self.block_quant:
if getattr(layer, "use_intel_amx_backend", False): if use_intel_amx_backend(layer):
return torch.ops.sgl_kernel.fp8_scaled_mm_cpu( return torch.ops.sgl_kernel.fp8_scaled_mm_cpu(
x, x,
layer.weight, layer.weight,
...@@ -769,7 +770,7 @@ class Fp8MoEMethod: ...@@ -769,7 +770,7 @@ class Fp8MoEMethod:
assert ( assert (
_is_cpu_amx_available _is_cpu_amx_available
), "Fp8MoEMethod on CPU requires that CPU has AMX support" ), "Fp8MoEMethod on CPU requires that CPU has AMX support"
_process_weight_after_loading(layer, ["w13_weight", "w2_weight"]) _amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
return return
...@@ -996,7 +997,7 @@ class Fp8MoEMethod: ...@@ -996,7 +997,7 @@ class Fp8MoEMethod:
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
) )
if getattr(layer, "use_intel_amx_backend", False): if use_intel_amx_backend(layer):
return torch.ops.sgl_kernel.fused_experts_cpu( return torch.ops.sgl_kernel.fused_experts_cpu(
x, x,
layer.w13_weight, layer.w13_weight,
......
...@@ -4,6 +4,7 @@ import torch ...@@ -4,6 +4,7 @@ import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
from sglang.srt.layers.linear import LinearMethodBase from sglang.srt.layers.linear import LinearMethodBase
from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
...@@ -12,11 +13,11 @@ from sglang.srt.layers.quantization.base_config import ( ...@@ -12,11 +13,11 @@ from sglang.srt.layers.quantization.base_config import (
) )
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8 from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
from sglang.srt.utils import ( from sglang.srt.utils import (
_process_weight_after_loading,
cpu_has_amx_support, cpu_has_amx_support,
is_cpu, is_cpu,
is_cuda, is_cuda,
set_weight_attrs, set_weight_attrs,
use_intel_amx_backend,
) )
_is_cuda = is_cuda() _is_cuda = is_cuda()
...@@ -84,7 +85,7 @@ class W8A8Int8LinearMethod(LinearMethodBase): ...@@ -84,7 +85,7 @@ class W8A8Int8LinearMethod(LinearMethodBase):
assert ( assert (
_is_cpu_amx_available _is_cpu_amx_available
), "W8A8Int8LinearMethod on CPU requires that CPU has AMX support" ), "W8A8Int8LinearMethod on CPU requires that CPU has AMX support"
_process_weight_after_loading(layer, ["weight"]) _amx_process_weight_after_loading(layer, ["weight"])
return return
layer.weight = Parameter(layer.weight.t(), requires_grad=False) layer.weight = Parameter(layer.weight.t(), requires_grad=False)
...@@ -127,7 +128,7 @@ class W8A8Int8LinearMethod(LinearMethodBase): ...@@ -127,7 +128,7 @@ class W8A8Int8LinearMethod(LinearMethodBase):
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
): ):
if getattr(layer, "use_intel_amx_backend", False): if use_intel_amx_backend(layer):
return torch.ops.sgl_kernel.int8_scaled_mm_with_quant( return torch.ops.sgl_kernel.int8_scaled_mm_with_quant(
x, x,
layer.weight, layer.weight,
...@@ -235,7 +236,7 @@ class W8A8Int8MoEMethod: ...@@ -235,7 +236,7 @@ class W8A8Int8MoEMethod:
assert ( assert (
_is_cpu_amx_available _is_cpu_amx_available
), "W8A8Int8MoEMethod on CPU requires that CPU has AMX support" ), "W8A8Int8MoEMethod on CPU requires that CPU has AMX support"
_process_weight_after_loading(layer, ["w13_weight", "w2_weight"]) _amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
return return
layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False) layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
...@@ -284,7 +285,7 @@ class W8A8Int8MoEMethod: ...@@ -284,7 +285,7 @@ class W8A8Int8MoEMethod:
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
) )
if getattr(layer, "use_intel_amx_backend", False): if use_intel_amx_backend(layer):
return torch.ops.sgl_kernel.fused_experts_cpu( return torch.ops.sgl_kernel.fused_experts_cpu(
x, x,
layer.w13_weight, layer.w13_weight,
......
...@@ -13,6 +13,7 @@ from sglang.srt.distributed import ( ...@@ -13,6 +13,7 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from sglang.srt.layers.amx_utils import PackWeightMethod
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
from sglang.srt.layers.parameter import BasevLLMParameter from sglang.srt.layers.parameter import BasevLLMParameter
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
...@@ -20,12 +21,7 @@ from sglang.srt.layers.quantization.base_config import ( ...@@ -20,12 +21,7 @@ from sglang.srt.layers.quantization.base_config import (
QuantizeMethodBase, QuantizeMethodBase,
method_has_implemented_embedding, method_has_implemented_embedding,
) )
from sglang.srt.utils import ( from sglang.srt.utils import cpu_has_amx_support, is_cpu, set_weight_attrs
PackWeightMethod,
cpu_has_amx_support,
is_cpu,
set_weight_attrs,
)
DEFAULT_VOCAB_PADDING_SIZE = 64 DEFAULT_VOCAB_PADDING_SIZE = 64
......
...@@ -36,6 +36,7 @@ from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_r ...@@ -36,6 +36,7 @@ from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_r
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.amx_utils import PackWeightMethod
from sglang.srt.layers.communicator import ( from sglang.srt.layers.communicator import (
LayerCommunicator, LayerCommunicator,
LayerScatterModes, LayerScatterModes,
...@@ -91,7 +92,6 @@ from sglang.srt.utils import ( ...@@ -91,7 +92,6 @@ from sglang.srt.utils import (
BumpAllocator, BumpAllocator,
DeepEPMode, DeepEPMode,
LazyValue, LazyValue,
PackWeightMethod,
add_prefix, add_prefix,
bind_or_assign, bind_or_assign,
cpu_has_amx_support, cpu_has_amx_support,
...@@ -103,6 +103,7 @@ from sglang.srt.utils import ( ...@@ -103,6 +103,7 @@ from sglang.srt.utils import (
is_hip, is_hip,
is_non_idle_and_non_empty, is_non_idle_and_non_empty,
log_info_on_rank0, log_info_on_rank0,
use_intel_amx_backend,
) )
_is_hip = is_hip() _is_hip = is_hip()
...@@ -224,7 +225,7 @@ class MoEGate(nn.Module): ...@@ -224,7 +225,7 @@ class MoEGate(nn.Module):
self.quant_method = PackWeightMethod(weight_names=["weight"]) self.quant_method = PackWeightMethod(weight_names=["weight"])
def forward(self, hidden_states): def forward(self, hidden_states):
if getattr(self, "use_intel_amx_backend", False): if use_intel_amx_backend(self):
return torch.ops.sgl_kernel.weight_packed_linear( return torch.ops.sgl_kernel.weight_packed_linear(
hidden_states, hidden_states,
self.weight, self.weight,
...@@ -437,8 +438,8 @@ class DeepseekV2MoE(nn.Module): ...@@ -437,8 +438,8 @@ class DeepseekV2MoE(nn.Module):
return final_hidden_states return final_hidden_states
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
if hasattr(self, "shared_experts") and getattr( if hasattr(self, "shared_experts") and use_intel_amx_backend(
self.shared_experts.gate_up_proj, "use_intel_amx_backend", False self.shared_experts.gate_up_proj
): ):
return self.forward_cpu(hidden_states) return self.forward_cpu(hidden_states)
...@@ -464,9 +465,9 @@ class DeepseekV2MoE(nn.Module): ...@@ -464,9 +465,9 @@ class DeepseekV2MoE(nn.Module):
hidden_states=hidden_states, router_logits=router_logits hidden_states=hidden_states, router_logits=router_logits
) )
assert getattr( assert use_intel_amx_backend(
self.shared_experts.gate_up_proj, "use_intel_amx_backend", False self.shared_experts.gate_up_proj
) == getattr(self.shared_experts.down_proj, "use_intel_amx_backend", False) ) == use_intel_amx_backend(self.shared_experts.down_proj)
# [Note] inplace should be False in fused_experts. # [Note] inplace should be False in fused_experts.
# If inplace is True in fused_experts (self.experts), hidden_states will be changed after fused_experts # If inplace is True in fused_experts (self.experts), hidden_states will be changed after fused_experts
# While hidden_states is still needed in shared_expert. # While hidden_states is still needed in shared_expert.
...@@ -928,15 +929,23 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -928,15 +929,23 @@ class DeepseekV2AttentionMLA(nn.Module):
) )
self.weight_block_size = None self.weight_block_size = None
if self.qkv_proj_with_rope_is_fp8: if self.qkv_proj_with_rope_is_fp8 and _is_cpu and _is_cpu_amx_available:
assert ( assert getattr(
self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size self.fused_qkv_a_proj_with_mqa.quant_method, "block_quant", False
== self.q_b_proj.quant_method.quant_config.weight_block_size ) == getattr(self.q_b_proj.quant_method, "block_quant", False)
) use_block_quant = getattr(
self.weight_block_size = ( self.fused_qkv_a_proj_with_mqa.quant_method, "block_quant", False
self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
) )
if use_block_quant:
assert (
self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
== self.q_b_proj.quant_method.quant_config.weight_block_size
)
self.weight_block_size = (
self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
)
def dispatch_attn_forward_method( def dispatch_attn_forward_method(
self, forward_batch: ForwardBatch self, forward_batch: ForwardBatch
) -> AttnForwardMethod: ) -> AttnForwardMethod:
...@@ -950,8 +959,8 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -950,8 +959,8 @@ class DeepseekV2AttentionMLA(nn.Module):
else: else:
return AttnForwardMethod.MLA return AttnForwardMethod.MLA
else: else:
if hasattr(self, "fused_qkv_a_proj_with_mqa") and getattr( if hasattr(self, "fused_qkv_a_proj_with_mqa") and use_intel_amx_backend(
self, "use_intel_amx_backend", False self
): ):
return AttnForwardMethod.MLA_FUSED_ROPE_CPU return AttnForwardMethod.MLA_FUSED_ROPE_CPU
else: else:
...@@ -1426,8 +1435,8 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1426,8 +1435,8 @@ class DeepseekV2AttentionMLA(nn.Module):
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
zero_allocator: BumpAllocator, zero_allocator: BumpAllocator,
): ):
assert self.q_lora_rank is not None and getattr( assert self.q_lora_rank is not None and use_intel_amx_backend(
self, "use_intel_amx_backend", False self
), "forward_absorb_fused_mla_rope_cpu_prepare requires q_lora_rank is not None and use_intel_amx_backend" ), "forward_absorb_fused_mla_rope_cpu_prepare requires q_lora_rank is not None and use_intel_amx_backend"
q_input, k_input, v_input = ( q_input, k_input, v_input = (
...@@ -1546,8 +1555,8 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1546,8 +1555,8 @@ class DeepseekV2AttentionMLA(nn.Module):
def forward_absorb_fused_mla_rope_cpu_core( def forward_absorb_fused_mla_rope_cpu_core(
self, q_input, k_input, v_input, forward_batch, zero_allocator self, q_input, k_input, v_input, forward_batch, zero_allocator
): ):
assert self.q_lora_rank is not None and getattr( assert self.q_lora_rank is not None and use_intel_amx_backend(
self, "use_intel_amx_backend", False self
), "forward_absorb_fused_mla_rope_cpu_core requires q_lora_rank is not None and use_intel_amx_backend" ), "forward_absorb_fused_mla_rope_cpu_core requires q_lora_rank is not None and use_intel_amx_backend"
attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch) attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
......
...@@ -2416,75 +2416,8 @@ def cpu_has_amx_support(): ...@@ -2416,75 +2416,8 @@ def cpu_has_amx_support():
return torch._C._cpu._is_amx_tile_supported() and is_intel_amx_backend_available return torch._C._cpu._is_amx_tile_supported() and is_intel_amx_backend_available
def prepack_weight_if_needed(weight): def use_intel_amx_backend(layer):
if weight.device != torch.device("cpu"): return getattr(layer, "use_intel_amx_backend", False)
return weight
if not cpu_has_amx_support():
return weight
return torch.ops.sgl_kernel.convert_weight_packed(weight)
# TODO: currently gemm kernel has the below requirements:
# OC % TILE_N == 0, where TILE_N = 16
# IC % TILE_K == 0, where TILE_K = 32
def dim_is_supported(weight):
return weight.size(0) % 16 == 0 and weight.size(1) % 32 == 0
def _process_weight_after_loading(module, weight_names, transpose_dims=None) -> None:
# Pack weight for get better performance on CPU
devices = {getattr(module, weight_name).device for weight_name in weight_names}
assert len(devices) == 1, f"Expects all weights to be on the same device"
device = devices.pop()
if transpose_dims:
assert len(weight_names) == len(
transpose_dims
), "len(weight_names) should be equal to len(transpose_dims)"
for i, weight_name in enumerate(weight_names):
weight_tensor = getattr(module, weight_name)
# We don't pack weight or use intel amx backend if any weight of this module has unsupported dim.
if not dim_is_supported(weight_tensor):
logger.warning(
f"Expects weight.size(0) % 16 == 0 and weight.size(1) % 32 == 0 "
f"but {weight_tensor.size(0)=} and {weight_tensor.size(1)=} in {module}. "
f"{module} won't use intel amx backend."
)
module.use_intel_amx_backend = False
return
if transpose_dims and transpose_dims[i]:
weight_tensor = weight_tensor.transpose(*transpose_dims[i])
packed_weight = torch.nn.Parameter(
prepack_weight_if_needed(weight_tensor),
requires_grad=False,
)
packed_weight.__dict__ = weight_tensor.__dict__
setattr(module, weight_name, packed_weight)
module.use_intel_amx_backend = (
device == torch.device("cpu") and cpu_has_amx_support()
)
if (
module.use_intel_amx_backend
and hasattr(module, "bias")
and module.bias is not None
):
module.bias = torch.nn.Parameter(module.bias.data.float(), requires_grad=False)
class PackWeightMethod:
def __init__(self, weight_names, transpose_dims=None):
self.weight_names = weight_names
self.transpose_dims = transpose_dims
def process_weights_after_loading(self, module) -> None:
_process_weight_after_loading(module, self.weight_names, self.transpose_dims)
class LazyValue: class LazyValue:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment