Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
32176fee
Unverified
Commit
32176fee
authored
Oct 27, 2024
by
youkaichao
Committed by
GitHub
Oct 27, 2024
Browse files
[torch.compile] support moe models (#9632)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
4e2d95e3
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
217 additions
and
78 deletions
+217
-78
benchmarks/kernels/benchmark_moe.py
benchmarks/kernels/benchmark_moe.py
+17
-16
tests/compile/test_basic_correctness.py
tests/compile/test_basic_correctness.py
+2
-2
tests/kernels/test_awq_marlin.py
tests/kernels/test_awq_marlin.py
+10
-11
tests/kernels/test_moe.py
tests/kernels/test_moe.py
+3
-4
vllm/model_executor/layers/fused_moe/__init__.py
vllm/model_executor/layers/fused_moe/__init__.py
+24
-4
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
+42
-9
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+93
-7
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+18
-11
vllm/model_executor/layers/quantization/awq_marlin.py
vllm/model_executor/layers/quantization/awq_marlin.py
+2
-5
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+2
-5
vllm/model_executor/layers/quantization/gptq_marlin.py
vllm/model_executor/layers/quantization/gptq_marlin.py
+2
-4
vllm/model_executor/models/granitemoe.py
vllm/model_executor/models/granitemoe.py
+2
-0
No files found.
benchmarks/kernels/benchmark_moe.py
View file @
32176fee
...
...
@@ -88,6 +88,8 @@ def benchmark_config(
input_gating
.
copy_
(
gating_output
[
i
])
def
run
():
from
vllm.model_executor.layers.fused_moe
import
override_config
with
override_config
(
config
):
fused_moe
(
x
,
w1
,
...
...
@@ -96,7 +98,6 @@ def benchmark_config(
topk
,
renormalize
=
True
,
inplace
=
True
,
override_config
=
config
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
w1_scale
=
w1_scale
,
...
...
tests/compile/test_basic_correctness.py
View file @
32176fee
...
...
@@ -13,11 +13,11 @@ from ..utils import compare_all_settings
@
pytest
.
mark
.
parametrize
(
"model, model_args, pp_size, tp_size, attn_backend, method, fullgraph"
,
[
(
"meta-llama/Llama-3.2-1B"
,
[],
2
,
2
,
"FLASH
_ATTN
"
,
"generate"
,
True
),
(
"meta-llama/Llama-3.2-1B"
,
[],
2
,
2
,
"FLASH
INFER
"
,
"generate"
,
True
),
(
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples"
,
[
"--quantization"
,
"compressed-tensors"
],
1
,
1
,
"FLASH_ATTN"
,
"generate"
,
True
),
(
"
google/gemma-2-2b-it
"
,
[],
1
,
2
,
"FLASH
INFER
"
,
"generate"
,
True
),
(
"
ibm/PowerMoE-3b
"
,
[],
1
,
2
,
"FLASH
_ATTN
"
,
"generate"
,
True
),
# TODO: add multi-modality test for llava
(
"llava-hf/llava-1.5-7b-hf"
,
[],
2
,
1
,
"FLASHINFER"
,
"generate"
,
False
)
])
...
...
tests/kernels/test_awq_marlin.py
View file @
32176fee
...
...
@@ -5,11 +5,10 @@ Run `pytest tests/kernels/test_awq_marlin.py`.
import
pytest
import
torch
import
vllm.model_executor.layers.fused_moe
# noqa
from
tests.kernels.utils
import
(
compute_max_diff
,
stack_and_dev
,
torch_moe
,
torch_moe_single
)
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.fused_moe.fused_marlin_moe
import
(
fused_marlin_moe
,
single_marlin_moe
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_topk
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test
import
(
awq_marlin_quantize
)
...
...
@@ -81,7 +80,7 @@ def test_fused_marlin_moe_awq(
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
dtype
)
topk_weights
,
topk_ids
=
fused_topk
(
a
,
score
,
topk
,
False
)
marlin_output
=
fused_marlin_moe
(
marlin_output
=
torch
.
ops
.
vllm
.
fused_marlin_moe
(
a
,
qweight1
,
qweight2
,
...
...
@@ -150,7 +149,7 @@ def test_single_marlin_moe_multiply_awq(
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
dtype
)
marlin_output
=
single_marlin_moe
(
a
,
marlin_output
=
torch
.
ops
.
vllm
.
single_marlin_moe
(
a
,
qweight
,
scales
,
score
,
...
...
tests/kernels/test_moe.py
View file @
32176fee
...
...
@@ -7,12 +7,11 @@ import torch
from
transformers
import
MixtralConfig
from
transformers.models.mixtral.modeling_mixtral
import
MixtralSparseMoeBlock
import
vllm.model_executor.layers.fused_moe
# noqa
from
tests.kernels.utils
import
(
compute_max_diff
,
opcheck
,
stack_and_dev
,
torch_moe
,
torch_moe_single
)
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.fused_moe.fused_marlin_moe
import
(
fused_marlin_moe
,
single_marlin_moe
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_topk
,
moe_align_block_size
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test
import
(
...
...
@@ -193,7 +192,7 @@ def test_fused_marlin_moe(
topk
,
renormalize
=
False
,
)
marlin_output
=
fused_marlin_moe
(
marlin_output
=
torch
.
ops
.
vllm
.
fused_marlin_moe
(
a
,
qweight1
,
qweight2
,
...
...
@@ -309,7 +308,7 @@ def test_single_marlin_moe_multiply(
sort_indices
=
stack_and_dev
(
sort_indices_l
)
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
dtype
)
marlin_output
=
single_marlin_moe
(
marlin_output
=
torch
.
ops
.
vllm
.
single_marlin_moe
(
a
,
qweight
,
scales
,
...
...
vllm/model_executor/layers/fused_moe/__init__.py
View file @
32176fee
from
contextlib
import
contextmanager
from
typing
import
Any
,
Dict
,
Optional
from
vllm.model_executor.layers.fused_moe.layer
import
(
FusedMoE
,
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
)
from
vllm.triton_utils
import
HAS_TRITON
_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
@
contextmanager
def
override_config
(
config
):
global
_config
old_config
=
_config
_config
=
config
yield
_config
=
old_config
def
get_config
()
->
Optional
[
Dict
[
str
,
Any
]]:
return
_config
__all__
=
[
"FusedMoE"
,
"FusedMoEMethodBase"
,
"FusedMoeWeightScaleSupported"
,
"override_config"
,
"get_config"
,
]
if
HAS_TRITON
:
from
vllm.model_executor.layers.fused_moe.fused_marlin_moe
import
(
fused_marlin_moe
,
single_marlin_moe
)
# import to register the custom ops
import
vllm.model_executor.layers.fused_moe.fused_marlin_moe
# noqa
import
vllm.model_executor.layers.fused_moe.fused_moe
# noqa
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_experts
,
fused_moe
,
fused_topk
,
get_config_file_name
,
grouped_topk
)
__all__
+=
[
"fused_marlin_moe"
,
"single_marlin_moe"
,
"fused_moe"
,
"fused_topk"
,
"fused_experts"
,
...
...
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
View file @
32176fee
"""Fused MoE utilities for GPTQ."""
import
functools
from
typing
import
Any
,
Dict
,
Optional
from
typing
import
Optional
import
torch
...
...
@@ -18,6 +18,7 @@ def get_scalar_type(num_bits: int, has_zp: bool):
return
scalar_types
.
uint4b8
if
num_bits
==
4
else
scalar_types
.
uint8b128
@
torch
.
library
.
custom_op
(
"vllm::single_marlin_moe"
,
mutates_args
=
[])
def
single_marlin_moe
(
hidden_states
:
torch
.
Tensor
,
w
:
torch
.
Tensor
,
...
...
@@ -28,7 +29,6 @@ def single_marlin_moe(
g_idx
:
Optional
[
torch
.
Tensor
]
=
None
,
sort_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
w_zeros
:
Optional
[
torch
.
Tensor
]
=
None
,
override_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
num_bits
:
int
=
8
,
is_k_full
:
bool
=
True
,
)
->
torch
.
Tensor
:
...
...
@@ -49,8 +49,6 @@ def single_marlin_moe(
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- w_zeros (Optional[torch.Tensor]): Optional zero points to be used for w.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- num_bits (bool): The number of bits in expert weights quantization.
Returns:
...
...
@@ -79,7 +77,6 @@ def single_marlin_moe(
w
.
shape
,
topk_ids
.
shape
[
1
],
None
,
override_config
=
override_config
,
is_marlin
=
True
)
config
=
get_config_func
(
M
)
...
...
@@ -122,6 +119,24 @@ def single_marlin_moe(
return
torch
.
sum
(
intermediate_cache
.
view
(
*
intermediate_cache
.
shape
),
dim
=
1
)
@
single_marlin_moe
.
register_fake
def
_
(
hidden_states
:
torch
.
Tensor
,
w
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
g_idx
:
Optional
[
torch
.
Tensor
]
=
None
,
sort_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
w_zeros
:
Optional
[
torch
.
Tensor
]
=
None
,
num_bits
:
int
=
8
,
is_k_full
:
bool
=
True
,
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
hidden_states
)
@
torch
.
library
.
custom_op
(
"vllm::fused_marlin_moe"
,
mutates_args
=
[])
def
fused_marlin_moe
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
...
...
@@ -137,7 +152,6 @@ def fused_marlin_moe(
sort_indices2
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_zeros
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_zeros
:
Optional
[
torch
.
Tensor
]
=
None
,
override_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
num_bits
:
int
=
8
,
is_k_full
:
bool
=
True
,
)
->
torch
.
Tensor
:
...
...
@@ -161,8 +175,6 @@ def fused_marlin_moe(
permutation.
- topk_weights (torch.Tensor): Top-k weights.
- topk_ids (torch.Tensor): Indices of topk-k elements.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1.
- w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2.
- num_bits (bool): The number of bits in expert weights quantization.
...
...
@@ -209,7 +221,6 @@ def fused_marlin_moe(
w2
.
shape
,
topk_ids
.
shape
[
1
],
None
,
override_config
=
override_config
,
is_marlin
=
True
,
)
config
=
get_config_func
(
M
)
...
...
@@ -311,3 +322,25 @@ def fused_marlin_moe(
return
torch
.
sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
dim
=
1
)
@
fused_marlin_moe
.
register_fake
def
_
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
g_idx1
:
Optional
[
torch
.
Tensor
]
=
None
,
g_idx2
:
Optional
[
torch
.
Tensor
]
=
None
,
sort_indices1
:
Optional
[
torch
.
Tensor
]
=
None
,
sort_indices2
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_zeros
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_zeros
:
Optional
[
torch
.
Tensor
]
=
None
,
num_bits
:
int
=
8
,
is_k_full
:
bool
=
True
,
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
hidden_states
)
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
32176fee
...
...
@@ -358,9 +358,10 @@ def try_get_optimal_moe_config(
top_k
:
int
,
dtype
:
Optional
[
str
],
M
:
int
,
override_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
is_marlin
:
bool
=
False
,
):
from
vllm.model_executor.layers.fused_moe
import
get_config
override_config
=
get_config
()
if
override_config
:
config
=
override_config
else
:
...
...
@@ -465,13 +466,103 @@ def get_config_dtype_str(dtype: torch.dtype,
return
None
@
torch
.
library
.
custom_op
(
"vllm::inplace_fused_experts"
,
mutates_args
=
[
"hidden_states"
])
def
inplace_fused_experts
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
)
->
None
:
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
True
,
use_fp8_w8a8
,
use_int8_w8a16
,
w1_scale
,
w2_scale
,
a1_scale
,
a2_scale
)
@
inplace_fused_experts
.
register_fake
def
_
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
)
->
None
:
pass
@
torch
.
library
.
custom_op
(
"vllm::outplace_fused_experts"
,
mutates_args
=
[])
def
outplace_fused_experts
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
return
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
False
,
use_fp8_w8a8
,
use_int8_w8a16
,
w1_scale
,
w2_scale
,
a1_scale
,
a2_scale
)
@
outplace_fused_experts
.
register_fake
def
_
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
hidden_states
)
def
fused_experts
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
inplace
:
bool
=
False
,
override_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
):
if
inplace
:
torch
.
ops
.
vllm
.
inplace_fused_experts
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
use_fp8_w8a8
,
use_int8_w8a16
,
w1_scale
,
w2_scale
,
a1_scale
,
a2_scale
)
return
hidden_states
else
:
return
torch
.
ops
.
vllm
.
outplace_fused_experts
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
use_fp8_w8a8
,
use_int8_w8a16
,
w1_scale
,
w2_scale
,
a1_scale
,
a2_scale
)
def
fused_experts_impl
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
inplace
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -504,7 +595,6 @@ def fused_experts(hidden_states: torch.Tensor,
w2
.
shape
,
topk_ids
.
shape
[
1
],
config_dtype
,
override_config
=
override_config
,
)
config
=
get_config_func
(
M
)
...
...
@@ -602,7 +692,6 @@ def fused_moe(
topk
:
int
,
renormalize
:
bool
,
inplace
:
bool
=
False
,
override_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
...
...
@@ -628,8 +717,6 @@ def fused_moe(
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- num_expert_group: Optional[int]: additional parameter for grouped_topk
- topk_group: Optional[int]: additional parameter for grouped_topk
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
...
...
@@ -667,7 +754,6 @@ def fused_moe(
topk_weights
,
topk_ids
,
inplace
=
inplace
,
override_config
=
override_config
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
w1_scale
=
w1_scale
,
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
32176fee
...
...
@@ -12,7 +12,16 @@ from vllm.model_executor.custom_op import CustomOp
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
if
current_platform
.
is_cuda_alike
():
from
.fused_moe
import
fused_experts
else
:
fused_experts
=
None
# type: ignore
if
current_platform
.
is_tpu
():
from
.moe_pallas
import
fused_moe
as
fused_moe_pallas
else
:
fused_moe_pallas
=
None
# type: ignore
logger
=
init_logger
(
__name__
)
...
...
@@ -96,9 +105,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_experts
)
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
...
...
@@ -132,18 +138,19 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe.moe_pallas
import
fused_moe
assert
not
use_grouped_topk
assert
num_expert_group
is
None
assert
topk_group
is
None
assert
custom_routing_function
is
None
return
fused_moe
(
hidden_states
=
x
,
return
fused_moe
_pallas
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
topk
=
top_k
,
gating_output
=
router_logits
,
renormalize
=
renormalize
)
forward_native
=
forward_cuda
class
FusedMoE
(
torch
.
nn
.
Module
):
"""FusedMoE layer for MoE models.
...
...
vllm/model_executor/layers/quantization/awq_marlin.py
View file @
32176fee
...
...
@@ -3,6 +3,7 @@ from typing import Any, Callable, Dict, List, Optional
import
torch
from
torch.nn
import
Parameter
import
vllm.model_executor.layers.fused_moe
# noqa
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe.layer
import
(
...
...
@@ -435,10 +436,6 @@ class AWQMoEMethod(FusedMoEMethodBase):
topk_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe.fused_marlin_moe
import
(
fused_marlin_moe
)
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
...
...
@@ -449,7 +446,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
)
return
fused_marlin_moe
(
return
torch
.
ops
.
vllm
.
fused_marlin_moe
(
x
,
layer
.
w13_qweight
,
layer
.
w2_qweight
,
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
32176fee
...
...
@@ -6,6 +6,7 @@ import torch
from
compressed_tensors
import
CompressionFormat
from
compressed_tensors.quantization
import
QuantizationStrategy
import
vllm.model_executor.layers.fused_moe
# noqa
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.fused_moe
import
(
FusedMoE
,
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
)
...
...
@@ -481,10 +482,6 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
topk_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe.fused_marlin_moe
import
(
fused_marlin_moe
)
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
...
...
@@ -495,7 +492,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
)
return
fused_marlin_moe
(
return
torch
.
ops
.
vllm
.
fused_marlin_moe
(
x
,
layer
.
w13_weight_packed
,
layer
.
w2_weight_packed
,
...
...
vllm/model_executor/layers/quantization/gptq_marlin.py
View file @
32176fee
...
...
@@ -2,6 +2,7 @@ from typing import Any, Callable, Dict, List, Optional, Set, Union
import
torch
import
vllm.model_executor.layers.fused_moe
# noqa
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe.layer
import
(
...
...
@@ -536,9 +537,6 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
topk_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe.fused_marlin_moe
import
(
fused_marlin_moe
)
# The input must currently be float16
orig_dtype
=
x
.
dtype
x
=
x
.
half
()
...
...
@@ -553,7 +551,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
num_expert_group
=
num_expert_group
,
custom_routing_function
=
None
)
return
fused_marlin_moe
(
return
torch
.
ops
.
vllm
.
fused_marlin_moe
(
x
,
layer
.
w13_qweight
,
layer
.
w2_qweight
,
...
...
vllm/model_executor/models/granitemoe.py
View file @
32176fee
...
...
@@ -28,6 +28,7 @@ from torch import nn
from
transformers.models.granitemoe
import
GraniteMoeConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
LoRAConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
...
...
@@ -244,6 +245,7 @@ class GraniteMoeDecoderLayer(nn.Module):
return
hidden_states
@
support_torch_compile
class
GraniteMoeModel
(
nn
.
Module
):
def
__init__
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment