Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
aa1e77a1
Unverified
Commit
aa1e77a1
authored
Jan 11, 2025
by
Li, Jiang
Committed by
GitHub
Jan 10, 2025
Browse files
[Hardware][CPU] Support MOE models on x86 CPU (#11831)
Signed-off-by:
jiang1.li
<
jiang1.li@intel.com
>
parent
5959564f
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
43 additions
and
4 deletions
+43
-4
docs/source/getting_started/installation/cpu-x86.md
docs/source/getting_started/installation/cpu-x86.md
+1
-1
tests/models/decoder_only/language/test_models.py
tests/models/decoder_only/language/test_models.py
+4
-0
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+38
-3
No files found.
docs/source/getting_started/installation/cpu-x86.md
View file @
aa1e77a1
...
...
@@ -5,7 +5,7 @@
vLLM initially supports basic model inferencing and serving on x86 CPU platform, with data types FP32, FP16 and BF16. vLLM CPU backend supports the following vLLM features:
-
Tensor Parallel
-
Model Quantization (
`INT8 W8A8, AWQ`
)
-
Model Quantization (
`INT8 W8A8, AWQ
, GPTQ
`
)
-
Chunked-prefill
-
Prefix-caching
-
FP8-E5M2 KV-Caching (TODO)
...
...
tests/models/decoder_only/language/test_models.py
View file @
aa1e77a1
...
...
@@ -48,6 +48,10 @@ from ...utils import check_logprobs_close
),
pytest
.
param
(
"stabilityai/stablelm-3b-4e1t"
),
# stablelm
pytest
.
param
(
"bigcode/starcoder2-3b"
),
# starcoder2
pytest
.
param
(
"ehristoforu/Falcon3-MoE-2x7B-Insruct"
,
# mixtral
marks
=
[
pytest
.
mark
.
cpu_model
],
)
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
32
])
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
aa1e77a1
...
...
@@ -13,6 +13,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.platforms.interface
import
CpuArchEnum
if
current_platform
.
is_cuda_alike
():
from
.fused_moe
import
fused_experts
...
...
@@ -83,6 +84,20 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
super
().
process_weights_after_loading
(
layer
)
if
current_platform
.
is_cpu
():
if
current_platform
.
get_cpu_architecture
()
==
CpuArchEnum
.
X86
:
import
intel_extension_for_pytorch
as
ipex
layer
.
ipex_fusion
=
ipex
.
llm
.
modules
.
GatedMLPMOE
(
layer
.
w13_weight
,
layer
.
w2_weight
,
use_prepack
=
True
,
)
else
:
raise
NotImplementedError
(
"CPU MOE only supports x86 arch."
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
...
...
@@ -142,9 +157,29 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_ids
=
topk_ids
,
inplace
=
True
)
def
forward_cpu
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
(
"The CPU backend currently does not support MoE."
)
def
forward_cpu
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
use_grouped_topk
:
bool
,
top_k
:
int
,
router_logits
:
torch
.
Tensor
,
renormalize
:
bool
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
**
kwargs
,
):
assert
custom_routing_function
is
None
return
layer
.
ipex_fusion
(
x
,
use_grouped_topk
,
top_k
,
router_logits
,
renormalize
,
topk_group
,
num_expert_group
,
)
def
forward_tpu
(
self
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment