Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2339d59f
Unverified
Commit
2339d59f
authored
Dec 27, 2024
by
Robert Shaw
Committed by
GitHub
Dec 26, 2024
Browse files
[BugFix] Fix quantization for all other methods (#11547)
parent
1b875a0e
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
52 additions
and
22 deletions
+52
-22
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+15
-4
vllm/model_executor/layers/quantization/awq_marlin.py
vllm/model_executor/layers/quantization/awq_marlin.py
+7
-3
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+15
-7
vllm/model_executor/layers/quantization/experts_int8.py
vllm/model_executor/layers/quantization/experts_int8.py
+7
-3
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+1
-2
vllm/model_executor/layers/quantization/gptq_marlin.py
vllm/model_executor/layers/quantization/gptq_marlin.py
+7
-3
No files found.
vllm/model_executor/layers/fused_moe/layer.py
View file @
2339d59f
...
...
@@ -41,9 +41,20 @@ class FusedMoEMethodBase(QuantizeMethodBase):
raise
NotImplementedError
@
abstractmethod
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
)
->
torch
.
Tensor
:
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
raise
NotImplementedError
...
...
@@ -79,7 +90,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
...
...
vllm/model_executor/layers/quantization/awq_marlin.py
View file @
2339d59f
...
...
@@ -440,11 +440,13 @@ class AWQMoEMethod(FusedMoEMethodBase):
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
=
True
,
renormalize
:
bool
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
...
...
@@ -454,7 +456,9 @@ class AWQMoEMethod(FusedMoEMethodBase):
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
)
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
)
return
torch
.
ops
.
vllm
.
fused_marlin_moe
(
x
,
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
2339d59f
...
...
@@ -203,13 +203,14 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
=
True
,
renormalize
:
bool
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
...
...
@@ -220,7 +221,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
)
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
)
return
fused_experts
(
x
,
layer
.
w13_weight
,
...
...
@@ -476,12 +479,15 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
=
True
,
renormalize
:
bool
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
...
...
@@ -490,7 +496,9 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
)
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
)
return
torch
.
ops
.
vllm
.
fused_marlin_moe
(
x
,
...
...
vllm/model_executor/layers/quantization/experts_int8.py
View file @
2339d59f
...
...
@@ -99,11 +99,13 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
=
True
,
renormalize
:
bool
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
...
...
@@ -115,7 +117,9 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
)
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
)
return
fused_experts
(
x
,
layer
.
w13_weight
,
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
2339d59f
...
...
@@ -601,14 +601,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
...
...
vllm/model_executor/layers/quantization/gptq_marlin.py
View file @
2339d59f
...
...
@@ -532,11 +532,13 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
=
True
,
renormalize
:
bool
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
# The input must currently be float16
orig_dtype
=
x
.
dtype
...
...
@@ -550,7 +552,9 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
None
)
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
)
return
torch
.
ops
.
vllm
.
fused_marlin_moe
(
x
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment