Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9fcc9a80
Unverified
Commit
9fcc9a80
authored
Jul 04, 2025
by
Chunyuan WU
Committed by
GitHub
Jul 03, 2025
Browse files
[CPU] refine CPU integration code (#7647)
parent
ac49dac0
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
141 additions
and
116 deletions
+141
-116
python/sglang/srt/layers/amx_utils.py
python/sglang/srt/layers/amx_utils.py
+86
-0
python/sglang/srt/layers/linear.py
python/sglang/srt/layers/linear.py
+4
-3
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+2
-2
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+4
-6
python/sglang/srt/layers/quantization/fp8.py
python/sglang/srt/layers/quantization/fp8.py
+6
-5
python/sglang/srt/layers/quantization/w8a8_int8.py
python/sglang/srt/layers/quantization/w8a8_int8.py
+6
-5
python/sglang/srt/layers/vocab_parallel_embedding.py
python/sglang/srt/layers/vocab_parallel_embedding.py
+2
-6
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+29
-20
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+2
-69
No files found.
python/sglang/srt/layers/amx_utils.py
0 → 100644
View file @
9fcc9a80
import
logging
import
torch
from
sglang.srt.utils
import
cpu_has_amx_support
logger
=
logging
.
getLogger
(
__name__
)
def
amx_process_weight_after_loading
(
weight
):
if
weight
.
device
!=
torch
.
device
(
"cpu"
):
return
weight
if
not
cpu_has_amx_support
():
return
weight
return
torch
.
ops
.
sgl_kernel
.
convert_weight_packed
(
weight
)
# TODO: currently gemm kernel has the below requirements:
# OC % TILE_N == 0, where TILE_N = 16
# IC % TILE_K == 0, where TILE_K = 32
def
dim_is_supported
(
weight
):
TILE_N
=
16
TILE_K
=
32
ndim
=
weight
.
ndim
OC
=
weight
.
size
(
1
)
if
ndim
==
3
else
weight
.
size
(
0
)
IC
=
weight
.
size
(
2
)
if
ndim
==
3
else
weight
.
size
(
1
)
return
OC
%
TILE_N
==
0
and
IC
%
TILE_K
==
0
def
_amx_process_weight_after_loading
(
module
,
weight_names
,
transpose_dims
=
None
)
->
None
:
# Pack weight for get better performance on CPU
devices
=
{
getattr
(
module
,
weight_name
).
device
for
weight_name
in
weight_names
}
assert
len
(
devices
)
==
1
,
f
"Expects all weights to be on the same device"
device
=
devices
.
pop
()
if
transpose_dims
:
assert
len
(
weight_names
)
==
len
(
transpose_dims
),
"len(weight_names) should be equal to len(transpose_dims)"
for
i
,
weight_name
in
enumerate
(
weight_names
):
weight_tensor
=
getattr
(
module
,
weight_name
)
if
transpose_dims
and
transpose_dims
[
i
]:
weight_tensor
=
weight_tensor
.
transpose
(
*
transpose_dims
[
i
])
# We don't pack weight or use intel amx backend if any weight of this module has unsupported dim.
if
not
dim_is_supported
(
weight_tensor
):
logger
.
warning
(
f
"Unsupported dimension for prepacking for weight '
{
weight_name
}
' with shape
{
weight_tensor
.
shape
}
in
{
module
}
. "
f
"The derived (OC, IC) dimensions must be divisible by (16, 32). "
)
module
.
use_intel_amx_backend
=
False
return
packed_weight
=
torch
.
nn
.
Parameter
(
amx_process_weight_after_loading
(
weight_tensor
),
requires_grad
=
False
,
)
packed_weight
.
__dict__
=
weight_tensor
.
__dict__
setattr
(
module
,
weight_name
,
packed_weight
)
module
.
use_intel_amx_backend
=
(
device
==
torch
.
device
(
"cpu"
)
and
cpu_has_amx_support
()
)
if
(
module
.
use_intel_amx_backend
and
hasattr
(
module
,
"bias"
)
and
module
.
bias
is
not
None
):
module
.
bias
=
torch
.
nn
.
Parameter
(
module
.
bias
.
data
.
float
(),
requires_grad
=
False
)
class
PackWeightMethod
:
def
__init__
(
self
,
weight_names
,
transpose_dims
=
None
):
self
.
weight_names
=
weight_names
self
.
transpose_dims
=
transpose_dims
def
process_weights_after_loading
(
self
,
module
)
->
None
:
_amx_process_weight_after_loading
(
module
,
self
.
weight_names
,
self
.
transpose_dims
)
python/sglang/srt/layers/linear.py
View file @
9fcc9a80
...
@@ -17,6 +17,7 @@ from sglang.srt.distributed import (
...
@@ -17,6 +17,7 @@ from sglang.srt.distributed import (
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_reduce
,
tensor_model_parallel_all_reduce
,
)
)
from
sglang.srt.layers.amx_utils
import
_amx_process_weight_after_loading
from
sglang.srt.layers.parameter
import
(
from
sglang.srt.layers.parameter
import
(
BasevLLMParameter
,
BasevLLMParameter
,
BlockQuantScaleParameter
,
BlockQuantScaleParameter
,
...
@@ -31,10 +32,10 @@ from sglang.srt.layers.quantization.base_config import (
...
@@ -31,10 +32,10 @@ from sglang.srt.layers.quantization.base_config import (
QuantizeMethodBase
,
QuantizeMethodBase
,
)
)
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
_process_weight_after_loading
,
cpu_has_amx_support
,
cpu_has_amx_support
,
is_cpu
,
is_cpu
,
set_weight_attrs
,
set_weight_attrs
,
use_intel_amx_backend
,
)
)
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -175,7 +176,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
...
@@ -175,7 +176,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
if
_is_cpu
and
_is_cpu_amx_available
:
if
_is_cpu
and
_is_cpu_amx_available
:
_process_weight_after_loading
(
layer
,
[
"weight"
])
_amx
_process_weight_after_loading
(
layer
,
[
"weight"
])
def
apply
(
def
apply
(
self
,
self
,
...
@@ -184,7 +185,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
...
@@ -184,7 +185,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
getattr
(
layer
,
"
use_intel_amx_backend
"
,
False
):
if
use_intel_amx_backend
(
layer
):
return
torch
.
ops
.
sgl_kernel
.
weight_packed_linear
(
return
torch
.
ops
.
sgl_kernel
.
weight_packed_linear
(
x
,
layer
.
weight
,
bias
,
True
# is_vnni
x
,
layer
.
weight
,
bias
,
True
# is_vnni
)
)
...
...
python/sglang/srt/layers/logits_processor.py
View file @
9fcc9a80
...
@@ -42,7 +42,7 @@ from sglang.srt.model_executor.forward_batch_info import (
...
@@ -42,7 +42,7 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch
,
ForwardBatch
,
ForwardMode
,
ForwardMode
,
)
)
from
sglang.srt.utils
import
dump_to_file
from
sglang.srt.utils
import
dump_to_file
,
use_intel_amx_backend
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -442,7 +442,7 @@ class LogitsProcessor(nn.Module):
...
@@ -442,7 +442,7 @@ class LogitsProcessor(nn.Module):
dp_gather_replicate
(
hidden_states
,
local_hidden_states
,
logits_metadata
)
dp_gather_replicate
(
hidden_states
,
local_hidden_states
,
logits_metadata
)
if
hasattr
(
lm_head
,
"weight"
):
if
hasattr
(
lm_head
,
"weight"
):
if
getattr
(
lm_head
,
"
use_intel_amx_backend
"
,
False
):
if
use_intel_amx_backend
(
lm_head
):
logits
=
torch
.
ops
.
sgl_kernel
.
weight_packed_linear
(
logits
=
torch
.
ops
.
sgl_kernel
.
weight_packed_linear
(
hidden_states
.
to
(
lm_head
.
weight
.
dtype
),
hidden_states
.
to
(
lm_head
.
weight
.
dtype
),
lm_head
.
weight
,
lm_head
.
weight
,
...
...
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
9fcc9a80
...
@@ -12,6 +12,7 @@ from sglang.srt.distributed import (
...
@@ -12,6 +12,7 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
,
tensor_model_parallel_all_reduce
,
)
)
from
sglang.srt.layers.amx_utils
import
_amx_process_weight_after_loading
from
sglang.srt.layers.moe.fused_moe_native
import
moe_forward_native
from
sglang.srt.layers.moe.fused_moe_native
import
moe_forward_native
from
sglang.srt.layers.moe.topk
import
select_experts
from
sglang.srt.layers.moe.topk
import
select_experts
from
sglang.srt.layers.quantization.base_config
import
(
from
sglang.srt.layers.quantization.base_config
import
(
...
@@ -19,12 +20,12 @@ from sglang.srt.layers.quantization.base_config import (
...
@@ -19,12 +20,12 @@ from sglang.srt.layers.quantization.base_config import (
QuantizeMethodBase
,
QuantizeMethodBase
,
)
)
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
_process_weight_after_loading
,
cpu_has_amx_support
,
cpu_has_amx_support
,
get_bool_env_var
,
get_bool_env_var
,
is_cpu
,
is_cpu
,
is_hip
,
is_hip
,
set_weight_attrs
,
set_weight_attrs
,
use_intel_amx_backend
,
)
)
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
...
@@ -129,7 +130,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -129,7 +130,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
# Pack weight for get better performance on CPU
# Pack weight for get better performance on CPU
if
_is_cpu
and
_is_cpu_amx_available
:
if
_is_cpu
and
_is_cpu_amx_available
:
_process_weight_after_loading
(
layer
,
[
"w13_weight"
,
"w2_weight"
])
_amx
_process_weight_after_loading
(
layer
,
[
"w13_weight"
,
"w2_weight"
])
return
return
...
@@ -264,10 +265,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -264,10 +265,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
assert
activation
==
"silu"
,
f
"activation =
{
activation
}
is not supported."
assert
activation
==
"silu"
,
f
"activation =
{
activation
}
is not supported."
if
(
if
use_intel_amx_backend
(
layer
)
and
not
apply_router_weight_on_input
:
getattr
(
layer
,
"use_intel_amx_backend"
,
False
)
and
not
apply_router_weight_on_input
):
topk_weights
,
topk_ids
=
select_experts
(
topk_weights
,
topk_ids
=
select_experts
(
hidden_states
=
x
,
hidden_states
=
x
,
router_logits
=
router_logits
,
router_logits
=
router_logits
,
...
...
python/sglang/srt/layers/quantization/fp8.py
View file @
9fcc9a80
...
@@ -27,6 +27,7 @@ except ImportError:
...
@@ -27,6 +27,7 @@ except ImportError:
from
sglang.srt.distributed
import
get_tensor_model_parallel_world_size
from
sglang.srt.distributed
import
get_tensor_model_parallel_world_size
from
sglang.srt.layers.amx_utils
import
_amx_process_weight_after_loading
from
sglang.srt.layers.linear
import
(
from
sglang.srt.layers.linear
import
(
LinearBase
,
LinearBase
,
LinearMethodBase
,
LinearMethodBase
,
...
@@ -64,7 +65,6 @@ from sglang.srt.layers.quantization.utils import (
...
@@ -64,7 +65,6 @@ from sglang.srt.layers.quantization.utils import (
)
)
from
sglang.srt.layers.utils
import
is_sm100_supported
from
sglang.srt.layers.utils
import
is_sm100_supported
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
_process_weight_after_loading
,
cpu_has_amx_support
,
cpu_has_amx_support
,
get_bool_env_var
,
get_bool_env_var
,
is_cpu
,
is_cpu
,
...
@@ -74,6 +74,7 @@ from sglang.srt.utils import (
...
@@ -74,6 +74,7 @@ from sglang.srt.utils import (
log_info_on_rank0
,
log_info_on_rank0
,
print_warning_once
,
print_warning_once
,
set_weight_attrs
,
set_weight_attrs
,
use_intel_amx_backend
,
)
)
_is_hip
=
is_hip
()
_is_hip
=
is_hip
()
...
@@ -335,7 +336,7 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -335,7 +336,7 @@ class Fp8LinearMethod(LinearMethodBase):
assert
(
assert
(
_is_cpu_amx_available
_is_cpu_amx_available
),
"Fp8LinearMethod on CPU requires that CPU has AMX support"
),
"Fp8LinearMethod on CPU requires that CPU has AMX support"
_process_weight_after_loading
(
layer
,
[
"weight"
])
_amx
_process_weight_after_loading
(
layer
,
[
"weight"
])
return
return
else
:
else
:
weight
,
weight_scale
=
layer
.
weight
.
data
,
layer
.
weight_scale_inv
.
data
weight
,
weight_scale
=
layer
.
weight
.
data
,
layer
.
weight_scale_inv
.
data
...
@@ -433,7 +434,7 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -433,7 +434,7 @@ class Fp8LinearMethod(LinearMethodBase):
)
)
if
self
.
block_quant
:
if
self
.
block_quant
:
if
getattr
(
layer
,
"
use_intel_amx_backend
"
,
False
):
if
use_intel_amx_backend
(
layer
):
return
torch
.
ops
.
sgl_kernel
.
fp8_scaled_mm_cpu
(
return
torch
.
ops
.
sgl_kernel
.
fp8_scaled_mm_cpu
(
x
,
x
,
layer
.
weight
,
layer
.
weight
,
...
@@ -769,7 +770,7 @@ class Fp8MoEMethod:
...
@@ -769,7 +770,7 @@ class Fp8MoEMethod:
assert
(
assert
(
_is_cpu_amx_available
_is_cpu_amx_available
),
"Fp8MoEMethod on CPU requires that CPU has AMX support"
),
"Fp8MoEMethod on CPU requires that CPU has AMX support"
_process_weight_after_loading
(
layer
,
[
"w13_weight"
,
"w2_weight"
])
_amx
_process_weight_after_loading
(
layer
,
[
"w13_weight"
,
"w2_weight"
])
return
return
...
@@ -996,7 +997,7 @@ class Fp8MoEMethod:
...
@@ -996,7 +997,7 @@ class Fp8MoEMethod:
routed_scaling_factor
=
routed_scaling_factor
,
routed_scaling_factor
=
routed_scaling_factor
,
)
)
if
getattr
(
layer
,
"
use_intel_amx_backend
"
,
False
):
if
use_intel_amx_backend
(
layer
):
return
torch
.
ops
.
sgl_kernel
.
fused_experts_cpu
(
return
torch
.
ops
.
sgl_kernel
.
fused_experts_cpu
(
x
,
x
,
layer
.
w13_weight
,
layer
.
w13_weight
,
...
...
python/sglang/srt/layers/quantization/w8a8_int8.py
View file @
9fcc9a80
...
@@ -4,6 +4,7 @@ import torch
...
@@ -4,6 +4,7 @@ import torch
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
sglang.srt.distributed
import
get_tensor_model_parallel_world_size
from
sglang.srt.distributed
import
get_tensor_model_parallel_world_size
from
sglang.srt.layers.amx_utils
import
_amx_process_weight_after_loading
from
sglang.srt.layers.linear
import
LinearMethodBase
from
sglang.srt.layers.linear
import
LinearMethodBase
from
sglang.srt.layers.parameter
import
ChannelQuantScaleParameter
,
ModelWeightParameter
from
sglang.srt.layers.parameter
import
ChannelQuantScaleParameter
,
ModelWeightParameter
from
sglang.srt.layers.quantization.base_config
import
(
from
sglang.srt.layers.quantization.base_config
import
(
...
@@ -12,11 +13,11 @@ from sglang.srt.layers.quantization.base_config import (
...
@@ -12,11 +13,11 @@ from sglang.srt.layers.quantization.base_config import (
)
)
from
sglang.srt.layers.quantization.int8_kernel
import
per_token_quant_int8
from
sglang.srt.layers.quantization.int8_kernel
import
per_token_quant_int8
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
_process_weight_after_loading
,
cpu_has_amx_support
,
cpu_has_amx_support
,
is_cpu
,
is_cpu
,
is_cuda
,
is_cuda
,
set_weight_attrs
,
set_weight_attrs
,
use_intel_amx_backend
,
)
)
_is_cuda
=
is_cuda
()
_is_cuda
=
is_cuda
()
...
@@ -84,7 +85,7 @@ class W8A8Int8LinearMethod(LinearMethodBase):
...
@@ -84,7 +85,7 @@ class W8A8Int8LinearMethod(LinearMethodBase):
assert
(
assert
(
_is_cpu_amx_available
_is_cpu_amx_available
),
"W8A8Int8LinearMethod on CPU requires that CPU has AMX support"
),
"W8A8Int8LinearMethod on CPU requires that CPU has AMX support"
_process_weight_after_loading
(
layer
,
[
"weight"
])
_amx
_process_weight_after_loading
(
layer
,
[
"weight"
])
return
return
layer
.
weight
=
Parameter
(
layer
.
weight
.
t
(),
requires_grad
=
False
)
layer
.
weight
=
Parameter
(
layer
.
weight
.
t
(),
requires_grad
=
False
)
...
@@ -127,7 +128,7 @@ class W8A8Int8LinearMethod(LinearMethodBase):
...
@@ -127,7 +128,7 @@ class W8A8Int8LinearMethod(LinearMethodBase):
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
):
):
if
getattr
(
layer
,
"
use_intel_amx_backend
"
,
False
):
if
use_intel_amx_backend
(
layer
):
return
torch
.
ops
.
sgl_kernel
.
int8_scaled_mm_with_quant
(
return
torch
.
ops
.
sgl_kernel
.
int8_scaled_mm_with_quant
(
x
,
x
,
layer
.
weight
,
layer
.
weight
,
...
@@ -235,7 +236,7 @@ class W8A8Int8MoEMethod:
...
@@ -235,7 +236,7 @@ class W8A8Int8MoEMethod:
assert
(
assert
(
_is_cpu_amx_available
_is_cpu_amx_available
),
"W8A8Int8MoEMethod on CPU requires that CPU has AMX support"
),
"W8A8Int8MoEMethod on CPU requires that CPU has AMX support"
_process_weight_after_loading
(
layer
,
[
"w13_weight"
,
"w2_weight"
])
_amx
_process_weight_after_loading
(
layer
,
[
"w13_weight"
,
"w2_weight"
])
return
return
layer
.
w13_weight
=
Parameter
(
layer
.
w13_weight
,
requires_grad
=
False
)
layer
.
w13_weight
=
Parameter
(
layer
.
w13_weight
,
requires_grad
=
False
)
...
@@ -284,7 +285,7 @@ class W8A8Int8MoEMethod:
...
@@ -284,7 +285,7 @@ class W8A8Int8MoEMethod:
routed_scaling_factor
=
routed_scaling_factor
,
routed_scaling_factor
=
routed_scaling_factor
,
)
)
if
getattr
(
layer
,
"
use_intel_amx_backend
"
,
False
):
if
use_intel_amx_backend
(
layer
):
return
torch
.
ops
.
sgl_kernel
.
fused_experts_cpu
(
return
torch
.
ops
.
sgl_kernel
.
fused_experts_cpu
(
x
,
x
,
layer
.
w13_weight
,
layer
.
w13_weight
,
...
...
python/sglang/srt/layers/vocab_parallel_embedding.py
View file @
9fcc9a80
...
@@ -13,6 +13,7 @@ from sglang.srt.distributed import (
...
@@ -13,6 +13,7 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
,
tensor_model_parallel_all_reduce
,
)
)
from
sglang.srt.layers.amx_utils
import
PackWeightMethod
from
sglang.srt.layers.dp_attention
import
get_attention_tp_rank
,
get_attention_tp_size
from
sglang.srt.layers.dp_attention
import
get_attention_tp_rank
,
get_attention_tp_size
from
sglang.srt.layers.parameter
import
BasevLLMParameter
from
sglang.srt.layers.parameter
import
BasevLLMParameter
from
sglang.srt.layers.quantization.base_config
import
(
from
sglang.srt.layers.quantization.base_config
import
(
...
@@ -20,12 +21,7 @@ from sglang.srt.layers.quantization.base_config import (
...
@@ -20,12 +21,7 @@ from sglang.srt.layers.quantization.base_config import (
QuantizeMethodBase
,
QuantizeMethodBase
,
method_has_implemented_embedding
,
method_has_implemented_embedding
,
)
)
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
cpu_has_amx_support
,
is_cpu
,
set_weight_attrs
PackWeightMethod
,
cpu_has_amx_support
,
is_cpu
,
set_weight_attrs
,
)
DEFAULT_VOCAB_PADDING_SIZE
=
64
DEFAULT_VOCAB_PADDING_SIZE
=
64
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
9fcc9a80
...
@@ -36,6 +36,7 @@ from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_r
...
@@ -36,6 +36,7 @@ from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_r
from
sglang.srt.eplb.expert_location
import
ModelConfigForExpertLocation
from
sglang.srt.eplb.expert_location
import
ModelConfigForExpertLocation
from
sglang.srt.eplb.expert_location_dispatch
import
ExpertLocationDispatchInfo
from
sglang.srt.eplb.expert_location_dispatch
import
ExpertLocationDispatchInfo
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.amx_utils
import
PackWeightMethod
from
sglang.srt.layers.communicator
import
(
from
sglang.srt.layers.communicator
import
(
LayerCommunicator
,
LayerCommunicator
,
LayerScatterModes
,
LayerScatterModes
,
...
@@ -91,7 +92,6 @@ from sglang.srt.utils import (
...
@@ -91,7 +92,6 @@ from sglang.srt.utils import (
BumpAllocator
,
BumpAllocator
,
DeepEPMode
,
DeepEPMode
,
LazyValue
,
LazyValue
,
PackWeightMethod
,
add_prefix
,
add_prefix
,
bind_or_assign
,
bind_or_assign
,
cpu_has_amx_support
,
cpu_has_amx_support
,
...
@@ -103,6 +103,7 @@ from sglang.srt.utils import (
...
@@ -103,6 +103,7 @@ from sglang.srt.utils import (
is_hip
,
is_hip
,
is_non_idle_and_non_empty
,
is_non_idle_and_non_empty
,
log_info_on_rank0
,
log_info_on_rank0
,
use_intel_amx_backend
,
)
)
_is_hip
=
is_hip
()
_is_hip
=
is_hip
()
...
@@ -224,7 +225,7 @@ class MoEGate(nn.Module):
...
@@ -224,7 +225,7 @@ class MoEGate(nn.Module):
self
.
quant_method
=
PackWeightMethod
(
weight_names
=
[
"weight"
])
self
.
quant_method
=
PackWeightMethod
(
weight_names
=
[
"weight"
])
def
forward
(
self
,
hidden_states
):
def
forward
(
self
,
hidden_states
):
if
getattr
(
self
,
"
use_intel_amx_backend
"
,
False
):
if
use_intel_amx_backend
(
self
):
return
torch
.
ops
.
sgl_kernel
.
weight_packed_linear
(
return
torch
.
ops
.
sgl_kernel
.
weight_packed_linear
(
hidden_states
,
hidden_states
,
self
.
weight
,
self
.
weight
,
...
@@ -437,8 +438,8 @@ class DeepseekV2MoE(nn.Module):
...
@@ -437,8 +438,8 @@ class DeepseekV2MoE(nn.Module):
return
final_hidden_states
return
final_hidden_states
def
forward_normal
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward_normal
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
hasattr
(
self
,
"shared_experts"
)
and
getattr
(
if
hasattr
(
self
,
"shared_experts"
)
and
use_intel_amx_backend
(
self
.
shared_experts
.
gate_up_proj
,
"use_intel_amx_backend"
,
False
self
.
shared_experts
.
gate_up_proj
):
):
return
self
.
forward_cpu
(
hidden_states
)
return
self
.
forward_cpu
(
hidden_states
)
...
@@ -464,9 +465,9 @@ class DeepseekV2MoE(nn.Module):
...
@@ -464,9 +465,9 @@ class DeepseekV2MoE(nn.Module):
hidden_states
=
hidden_states
,
router_logits
=
router_logits
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
)
assert
getattr
(
assert
use_intel_amx_backend
(
self
.
shared_experts
.
gate_up_proj
,
"use_intel_amx_backend"
,
False
self
.
shared_experts
.
gate_up_proj
)
==
getattr
(
self
.
shared_experts
.
down_proj
,
"use_intel_amx_backend"
,
False
)
)
==
use_intel_amx_backend
(
self
.
shared_experts
.
down_proj
)
# [Note] inplace should be False in fused_experts.
# [Note] inplace should be False in fused_experts.
# If inplace is True in fused_experts (self.experts), hidden_states will be changed after fused_experts
# If inplace is True in fused_experts (self.experts), hidden_states will be changed after fused_experts
# While hidden_states is still needed in shared_expert.
# While hidden_states is still needed in shared_expert.
...
@@ -928,15 +929,23 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -928,15 +929,23 @@ class DeepseekV2AttentionMLA(nn.Module):
)
)
self
.
weight_block_size
=
None
self
.
weight_block_size
=
None
if
self
.
qkv_proj_with_rope_is_fp8
:
if
self
.
qkv_proj_with_rope_is_fp8
and
_is_cpu
and
_is_cpu_amx_available
:
assert
(
assert
getattr
(
self
.
fused_qkv_a_proj_with_mqa
.
quant_method
.
quant_config
.
weight_block_size
self
.
fused_qkv_a_proj_with_mqa
.
quant_method
,
"block_quant"
,
False
==
self
.
q_b_proj
.
quant_method
.
quant_config
.
weight_block_size
)
==
getattr
(
self
.
q_b_proj
.
quant_method
,
"block_quant"
,
False
)
)
use_block_quant
=
getattr
(
self
.
weight_block_size
=
(
self
.
fused_qkv_a_proj_with_mqa
.
quant_method
,
"block_quant"
,
False
self
.
fused_qkv_a_proj_with_mqa
.
quant_method
.
quant_config
.
weight_block_size
)
)
if
use_block_quant
:
assert
(
self
.
fused_qkv_a_proj_with_mqa
.
quant_method
.
quant_config
.
weight_block_size
==
self
.
q_b_proj
.
quant_method
.
quant_config
.
weight_block_size
)
self
.
weight_block_size
=
(
self
.
fused_qkv_a_proj_with_mqa
.
quant_method
.
quant_config
.
weight_block_size
)
def
dispatch_attn_forward_method
(
def
dispatch_attn_forward_method
(
self
,
forward_batch
:
ForwardBatch
self
,
forward_batch
:
ForwardBatch
)
->
AttnForwardMethod
:
)
->
AttnForwardMethod
:
...
@@ -950,8 +959,8 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -950,8 +959,8 @@ class DeepseekV2AttentionMLA(nn.Module):
else
:
else
:
return
AttnForwardMethod
.
MLA
return
AttnForwardMethod
.
MLA
else
:
else
:
if
hasattr
(
self
,
"fused_qkv_a_proj_with_mqa"
)
and
getattr
(
if
hasattr
(
self
,
"fused_qkv_a_proj_with_mqa"
)
and
use_intel_amx_backend
(
self
,
"use_intel_amx_backend"
,
False
self
):
):
return
AttnForwardMethod
.
MLA_FUSED_ROPE_CPU
return
AttnForwardMethod
.
MLA_FUSED_ROPE_CPU
else
:
else
:
...
@@ -1426,8 +1435,8 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -1426,8 +1435,8 @@ class DeepseekV2AttentionMLA(nn.Module):
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
zero_allocator
:
BumpAllocator
,
zero_allocator
:
BumpAllocator
,
):
):
assert
self
.
q_lora_rank
is
not
None
and
getattr
(
assert
self
.
q_lora_rank
is
not
None
and
use_intel_amx_backend
(
self
,
"use_intel_amx_backend"
,
False
self
),
"forward_absorb_fused_mla_rope_cpu_prepare requires q_lora_rank is not None and use_intel_amx_backend"
),
"forward_absorb_fused_mla_rope_cpu_prepare requires q_lora_rank is not None and use_intel_amx_backend"
q_input
,
k_input
,
v_input
=
(
q_input
,
k_input
,
v_input
=
(
...
@@ -1546,8 +1555,8 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -1546,8 +1555,8 @@ class DeepseekV2AttentionMLA(nn.Module):
def
forward_absorb_fused_mla_rope_cpu_core
(
def
forward_absorb_fused_mla_rope_cpu_core
(
self
,
q_input
,
k_input
,
v_input
,
forward_batch
,
zero_allocator
self
,
q_input
,
k_input
,
v_input
,
forward_batch
,
zero_allocator
):
):
assert
self
.
q_lora_rank
is
not
None
and
getattr
(
assert
self
.
q_lora_rank
is
not
None
and
use_intel_amx_backend
(
self
,
"use_intel_amx_backend"
,
False
self
),
"forward_absorb_fused_mla_rope_cpu_core requires q_lora_rank is not None and use_intel_amx_backend"
),
"forward_absorb_fused_mla_rope_cpu_core requires q_lora_rank is not None and use_intel_amx_backend"
attn_output
=
self
.
attn_mqa
(
q_input
,
k_input
,
v_input
,
forward_batch
)
attn_output
=
self
.
attn_mqa
(
q_input
,
k_input
,
v_input
,
forward_batch
)
...
...
python/sglang/srt/utils.py
View file @
9fcc9a80
...
@@ -2416,75 +2416,8 @@ def cpu_has_amx_support():
...
@@ -2416,75 +2416,8 @@ def cpu_has_amx_support():
return
torch
.
_C
.
_cpu
.
_is_amx_tile_supported
()
and
is_intel_amx_backend_available
return
torch
.
_C
.
_cpu
.
_is_amx_tile_supported
()
and
is_intel_amx_backend_available
def
prepack_weight_if_needed
(
weight
):
def
use_intel_amx_backend
(
layer
):
if
weight
.
device
!=
torch
.
device
(
"cpu"
):
return
getattr
(
layer
,
"use_intel_amx_backend"
,
False
)
return
weight
if
not
cpu_has_amx_support
():
return
weight
return
torch
.
ops
.
sgl_kernel
.
convert_weight_packed
(
weight
)
# TODO: currently gemm kernel has the below requirements:
# OC % TILE_N == 0, where TILE_N = 16
# IC % TILE_K == 0, where TILE_K = 32
def
dim_is_supported
(
weight
):
return
weight
.
size
(
0
)
%
16
==
0
and
weight
.
size
(
1
)
%
32
==
0
def
_process_weight_after_loading
(
module
,
weight_names
,
transpose_dims
=
None
)
->
None
:
# Pack weight for get better performance on CPU
devices
=
{
getattr
(
module
,
weight_name
).
device
for
weight_name
in
weight_names
}
assert
len
(
devices
)
==
1
,
f
"Expects all weights to be on the same device"
device
=
devices
.
pop
()
if
transpose_dims
:
assert
len
(
weight_names
)
==
len
(
transpose_dims
),
"len(weight_names) should be equal to len(transpose_dims)"
for
i
,
weight_name
in
enumerate
(
weight_names
):
weight_tensor
=
getattr
(
module
,
weight_name
)
# We don't pack weight or use intel amx backend if any weight of this module has unsupported dim.
if
not
dim_is_supported
(
weight_tensor
):
logger
.
warning
(
f
"Expects weight.size(0) % 16 == 0 and weight.size(1) % 32 == 0 "
f
"but
{
weight_tensor
.
size
(
0
)
=
}
and
{
weight_tensor
.
size
(
1
)
=
}
in
{
module
}
. "
f
"
{
module
}
won't use intel amx backend."
)
module
.
use_intel_amx_backend
=
False
return
if
transpose_dims
and
transpose_dims
[
i
]:
weight_tensor
=
weight_tensor
.
transpose
(
*
transpose_dims
[
i
])
packed_weight
=
torch
.
nn
.
Parameter
(
prepack_weight_if_needed
(
weight_tensor
),
requires_grad
=
False
,
)
packed_weight
.
__dict__
=
weight_tensor
.
__dict__
setattr
(
module
,
weight_name
,
packed_weight
)
module
.
use_intel_amx_backend
=
(
device
==
torch
.
device
(
"cpu"
)
and
cpu_has_amx_support
()
)
if
(
module
.
use_intel_amx_backend
and
hasattr
(
module
,
"bias"
)
and
module
.
bias
is
not
None
):
module
.
bias
=
torch
.
nn
.
Parameter
(
module
.
bias
.
data
.
float
(),
requires_grad
=
False
)
class
PackWeightMethod
:
def
__init__
(
self
,
weight_names
,
transpose_dims
=
None
):
self
.
weight_names
=
weight_names
self
.
transpose_dims
=
transpose_dims
def
process_weights_after_loading
(
self
,
module
)
->
None
:
_process_weight_after_loading
(
module
,
self
.
weight_names
,
self
.
transpose_dims
)
class
LazyValue
:
class
LazyValue
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment