Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
96ae75ad
"components/vscode:/vscode.git/clone" did not exist on "41d7d5490fc8e723fa1ef88ec946d9c7f4ec89b4"
Commit
96ae75ad
authored
Jan 04, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.6.6.post1' into v0.6.6.post1-dev
parents
f9f4a735
2339d59f
Changes
374
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1056 additions
and
261 deletions
+1056
-261
vllm/lora/models.py
vllm/lora/models.py
+8
-4
vllm/lora/utils.py
vllm/lora/utils.py
+23
-4
vllm/lora/worker_manager.py
vllm/lora/worker_manager.py
+12
-1
vllm/model_executor/guided_decoding/__init__.py
vllm/model_executor/guided_decoding/__init__.py
+39
-40
vllm/model_executor/guided_decoding/outlines_logits_processors.py
...el_executor/guided_decoding/outlines_logits_processors.py
+14
-9
vllm/model_executor/guided_decoding/utils.py
vllm/model_executor/guided_decoding/utils.py
+70
-0
vllm/model_executor/guided_decoding/xgrammar_decoding.py
vllm/model_executor/guided_decoding/xgrammar_decoding.py
+75
-42
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+117
-31
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+86
-40
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+21
-2
vllm/model_executor/layers/quantization/awq_marlin.py
vllm/model_executor/layers/quantization/awq_marlin.py
+7
-3
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+160
-27
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+15
-7
vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py
...ayers/quantization/compressed_tensors/schemes/__init__.py
+7
-8
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py
...ation/compressed_tensors/schemes/compressed_tensors_24.py
+208
-0
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py
...compressed_tensors/schemes/compressed_tensors_w4a16_24.py
+4
-0
vllm/model_executor/layers/quantization/compressed_tensors/utils.py
..._executor/layers/quantization/compressed_tensors/utils.py
+1
-1
vllm/model_executor/layers/quantization/experts_int8.py
vllm/model_executor/layers/quantization/experts_int8.py
+7
-3
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+175
-36
vllm/model_executor/layers/quantization/gptq_marlin.py
vllm/model_executor/layers/quantization/gptq_marlin.py
+7
-3
No files found.
vllm/lora/models.py
View file @
96ae75ad
...
@@ -28,7 +28,7 @@ from vllm.lora.utils import (from_layer, from_layer_logits_processor,
...
@@ -28,7 +28,7 @@ from vllm.lora.utils import (from_layer, from_layer_logits_processor,
parse_fine_tuned_lora_name
,
replace_submodule
)
parse_fine_tuned_lora_name
,
replace_submodule
)
from
vllm.model_executor.models
import
SupportsLoRA
,
supports_multimodal
from
vllm.model_executor.models
import
SupportsLoRA
,
supports_multimodal
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.model_executor.models.utils
import
PPMissingLayer
from
vllm.model_executor.models.utils
import
PPMissingLayer
,
WeightsMapper
from
vllm.utils
import
is_pin_memory_available
from
vllm.utils
import
is_pin_memory_available
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -113,13 +113,14 @@ class LoRAModel(AdapterModel):
...
@@ -113,13 +113,14 @@ class LoRAModel(AdapterModel):
target_embedding_padding
:
Optional
[
int
]
=
None
,
target_embedding_padding
:
Optional
[
int
]
=
None
,
embedding_modules
:
Optional
[
Dict
[
str
,
str
]]
=
None
,
embedding_modules
:
Optional
[
Dict
[
str
,
str
]]
=
None
,
embedding_padding_modules
:
Optional
[
List
[
str
]]
=
None
,
embedding_padding_modules
:
Optional
[
List
[
str
]]
=
None
,
weights_mapper
:
Optional
[
WeightsMapper
]
=
None
,
)
->
"LoRAModel"
:
)
->
"LoRAModel"
:
"""Create a LoRAModel from a dictionary of tensors."""
"""Create a LoRAModel from a dictionary of tensors."""
pin_memory
=
str
(
device
)
==
"cpu"
and
is_pin_memory_available
()
pin_memory
=
str
(
device
)
==
"cpu"
and
is_pin_memory_available
()
loras
:
Dict
[
str
,
LoRALayerWeights
]
=
{}
loras
:
Dict
[
str
,
LoRALayerWeights
]
=
{}
for
tensor_name
,
tensor
in
tensors
.
items
():
for
tensor_name
,
tensor
in
tensors
.
items
():
module_name
,
is_lora_a
,
is_bias
=
parse_fine_tuned_lora_name
(
module_name
,
is_lora_a
,
is_bias
=
parse_fine_tuned_lora_name
(
tensor_name
)
tensor_name
,
weights_mapper
)
if
module_name
not
in
loras
:
if
module_name
not
in
loras
:
lora_embeddings_tensor
=
None
lora_embeddings_tensor
=
None
if
embeddings
:
if
embeddings
:
...
@@ -187,6 +188,7 @@ class LoRAModel(AdapterModel):
...
@@ -187,6 +188,7 @@ class LoRAModel(AdapterModel):
target_embedding_padding
:
Optional
[
int
]
=
None
,
target_embedding_padding
:
Optional
[
int
]
=
None
,
embedding_modules
:
Optional
[
Dict
[
str
,
str
]]
=
None
,
embedding_modules
:
Optional
[
Dict
[
str
,
str
]]
=
None
,
embedding_padding_modules
:
Optional
[
List
[
str
]]
=
None
,
embedding_padding_modules
:
Optional
[
List
[
str
]]
=
None
,
weights_mapper
:
Optional
[
WeightsMapper
]
=
None
,
)
->
"LoRAModel"
:
)
->
"LoRAModel"
:
"""Create a LoRAModel from a local checkpoint.
"""Create a LoRAModel from a local checkpoint.
...
@@ -229,7 +231,8 @@ class LoRAModel(AdapterModel):
...
@@ -229,7 +231,8 @@ class LoRAModel(AdapterModel):
with
safetensors
.
safe_open
(
lora_tensor_path
,
with
safetensors
.
safe_open
(
lora_tensor_path
,
framework
=
"pt"
)
as
f
:
# type: ignore
framework
=
"pt"
)
as
f
:
# type: ignore
for
lora_module
in
f
.
keys
():
# noqa
for
lora_module
in
f
.
keys
():
# noqa
module_name
,
_
,
_
=
parse_fine_tuned_lora_name
(
lora_module
)
module_name
,
_
,
_
=
parse_fine_tuned_lora_name
(
lora_module
,
weights_mapper
)
part_name
=
module_name
.
split
(
"."
)[
-
1
]
part_name
=
module_name
.
split
(
"."
)[
-
1
]
if
part_name
not
in
expected_lora_modules
:
if
part_name
not
in
expected_lora_modules
:
unexpected_modules
.
append
(
module_name
)
unexpected_modules
.
append
(
module_name
)
...
@@ -289,7 +292,8 @@ class LoRAModel(AdapterModel):
...
@@ -289,7 +292,8 @@ class LoRAModel(AdapterModel):
embeddings
=
embeddings
,
embeddings
=
embeddings
,
target_embedding_padding
=
target_embedding_padding
,
target_embedding_padding
=
target_embedding_padding
,
embedding_modules
=
embedding_modules
,
embedding_modules
=
embedding_modules
,
embedding_padding_modules
=
embedding_padding_modules
)
embedding_padding_modules
=
embedding_padding_modules
,
weights_mapper
=
weights_mapper
)
class
LoRAModelManager
(
AdapterModelManager
):
class
LoRAModelManager
(
AdapterModelManager
):
...
...
vllm/lora/utils.py
View file @
96ae75ad
...
@@ -30,6 +30,7 @@ from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
...
@@ -30,6 +30,7 @@ from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
# yapf: enable
# yapf: enable
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.models.utils
import
WeightsMapper
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -91,28 +92,46 @@ def replace_submodule(model: nn.Module, module_name: str,
...
@@ -91,28 +92,46 @@ def replace_submodule(model: nn.Module, module_name: str,
return
new_module
return
new_module
def
parse_fine_tuned_lora_name
(
name
:
str
)
->
Tuple
[
str
,
bool
,
bool
]:
def
parse_fine_tuned_lora_name
(
name
:
str
,
weights_mapper
:
Optional
[
WeightsMapper
]
=
None
)
->
Tuple
[
str
,
bool
,
bool
]:
"""Parse the name of lora weights.
"""Parse the name of lora weights.
args:
args:
name: the name of the fine-tuned LoRA, e.g.
name: the name of the fine-tuned LoRA, e.g.
base_model.model.dense1.weight
base_model.model.dense1.weight
weights_mapper: maps the name of weight, e.g.
`model.` -> `language_model.model.`,
return:
return:
Tuple(module_name, is_lora_a):
Tuple(module_name, is_lora_a):
module_name: the name of the module, e.g. model.dense1,
module_name: the name of the module, e.g. model.dense1,
is_lora_a whether the tensor is lora_a or lora_b.
is_lora_a whether the tensor is lora_a or lora_b.
is_bias whether the tensor is lora bias.
is_bias whether the tensor is lora bias.
"""
"""
# LoRA weight qualified name always starts with `base_model.model.`,
# so we remove the prefix `base_model.model.` to make the following
# mapping correctly.
if
"base_model.model."
in
name
:
name
=
name
.
replace
(
"base_model.model."
,
""
)
name
=
weights_mapper
.
_map_name
(
name
)
if
weights_mapper
else
name
# recover the prefix `base_model.model.`
name
=
"base_model.model."
+
name
parts
=
name
.
split
(
"."
)
parts
=
name
.
split
(
"."
)
if
parts
[
-
1
]
==
"weight"
and
(
parts
[
-
2
]
==
"lora_A"
if
parts
[
-
1
]
==
"weight"
and
(
parts
[
-
2
]
==
"lora_A"
or
parts
[
-
2
]
==
"lora_B"
):
or
parts
[
-
2
]
==
"lora_B"
):
return
"."
.
join
(
parts
[
2
:
-
2
]),
parts
[
-
2
]
==
"lora_A"
,
False
new_name
=
"."
.
join
(
parts
[
2
:
-
2
])
return
new_name
,
parts
[
-
2
]
==
"lora_A"
,
False
if
parts
[
-
1
]
==
"lora_embedding_A"
or
parts
[
-
1
]
==
"lora_embedding_B"
:
if
parts
[
-
1
]
==
"lora_embedding_A"
or
parts
[
-
1
]
==
"lora_embedding_B"
:
return
"."
.
join
(
parts
[
2
:
-
1
]),
parts
[
-
1
]
==
"lora_embedding_A"
,
False
new_name
=
"."
.
join
(
parts
[
2
:
-
1
])
return
new_name
,
parts
[
-
1
]
==
"lora_embedding_A"
,
False
if
parts
[
-
1
]
==
"bias"
:
if
parts
[
-
1
]
==
"bias"
:
return
"."
.
join
(
parts
[
2
:
-
2
]),
False
,
True
new_name
=
"."
.
join
(
parts
[
2
:
-
2
])
return
new_name
,
False
,
True
raise
ValueError
(
f
"
{
name
}
is unsupported LoRA weight"
)
raise
ValueError
(
f
"
{
name
}
is unsupported LoRA weight"
)
...
...
vllm/lora/worker_manager.py
View file @
96ae75ad
...
@@ -91,7 +91,17 @@ class WorkerLoRAManager(AbstractWorkerManager):
...
@@ -91,7 +91,17 @@ class WorkerLoRAManager(AbstractWorkerManager):
packed_modules_mapping
[
module
])
packed_modules_mapping
[
module
])
else
:
else
:
expected_lora_modules
.
append
(
module
)
expected_lora_modules
.
append
(
module
)
expected_lora_modules
=
list
(
set
(
expected_lora_modules
))
lora_path
=
get_adapter_absolute_path
(
lora_request
.
lora_path
)
lora_path
=
get_adapter_absolute_path
(
lora_request
.
lora_path
)
# For some models like Qwen2VL, we need to use hf_to_vllm_mapper
# to ensure correct loading of lora weights.
hf_to_vllm_mapper
=
None
if
(
hasattr
(
model
,
"hf_to_vllm_mapper"
)
and
model
.
hf_to_vllm_mapper
is
not
None
):
hf_to_vllm_mapper
=
model
.
hf_to_vllm_mapper
lora
=
self
.
_lora_model_cls
.
from_local_checkpoint
(
lora
=
self
.
_lora_model_cls
.
from_local_checkpoint
(
lora_path
,
lora_path
,
expected_lora_modules
,
expected_lora_modules
,
...
@@ -103,7 +113,8 @@ class WorkerLoRAManager(AbstractWorkerManager):
...
@@ -103,7 +113,8 @@ class WorkerLoRAManager(AbstractWorkerManager):
self
.
lora_config
.
lora_extra_vocab_size
,
self
.
lora_config
.
lora_extra_vocab_size
,
embedding_modules
=
self
.
embedding_modules
,
embedding_modules
=
self
.
embedding_modules
,
embedding_padding_modules
=
self
.
embedding_padding_modules
,
embedding_padding_modules
=
self
.
embedding_padding_modules
,
)
weights_mapper
=
hf_to_vllm_mapper
)
except
Exception
as
e
:
except
Exception
as
e
:
raise
RuntimeError
(
f
"Loading lora
{
lora_path
}
failed"
)
from
e
raise
RuntimeError
(
f
"Loading lora
{
lora_path
}
failed"
)
from
e
if
lora
.
rank
>
self
.
lora_config
.
max_lora_rank
:
if
lora
.
rank
>
self
.
lora_config
.
max_lora_rank
:
...
...
vllm/model_executor/guided_decoding/__init__.py
View file @
96ae75ad
...
@@ -3,6 +3,9 @@ from __future__ import annotations
...
@@ -3,6 +3,9 @@ from __future__ import annotations
from
typing
import
TYPE_CHECKING
from
typing
import
TYPE_CHECKING
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.guided_decoding.utils
import
(
convert_lark_to_gbnf
,
grammar_is_likely_lark
,
has_lmf_unsupported_json_features
,
has_xgrammar_unsupported_json_features
)
from
vllm.platforms
import
CpuArchEnum
,
current_platform
from
vllm.platforms
import
CpuArchEnum
,
current_platform
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -15,49 +18,24 @@ if TYPE_CHECKING:
...
@@ -15,49 +18,24 @@ if TYPE_CHECKING:
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
def
has_xgrammar_unsupported_json_features
(
schema
:
dict
)
->
bool
:
"""Check if JSON schema contains features unsupported by xgrammar."""
def
check_object
(
obj
:
dict
)
->
bool
:
if
not
isinstance
(
obj
,
dict
):
return
False
# Check for pattern restrictions
if
"pattern"
in
obj
:
return
True
# Check for numeric ranges
if
obj
.
get
(
"type"
)
in
(
"integer"
,
"number"
)
and
any
(
key
in
obj
for
key
in
[
"minimum"
,
"maximum"
,
"exclusiveMinimum"
,
"exclusiveMaximum"
,
"multipleOf"
]):
return
True
# Recursively check all nested objects and arrays
for
value
in
obj
.
values
():
if
isinstance
(
value
,
dict
):
if
check_object
(
value
):
return
True
elif
isinstance
(
value
,
list
):
for
item
in
value
:
if
isinstance
(
item
,
dict
)
and
check_object
(
item
):
return
True
return
False
return
check_object
(
schema
)
def
maybe_backend_fallback
(
def
maybe_backend_fallback
(
guided_params
:
GuidedDecodingParams
)
->
GuidedDecodingParams
:
guided_params
:
GuidedDecodingParams
)
->
GuidedDecodingParams
:
# lm-format-enforce doesn't support grammar, fallback to xgrammar
# lm-format-enforce doesn't support grammar, fallback to xgrammar
if
(
guided_params
.
backend
==
"lm-format-enforcer"
if
guided_params
.
backend
==
"lm-format-enforcer"
:
and
guided_params
.
grammar
is
not
None
):
if
guided_params
.
grammar
is
not
None
:
logger
.
warning
(
logger
.
warning
(
"lm-format-enforcer does not support grammar guided decoding. "
"lm-format-enforcer does not support grammar guided decoding. "
"Falling back to use xgrammar instead."
)
"Falling back to use xgrammar instead."
)
guided_params
.
backend
=
"xgrammar"
guided_params
.
backend
=
"xgrammar"
# lm-format-enforcer doesn't support some JSON schema features
elif
(
guided_params
.
json
is
not
None
and
has_lmf_unsupported_json_features
(
guided_params
.
json
)):
logger
.
warning
(
"lm-format-enforcer does not support advanced JSON schema "
"features like patterns or numeric ranges. "
"Falling back to use outlines instead."
)
guided_params
.
backend
=
"outlines"
if
guided_params
.
backend
==
"xgrammar"
:
if
guided_params
.
backend
==
"xgrammar"
:
# xgrammar only has x86 wheels for linux, fallback to outlines
# xgrammar only has x86 wheels for linux, fallback to outlines
...
@@ -82,6 +60,27 @@ def maybe_backend_fallback(
...
@@ -82,6 +60,27 @@ def maybe_backend_fallback(
"Falling back to use outlines instead."
)
"Falling back to use outlines instead."
)
guided_params
.
backend
=
"outlines"
guided_params
.
backend
=
"outlines"
# xgrammar only supports GBNF grammars, so we must convert Lark.
# We must check if the grammar is likely Lark and if that
# grammar is convertible to GBNF
elif
(
guided_params
.
grammar
is
not
None
and
grammar_is_likely_lark
(
guided_params
.
grammar
)):
try
:
convert_lark_to_gbnf
(
guided_params
.
grammar
)
except
Exception
:
logger
.
warning
(
"xgrammar does not support Lark grammars and the "
"grammar failed to convert to GBNF. "
"Falling back to use outlines instead."
)
guided_params
.
backend
=
"outlines"
if
(
guided_params
.
backend
==
"outlines"
and
guided_params
.
json_object
is
not
None
):
# outlines doesn't support json_object, fallback to xgrammar
logger
.
warning
(
"outlines does not support json_object. "
"Falling back to use xgrammar instead."
)
guided_params
.
backend
=
"xgrammar"
return
guided_params
return
guided_params
...
...
vllm/model_executor/guided_decoding/outlines_logits_processors.py
View file @
96ae75ad
...
@@ -21,10 +21,11 @@ from typing import Callable, DefaultDict, Dict, List, Union
...
@@ -21,10 +21,11 @@ from typing import Callable, DefaultDict, Dict, List, Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
lark
import
Lark
from
outlines
import
grammars
from
outlines
import
grammars
from
outlines.caching
import
cache
from
outlines.caching
import
cache
from
outlines.fsm.guide
import
CFGGuide
,
Generate
,
Guide
,
RegexGuide
,
Write
from
outlines.fsm.guide
import
(
CFGGuide
,
CFGState
,
Generate
,
Guide
,
RegexGuide
,
Write
)
from
outlines.fsm.parsing
import
PartialLark
from
outlines_core.fsm.json_schema
import
build_regex_from_schema
from
outlines_core.fsm.json_schema
import
build_regex_from_schema
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
from
transformers
import
PreTrainedTokenizerBase
from
transformers
import
PreTrainedTokenizerBase
...
@@ -34,7 +35,9 @@ class BaseLogitsProcessor:
...
@@ -34,7 +35,9 @@ class BaseLogitsProcessor:
def
__init__
(
self
,
guide
:
Guide
):
def
__init__
(
self
,
guide
:
Guide
):
self
.
_guide
:
Guide
=
guide
self
.
_guide
:
Guide
=
guide
self
.
_fsm_state
:
DefaultDict
[
int
,
int
]
=
defaultdict
(
int
)
# CFGState is used for the FSM state for CFGGuide
self
.
_fsm_state
:
DefaultDict
[
int
,
Union
[
int
,
CFGState
]]
=
defaultdict
(
int
)
def
__call__
(
self
,
input_ids
:
List
[
int
],
def
__call__
(
self
,
input_ids
:
List
[
int
],
scores
:
torch
.
Tensor
)
->
torch
.
Tensor
:
scores
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
@@ -54,15 +57,13 @@ class BaseLogitsProcessor:
...
@@ -54,15 +57,13 @@ class BaseLogitsProcessor:
# On the first time this is called, we simply re-create
# On the first time this is called, we simply re-create
# the Lark object.
# the Lark object.
if
isinstance
(
self
.
_guide
,
CFGGuide
):
if
isinstance
(
self
.
_guide
,
CFGGuide
):
self
.
_guide
.
parser
=
Lark
(
self
.
_guide
.
parser
=
Partial
Lark
(
self
.
_guide
.
cfg_string
,
self
.
_guide
.
cfg_string
,
parser
=
"lalr"
,
parser
=
"lalr"
,
lexer
=
"contextual"
,
propagate_positions
=
False
,
maybe_placeholders
=
False
,
regex
=
True
,
import_paths
=
[
grammars
.
GRAMMAR_PATH
],
import_paths
=
[
grammars
.
GRAMMAR_PATH
],
)
)
self
.
_fsm_state
[
seq_id
]
=
CFGState
(
parser_state
=
self
.
_guide
.
parser
.
parse
(
""
),
prev_token
=
None
)
instruction
=
self
.
_guide
.
get_next_instruction
(
instruction
=
self
.
_guide
.
get_next_instruction
(
state
=
self
.
_fsm_state
[
seq_id
])
state
=
self
.
_fsm_state
[
seq_id
])
...
@@ -200,7 +201,8 @@ def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase):
...
@@ -200,7 +201,8 @@ def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase):
string
=
tokenizer
.
convert_tokens_to_string
([
token
])
string
=
tokenizer
.
convert_tokens_to_string
([
token
])
# A hack to handle missing spaces to HF's Llama tokenizers
# A hack to handle missing spaces to HF's Llama tokenizers
if
token
.
startswith
(
SPIECE_UNDERLINE
)
or
token
==
"<0x20>"
:
if
(
type
(
token
)
is
str
and
token
.
startswith
(
SPIECE_UNDERLINE
)
or
token
==
"<0x20>"
):
return
" "
+
string
return
" "
+
string
return
string
return
string
...
@@ -211,6 +213,9 @@ def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase):
...
@@ -211,6 +213,9 @@ def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase):
"""Sync vLLM's decoder with the outlines by returning list."""
"""Sync vLLM's decoder with the outlines by returning list."""
def
new_decoder
(
inp_tokens
:
List
[
int
])
->
List
[
str
]:
def
new_decoder
(
inp_tokens
:
List
[
int
])
->
List
[
str
]:
if
(
isinstance
(
inp_tokens
,
list
)
and
len
(
inp_tokens
)
==
1
and
isinstance
(
inp_tokens
[
0
],
list
)):
inp_tokens
=
inp_tokens
[
0
]
return
[
decoder
(
inp_tokens
)]
return
[
decoder
(
inp_tokens
)]
return
new_decoder
return
new_decoder
...
...
vllm/model_executor/guided_decoding/
xgrammar_
utils.py
→
vllm/model_executor/guided_decoding/utils.py
View file @
96ae75ad
import
re
import
re
def
has_xgrammar_unsupported_json_features
(
schema
:
dict
)
->
bool
:
"""Check if JSON schema contains features unsupported by xgrammar."""
def
check_object
(
obj
:
dict
)
->
bool
:
if
not
isinstance
(
obj
,
dict
):
return
False
# Check for pattern restrictions
if
"pattern"
in
obj
:
return
True
# Check for numeric ranges
if
obj
.
get
(
"type"
)
in
(
"integer"
,
"number"
)
and
any
(
key
in
obj
for
key
in
[
"minimum"
,
"maximum"
,
"exclusiveMinimum"
,
"exclusiveMaximum"
,
"multipleOf"
]):
return
True
# Recursively check all nested objects and arrays
for
value
in
obj
.
values
():
if
isinstance
(
value
,
dict
):
if
check_object
(
value
):
return
True
elif
isinstance
(
value
,
list
):
for
item
in
value
:
if
isinstance
(
item
,
dict
)
and
check_object
(
item
):
return
True
return
False
return
check_object
(
schema
)
def
has_lmf_unsupported_json_features
(
schema
:
dict
)
->
bool
:
"""
Check if JSON schema contains features unsupported
by lm_format_enforcer.
Known issues:
- Regex patterns:
"grade": {
"type": "string",
"pattern": "^[A-D]$" # Regex pattern
},
"""
def
check_object
(
obj
:
dict
)
->
bool
:
if
not
isinstance
(
obj
,
dict
):
return
False
# Check for pattern restrictions
if
"pattern"
in
obj
:
return
True
# Recursively check all nested objects and arrays
for
value
in
obj
.
values
():
if
isinstance
(
value
,
dict
):
if
check_object
(
value
):
return
True
elif
isinstance
(
value
,
list
):
for
item
in
value
:
if
isinstance
(
item
,
dict
)
and
check_object
(
item
):
return
True
return
False
return
check_object
(
schema
)
def
grammar_is_likely_lark
(
grammar_str
:
str
)
->
bool
:
def
grammar_is_likely_lark
(
grammar_str
:
str
)
->
bool
:
"""
"""
Check if grammar appears to use Lark syntax.
Check if grammar appears to use Lark syntax.
...
...
vllm/model_executor/guided_decoding/xgrammar_decoding.py
View file @
96ae75ad
...
@@ -3,7 +3,7 @@ from __future__ import annotations
...
@@ -3,7 +3,7 @@ from __future__ import annotations
import
json
import
json
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
typing
import
TYPE_CHECKING
,
Any
,
NamedTuple
from
typing
import
TYPE_CHECKING
,
Any
import
torch
import
torch
from
transformers
import
PreTrainedTokenizerFast
from
transformers
import
PreTrainedTokenizerFast
...
@@ -14,8 +14,9 @@ try:
...
@@ -14,8 +14,9 @@ try:
except
ImportError
:
except
ImportError
:
pass
pass
from
vllm.model_executor.guided_decoding.xgrammar_utils
import
(
from
vllm.model_executor.guided_decoding.utils
import
(
convert_lark_to_gbnf
,
convert_lark_to_gbnf
,
grammar_is_likely_lark
)
grammar_is_likely_lark
)
from
vllm.transformers_utils.tokenizers.mistral
import
MistralTokenizer
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
transformers
import
PreTrainedTokenizer
from
transformers
import
PreTrainedTokenizer
...
@@ -37,11 +38,21 @@ def get_local_xgrammar_guided_decoding_logits_processor(
...
@@ -37,11 +38,21 @@ def get_local_xgrammar_guided_decoding_logits_processor(
return
XGrammarLogitsProcessor
(
config
)
return
XGrammarLogitsProcessor
(
config
)
class
TokenizerData
(
NamedTuple
):
@
dataclass
(
frozen
=
True
)
class
TokenizerData
:
"""Immutable container for cached tokenizer data."""
"""Immutable container for cached tokenizer data."""
encoded_vocab
:
list
[
str
]
encoded_vocab
:
list
[
str
]
=
field
(
default_factory
=
list
)
stop_token_ids
:
list
[
int
]
|
None
stop_token_ids
:
list
[
int
]
|
None
=
None
backend_str
:
str
# These fields are mutually exclusive: `backend_str` is used to create a
# TokenizeInfo with `TokenizerInfo.from_huggingface` while `vocab_type` is
# used within the constructor of TokenizeInfo
backend_str
:
str
|
None
=
None
vocab_type
:
xgr
.
VocabType
|
None
=
None
def
__post_init__
(
self
):
# Check for mutual exclusive
assert
not
(
self
.
backend_str
and
self
.
vocab_type
),
\
"backend_str and vocab_type are mutual exclusive"
class
TokenizerDataCache
:
class
TokenizerDataCache
:
...
@@ -68,18 +79,27 @@ class TokenizerDataCache:
...
@@ -68,18 +79,27 @@ class TokenizerDataCache:
"get_vocab method."
)
from
e
"get_vocab method."
)
from
e
stop_token_ids
=
None
stop_token_ids
=
None
backend_str
=
xgr
.
VocabType
.
RAW
backend_str
=
""
vocab_type
=
xgr
.
VocabType
.
RAW
if
stop_token_ids
is
None
and
hasattr
(
tokenizer
,
"eos_token_id"
)
and
tokenizer
.
eos_token_id
is
not
None
:
stop_token_ids
=
[
tokenizer
.
eos_token_id
]
if
isinstance
(
tokenizer
,
PreTrainedTokenizerFast
):
if
isinstance
(
tokenizer
,
PreTrainedTokenizerFast
):
backend_str
=
tokenizer
.
backend_tokenizer
.
to_str
()
backend_str
=
tokenizer
.
backend_tokenizer
.
to_str
()
if
stop_token_ids
is
None
and
hasattr
(
vocab_type
=
None
tokenizer
,
"eos_token_id"
)
and
tokenizer
.
eos_token_id
is
not
None
:
elif
isinstance
(
tokenizer
,
MistralTokenizer
):
stop_token_ids
=
[
tokenizer
.
eos_token_id
]
# REF: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501
vocab_type
=
xgr
.
VocabType
.
BYTE_FALLBACK
cls
.
_cache
[
tokenizer_hash
]
=
TokenizerData
(
cls
.
_cache
[
tokenizer_hash
]
=
TokenizerData
(
encoded_vocab
=
encoded_vocab
,
encoded_vocab
=
encoded_vocab
,
stop_token_ids
=
stop_token_ids
,
stop_token_ids
=
stop_token_ids
,
backend_str
=
backend_str
)
backend_str
=
backend_str
,
vocab_type
=
vocab_type
)
return
cls
.
_cache
[
tokenizer_hash
]
return
cls
.
_cache
[
tokenizer_hash
]
...
@@ -98,11 +118,30 @@ class GrammarCompilerCache:
...
@@ -98,11 +118,30 @@ class GrammarCompilerCache:
cache_key
=
str
(
config
.
tokenizer_hash
)
cache_key
=
str
(
config
.
tokenizer_hash
)
if
cache_key
not
in
cls
.
_cache
:
if
cache_key
not
in
cls
.
_cache
:
assert
config
.
encoded_vocab
is
not
None
assert
config
.
tokenizer_data
is
not
None
tokenizer_info
=
xgr
.
TokenizerInfo
.
_create_from_handle
(
assert
config
.
tokenizer_data
.
encoded_vocab
is
not
None
xgr_core
.
TokenizerInfo
.
from_huggingface
(
config
.
encoded_vocab
,
config
.
backend_str
,
config_data
=
config
.
tokenizer_data
config
.
vocab_size
,
config
.
stop_token_ids
))
# In TokenizerDataCache.get_tokenizer_data, a serializable
# tokenizer_data is created and cached. This data is used to build
# a tokenizer_info and create an xgrammar compiler.
# - If tokenizer_data has backend_str set, use
# xgr_core.TokenizerInfo.from_huggingface (a C++ bind).
# - Otherwise, use the default constructor with vocab_type.
# - xgr_core.TokenizerInfo.from_huggingface !=
# xgr.TokenizerInfo.from_huggingface.
if
config_data
.
backend_str
:
tokenizer_info
=
xgr
.
TokenizerInfo
.
_create_from_handle
(
xgr_core
.
TokenizerInfo
.
from_huggingface
(
config_data
.
encoded_vocab
,
config_data
.
backend_str
,
config
.
vocab_size
,
config_data
.
stop_token_ids
))
else
:
tokenizer_info
=
xgr
.
TokenizerInfo
(
config_data
.
encoded_vocab
,
config_data
.
vocab_type
,
vocab_size
=
config
.
vocab_size
,
stop_token_ids
=
config_data
.
stop_token_ids
)
cls
.
_cache
[
cache_key
]
=
xgr
.
GrammarCompiler
(
cls
.
_cache
[
cache_key
]
=
xgr
.
GrammarCompiler
(
tokenizer_info
,
max_threads
=
config
.
max_threads
)
tokenizer_info
,
max_threads
=
config
.
max_threads
)
...
@@ -118,10 +157,7 @@ class GrammarConfig:
...
@@ -118,10 +157,7 @@ class GrammarConfig:
grammar_str
:
str
|
None
=
None
grammar_str
:
str
|
None
=
None
json_object
:
bool
|
None
=
None
json_object
:
bool
|
None
=
None
max_threads
:
int
=
8
max_threads
:
int
=
8
# Only populated if tokenizer_hash not in cache
tokenizer_data
:
TokenizerData
|
None
=
None
encoded_vocab
:
list
[
str
]
|
None
=
None
stop_token_ids
:
list
[
int
]
|
None
=
None
backend_str
:
str
|
None
=
None
@
classmethod
@
classmethod
def
from_guided_params
(
cls
,
def
from_guided_params
(
cls
,
...
@@ -132,9 +168,6 @@ class GrammarConfig:
...
@@ -132,9 +168,6 @@ class GrammarConfig:
tokenizer_hash
=
hash
(
tokenizer
)
tokenizer_hash
=
hash
(
tokenizer
)
tokenizer_data
=
TokenizerDataCache
.
get_tokenizer_data
(
tokenizer
)
tokenizer_data
=
TokenizerDataCache
.
get_tokenizer_data
(
tokenizer
)
encoded_vocab
=
tokenizer_data
.
encoded_vocab
stop_token_ids
=
tokenizer_data
.
stop_token_ids
backend_str
=
tokenizer_data
.
backend_str
if
guided_params
.
json
:
if
guided_params
.
json
:
if
not
isinstance
(
guided_params
.
json
,
str
):
if
not
isinstance
(
guided_params
.
json
,
str
):
...
@@ -152,11 +185,9 @@ class GrammarConfig:
...
@@ -152,11 +185,9 @@ class GrammarConfig:
return
cls
(
json_str
=
json_str
,
return
cls
(
json_str
=
json_str
,
vocab_size
=
model_config
.
hf_text_config
.
vocab_size
,
vocab_size
=
model_config
.
hf_text_config
.
vocab_size
,
encoded_vocab
=
encoded_vocab
,
stop_token_ids
=
stop_token_ids
,
backend_str
=
backend_str
,
tokenizer_hash
=
tokenizer_hash
,
tokenizer_hash
=
tokenizer_hash
,
max_threads
=
max_threads
)
max_threads
=
max_threads
,
tokenizer_data
=
tokenizer_data
)
elif
guided_params
.
grammar
:
elif
guided_params
.
grammar
:
# XGrammar only supports GBNF grammars, so we must convert Lark
# XGrammar only supports GBNF grammars, so we must convert Lark
if
grammar_is_likely_lark
(
guided_params
.
grammar
):
if
grammar_is_likely_lark
(
guided_params
.
grammar
):
...
@@ -181,19 +212,17 @@ class GrammarConfig:
...
@@ -181,19 +212,17 @@ class GrammarConfig:
return
cls
(
grammar_str
=
grammar_str
,
return
cls
(
grammar_str
=
grammar_str
,
vocab_size
=
model_config
.
hf_text_config
.
vocab_size
,
vocab_size
=
model_config
.
hf_text_config
.
vocab_size
,
encoded_vocab
=
encoded_vocab
,
stop_token_ids
=
stop_token_ids
,
backend_str
=
backend_str
,
tokenizer_hash
=
tokenizer_hash
,
tokenizer_hash
=
tokenizer_hash
,
max_threads
=
max_threads
)
max_threads
=
max_threads
,
tokenizer_data
=
tokenizer_data
)
elif
guided_params
.
json_object
:
elif
guided_params
.
json_object
:
return
cls
(
json_object
=
True
,
return
cls
(
vocab_size
=
model_config
.
hf_text_config
.
vocab_siz
e
,
json_object
=
Tru
e
,
encoded_vocab
=
encoded_vocab
,
vocab_size
=
model_config
.
hf_text_config
.
vocab_size
,
stop_token_ids
=
stop_token_ids
,
tokenizer_hash
=
tokenizer_hash
,
backend_str
=
backend_str
,
max_threads
=
max_threads
,
tokenizer_
hash
=
tokenizer_
hash
,
tokenizer_
data
=
tokenizer_
data
,
max_threads
=
max_threads
)
)
else
:
else
:
raise
ValueError
(
raise
ValueError
(
"Currently only support JSON and EBNF grammar mode for xgrammar"
"Currently only support JSON and EBNF grammar mode for xgrammar"
...
@@ -269,10 +298,14 @@ class XGrammarLogitsProcessor:
...
@@ -269,10 +298,14 @@ class XGrammarLogitsProcessor:
# fill_next_token_bitmask so we move it to the device of scores
# fill_next_token_bitmask so we move it to the device of scores
device_type
=
scores
.
device
.
type
device_type
=
scores
.
device
.
type
if
device_type
!=
"cuda"
:
if
device_type
!=
"cuda"
:
scores
=
scores
.
to
(
"cpu"
)
scores
=
scores
.
to
(
"cpu"
).
unsqueeze
(
0
)
# Note: In this method, if the tensors have different dimensions
# on CPU device fails, but on GPU it runs without error. Hence the
# unsqueeze above for scores, to match the token bitmask shape
xgr
.
apply_token_bitmask_inplace
(
scores
,
xgr
.
apply_token_bitmask_inplace
(
scores
,
self
.
token_bitmask
.
to
(
scores
.
device
))
self
.
token_bitmask
.
to
(
scores
.
device
))
if
device_type
!=
"cuda"
:
if
device_type
!=
"cuda"
:
scores
=
scores
.
to
(
device_type
)
scores
=
scores
.
to
(
device_type
)
.
squeeze
()
return
scores
return
scores
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
96ae75ad
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
import
functools
import
functools
import
json
import
json
import
os
import
os
from
typing
import
Any
,
Callable
,
Dict
,
Optional
,
Tuple
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch
import
triton
import
triton
...
@@ -11,6 +11,8 @@ import triton.language as tl
...
@@ -11,6 +11,8 @@ import triton.language as tl
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
per_token_group_quant_fp8
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
direct_register_custom_op
from
vllm.utils
import
direct_register_custom_op
...
@@ -45,8 +47,14 @@ def fused_moe_kernel(
...
@@ -45,8 +47,14 @@ def fused_moe_kernel(
stride_bn
,
stride_bn
,
stride_cm
,
stride_cm
,
stride_cn
,
stride_cn
,
stride_asm
,
stride_ask
,
stride_bse
,
stride_bse
,
stride_bsk
,
stride_bsn
,
stride_bsn
,
# Block size for block-wise quantization
group_n
:
tl
.
constexpr
,
group_k
:
tl
.
constexpr
,
# Meta-parameters
# Meta-parameters
BLOCK_SIZE_M
:
tl
.
constexpr
,
BLOCK_SIZE_M
:
tl
.
constexpr
,
BLOCK_SIZE_N
:
tl
.
constexpr
,
BLOCK_SIZE_N
:
tl
.
constexpr
,
...
@@ -125,8 +133,14 @@ def fused_moe_kernel(
...
@@ -125,8 +133,14 @@ def fused_moe_kernel(
b_scale
=
tl
.
load
(
b_scale_ptrs
)
b_scale
=
tl
.
load
(
b_scale_ptrs
)
if
use_fp8_w8a8
:
if
use_fp8_w8a8
:
a_scale
=
tl
.
load
(
a_scale_ptr
)
if
group_k
>
0
and
group_n
>
0
:
b_scale
=
tl
.
load
(
b_scale_ptr
+
off_experts
)
a_scale_ptrs
=
a_scale_ptr
+
(
offs_token
//
top_k
)
*
stride_asm
offs_bsn
=
offs_bn
//
group_n
b_scale_ptrs
=
(
b_scale_ptr
+
off_experts
*
stride_bse
+
offs_bsn
*
stride_bsn
)
else
:
a_scale
=
tl
.
load
(
a_scale_ptr
)
b_scale
=
tl
.
load
(
b_scale_ptr
+
off_experts
)
# -----------------------------------------------------------
# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
# Iterate to compute a block of the C matrix.
...
@@ -149,7 +163,18 @@ def fused_moe_kernel(
...
@@ -149,7 +163,18 @@ def fused_moe_kernel(
if
use_int8_w8a16
:
if
use_int8_w8a16
:
accumulator
=
tl
.
dot
(
a
,
b
.
to
(
compute_type
),
acc
=
accumulator
)
accumulator
=
tl
.
dot
(
a
,
b
.
to
(
compute_type
),
acc
=
accumulator
)
elif
use_fp8_w8a8
:
elif
use_fp8_w8a8
:
accumulator
=
tl
.
dot
(
a
,
b
,
acc
=
accumulator
)
if
group_k
>
0
and
group_n
>
0
:
k_start
=
k
*
BLOCK_SIZE_K
offs_ks
=
k_start
//
group_k
a_scale
=
tl
.
load
(
a_scale_ptrs
+
offs_ks
*
stride_ask
,
mask
=
token_mask
,
other
=
0.0
)
b_scale
=
tl
.
load
(
b_scale_ptrs
+
offs_ks
*
stride_bsk
)
accumulator
+=
tl
.
dot
(
a
,
b
)
*
a_scale
[:,
None
]
*
b_scale
[
None
,
:]
else
:
accumulator
=
tl
.
dot
(
a
,
b
,
acc
=
accumulator
)
else
:
else
:
accumulator
+=
tl
.
dot
(
a
,
b
)
accumulator
+=
tl
.
dot
(
a
,
b
)
# Advance the ptrs to the next K block.
# Advance the ptrs to the next K block.
...
@@ -164,7 +189,10 @@ def fused_moe_kernel(
...
@@ -164,7 +189,10 @@ def fused_moe_kernel(
if
use_int8_w8a16
:
if
use_int8_w8a16
:
accumulator
=
(
accumulator
*
b_scale
).
to
(
compute_type
)
accumulator
=
(
accumulator
*
b_scale
).
to
(
compute_type
)
elif
use_fp8_w8a8
:
elif
use_fp8_w8a8
:
accumulator
=
(
accumulator
*
a_scale
*
b_scale
).
to
(
compute_type
)
if
group_k
>
0
and
group_n
>
0
:
accumulator
=
accumulator
.
to
(
compute_type
)
else
:
accumulator
=
(
accumulator
*
a_scale
*
b_scale
).
to
(
compute_type
)
else
:
else
:
accumulator
=
accumulator
.
to
(
compute_type
)
accumulator
=
accumulator
.
to
(
compute_type
)
# -----------------------------------------------------------
# -----------------------------------------------------------
...
@@ -233,22 +261,37 @@ def moe_align_block_size(
...
@@ -233,22 +261,37 @@ def moe_align_block_size(
return
sorted_ids
,
expert_ids
,
num_tokens_post_pad
return
sorted_ids
,
expert_ids
,
num_tokens_post_pad
def
invoke_fused_moe_kernel
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
C
:
torch
.
Tensor
,
def
invoke_fused_moe_kernel
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
C
:
torch
.
Tensor
,
A_scale
:
Optional
[
torch
.
Tensor
],
A_scale
:
Optional
[
torch
.
Tensor
],
B_scale
:
Optional
[
torch
.
Tensor
],
B_scale
:
Optional
[
torch
.
Tensor
],
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
sorted_token_ids
:
torch
.
Tensor
,
sorted_token_ids
:
torch
.
Tensor
,
expert_ids
:
torch
.
Tensor
,
expert_ids
:
torch
.
Tensor
,
num_tokens_post_padded
:
torch
.
Tensor
,
num_tokens_post_padded
:
torch
.
Tensor
,
mul_routed_weight
:
bool
,
top_k
:
int
,
mul_routed_weight
:
bool
,
config
:
Dict
[
str
,
Any
],
compute_type
:
tl
.
dtype
,
top_k
:
int
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
)
->
None
:
config
:
Dict
[
str
,
Any
],
compute_type
:
tl
.
dtype
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
block_shape
:
Optional
[
List
[
int
]]
=
None
)
->
None
:
assert
topk_weights
.
stride
(
1
)
==
1
assert
topk_weights
.
stride
(
1
)
==
1
assert
sorted_token_ids
.
stride
(
0
)
==
1
assert
sorted_token_ids
.
stride
(
0
)
==
1
if
use_fp8_w8a8
:
if
use_fp8_w8a8
:
A
,
A_scale
=
ops
.
scaled_fp8_quant
(
A
,
A_scale
)
assert
B_scale
is
not
None
assert
B_scale
is
not
None
if
block_shape
is
None
:
A
,
A_scale
=
ops
.
scaled_fp8_quant
(
A
,
A_scale
)
else
:
assert
len
(
block_shape
)
==
2
block_n
,
block_k
=
block_shape
[
0
],
block_shape
[
1
]
A
,
A_scale
=
per_token_group_quant_fp8
(
A
,
block_k
)
assert
triton
.
cdiv
(
A
.
shape
[
-
1
],
block_k
)
==
A_scale
.
shape
[
-
1
]
assert
triton
.
cdiv
(
B
.
shape
[
-
2
],
block_n
)
==
B_scale
.
shape
[
-
2
]
assert
triton
.
cdiv
(
B
.
shape
[
-
1
],
block_k
)
==
B_scale
.
shape
[
-
1
]
elif
use_int8_w8a16
:
elif
use_int8_w8a16
:
assert
B_scale
is
not
None
assert
B_scale
is
not
None
else
:
else
:
...
@@ -279,8 +322,13 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
...
@@ -279,8 +322,13 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
B
.
stride
(
1
),
B
.
stride
(
1
),
C
.
stride
(
1
),
C
.
stride
(
1
),
C
.
stride
(
2
),
C
.
stride
(
2
),
B_scale
.
stride
(
0
)
if
B_scale
is
not
None
and
use_int8_w8a16
else
0
,
A_scale
.
stride
(
0
)
if
A_scale
is
not
None
and
A_scale
.
ndim
==
2
else
0
,
B_scale
.
stride
(
1
)
if
B_scale
is
not
None
and
use_int8_w8a16
else
0
,
A_scale
.
stride
(
1
)
if
A_scale
is
not
None
and
A_scale
.
ndim
==
2
else
0
,
B_scale
.
stride
(
0
)
if
B_scale
is
not
None
and
B_scale
.
ndim
>=
2
else
0
,
B_scale
.
stride
(
2
)
if
B_scale
is
not
None
and
B_scale
.
ndim
==
3
else
0
,
B_scale
.
stride
(
1
)
if
B_scale
is
not
None
and
B_scale
.
ndim
>=
2
else
0
,
0
if
block_shape
is
None
else
block_shape
[
0
],
0
if
block_shape
is
None
else
block_shape
[
1
],
MUL_ROUTED_WEIGHT
=
mul_routed_weight
,
MUL_ROUTED_WEIGHT
=
mul_routed_weight
,
top_k
=
top_k
,
top_k
=
top_k
,
compute_type
=
compute_type
,
compute_type
=
compute_type
,
...
@@ -362,6 +410,7 @@ def try_get_optimal_moe_config(
...
@@ -362,6 +410,7 @@ def try_get_optimal_moe_config(
dtype
:
Optional
[
str
],
dtype
:
Optional
[
str
],
M
:
int
,
M
:
int
,
is_marlin
:
bool
=
False
,
is_marlin
:
bool
=
False
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
):
):
from
vllm.model_executor.layers.fused_moe
import
get_config
from
vllm.model_executor.layers.fused_moe
import
get_config
override_config
=
get_config
()
override_config
=
get_config
()
...
@@ -380,6 +429,12 @@ def try_get_optimal_moe_config(
...
@@ -380,6 +429,12 @@ def try_get_optimal_moe_config(
# Else use the default config
# Else use the default config
config
=
get_default_config
(
M
,
E
,
N
,
w1_shape
[
2
],
top_k
,
dtype
,
config
=
get_default_config
(
M
,
E
,
N
,
w1_shape
[
2
],
top_k
,
dtype
,
is_marlin
)
is_marlin
)
# NOTE: For block-wise quant,
# BLOCK_K must be divisible by block_shape[1]
# BLOCK_N and BLOCK_M has no requirements
if
block_shape
is
not
None
:
config
[
"BLOCK_SIZE_N"
]
=
block_shape
[
0
]
config
[
"BLOCK_SIZE_K"
]
=
block_shape
[
1
]
return
config
return
config
...
@@ -421,18 +476,29 @@ def fused_topk(
...
@@ -421,18 +476,29 @@ def fused_topk(
return
topk_weights
,
topk_ids
return
topk_weights
,
topk_ids
# This is used by the Deepseek-V2 model
# This is used by the Deepseek-V2
and Deepseek-V3
model
def
grouped_topk
(
hidden_states
:
torch
.
Tensor
,
def
grouped_topk
(
hidden_states
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
topk
:
int
,
renormalize
:
bool
,
renormalize
:
bool
,
num_expert_group
:
int
=
0
,
num_expert_group
:
int
=
0
,
topk_group
:
int
=
0
):
topk_group
:
int
=
0
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
):
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
(
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
(
"Number of tokens mismatch"
)
"Number of tokens mismatch"
)
scores
=
torch
.
softmax
(
gating_output
,
dim
=-
1
)
if
scoring_func
==
"softmax"
:
scores
=
torch
.
softmax
(
gating_output
,
dim
=-
1
)
elif
scoring_func
==
"sigmoid"
:
scores
=
gating_output
.
sigmoid
()
else
:
raise
ValueError
(
f
"Unsupported scoring function:
{
scoring_func
}
"
)
if
e_score_correction_bias
is
not
None
:
scores
.
add_
(
e_score_correction_bias
.
unsqueeze
(
0
))
num_token
=
scores
.
shape
[
0
]
num_token
=
scores
.
shape
[
0
]
group_scores
=
scores
.
view
(
num_token
,
num_expert_group
,
group_scores
=
scores
.
view
(
num_token
,
num_expert_group
,
-
1
).
max
(
dim
=-
1
).
values
# [n, n_group]
-
1
).
max
(
dim
=-
1
).
values
# [n, n_group]
...
@@ -479,10 +545,11 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
...
@@ -479,10 +545,11 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
)
->
None
:
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
)
->
None
:
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
True
,
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
True
,
use_fp8_w8a8
,
use_int8_w8a16
,
w1_scale
,
w2_scale
,
use_fp8_w8a8
,
use_int8_w8a16
,
w1_scale
,
w2_scale
,
a1_scale
,
a2_scale
)
a1_scale
,
a2_scale
,
block_shape
)
def
inplace_fused_experts_fake
(
def
inplace_fused_experts_fake
(
...
@@ -496,7 +563,8 @@ def inplace_fused_experts_fake(
...
@@ -496,7 +563,8 @@ def inplace_fused_experts_fake(
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
)
->
None
:
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
)
->
None
:
pass
pass
...
@@ -519,10 +587,11 @@ def outplace_fused_experts(
...
@@ -519,10 +587,11 @@ def outplace_fused_experts(
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
)
->
torch
.
Tensor
:
return
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
return
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
False
,
use_fp8_w8a8
,
use_int8_w8a16
,
w1_scale
,
False
,
use_fp8_w8a8
,
use_int8_w8a16
,
w1_scale
,
w2_scale
,
a1_scale
,
a2_scale
)
w2_scale
,
a1_scale
,
a2_scale
,
block_shape
)
def
outplace_fused_experts_fake
(
def
outplace_fused_experts_fake
(
...
@@ -536,7 +605,8 @@ def outplace_fused_experts_fake(
...
@@ -536,7 +605,8 @@ def outplace_fused_experts_fake(
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
hidden_states
)
return
torch
.
empty_like
(
hidden_states
)
...
@@ -559,18 +629,22 @@ def fused_experts(hidden_states: torch.Tensor,
...
@@ -559,18 +629,22 @@ def fused_experts(hidden_states: torch.Tensor,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
):
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
):
if
inplace
:
if
inplace
:
torch
.
ops
.
vllm
.
inplace_fused_experts
(
hidden_states
,
w1
,
w2
,
torch
.
ops
.
vllm
.
inplace_fused_experts
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
topk_weights
,
topk_ids
,
use_fp8_w8a8
,
use_int8_w8a16
,
use_fp8_w8a8
,
use_int8_w8a16
,
w1_scale
,
w2_scale
,
a1_scale
,
w1_scale
,
w2_scale
,
a1_scale
,
a2_scale
)
a2_scale
,
block_shape
)
return
hidden_states
return
hidden_states
else
:
else
:
return
torch
.
ops
.
vllm
.
outplace_fused_experts
(
return
torch
.
ops
.
vllm
.
outplace_fused_experts
(
hidden_states
,
w1
,
w2
,
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
use_fp8_w8a8
,
topk_weights
,
topk_ids
,
use_int8_w8a16
,
w1_scale
,
w2_scale
,
a1_scale
,
a2_scale
)
use_fp8_w8a8
,
use_int8_w8a16
,
w1_scale
,
w2_scale
,
a1_scale
,
a2_scale
,
block_shape
)
def
fused_experts_impl
(
hidden_states
:
torch
.
Tensor
,
def
fused_experts_impl
(
hidden_states
:
torch
.
Tensor
,
...
@@ -584,7 +658,8 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -584,7 +658,8 @@ def fused_experts_impl(hidden_states: torch.Tensor,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
):
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
):
# Check constraints.
# Check constraints.
assert
hidden_states
.
shape
[
1
]
==
w1
.
shape
[
2
],
"Hidden size mismatch"
assert
hidden_states
.
shape
[
1
]
==
w1
.
shape
[
2
],
"Hidden size mismatch"
assert
topk_weights
.
shape
==
topk_ids
.
shape
,
"topk shape mismatch"
assert
topk_weights
.
shape
==
topk_ids
.
shape
,
"topk shape mismatch"
...
@@ -611,6 +686,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -611,6 +686,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
w2
.
shape
,
w2
.
shape
,
topk_ids
.
shape
[
1
],
topk_ids
.
shape
[
1
],
config_dtype
,
config_dtype
,
block_shape
=
block_shape
,
)
)
config
=
get_config_func
(
M
)
config
=
get_config_func
(
M
)
...
@@ -674,7 +750,8 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -674,7 +750,8 @@ def fused_experts_impl(hidden_states: torch.Tensor,
config
,
config
,
compute_type
=
compute_type
,
compute_type
=
compute_type
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
)
use_int8_w8a16
=
use_int8_w8a16
,
block_shape
=
block_shape
)
ops
.
silu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
ops
.
silu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
...
@@ -693,7 +770,8 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -693,7 +770,8 @@ def fused_experts_impl(hidden_states: torch.Tensor,
config
,
config
,
compute_type
=
compute_type
,
compute_type
=
compute_type
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
)
use_int8_w8a16
=
use_int8_w8a16
,
block_shape
=
block_shape
)
ops
.
moe_sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
ops
.
moe_sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
])
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
])
...
@@ -718,6 +796,7 @@ def fused_moe(
...
@@ -718,6 +796,7 @@ def fused_moe(
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
This function computes a Mixture of Experts (MoE) layer using two sets of
...
@@ -745,6 +824,12 @@ def fused_moe(
...
@@ -745,6 +824,12 @@ def fused_moe(
w1.
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.
w2.
- a1_scale (Optional[torch.Tensor]): Optional scale to be used for
a1.
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for
a2.
- block_shape: (Optional[List[int]]): Optional block size for block-wise
quantization.
Returns:
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
- torch.Tensor: The output tensor after applying the MoE layer.
...
@@ -775,4 +860,5 @@ def fused_moe(
...
@@ -775,4 +860,5 @@ def fused_moe(
w1_scale
=
w1_scale
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
)
a2_scale
=
a2_scale
,
\ No newline at end of file
block_shape
=
block_shape
)
vllm/model_executor/layers/fused_moe/layer.py
View file @
96ae75ad
...
@@ -29,6 +29,7 @@ class FusedMoeWeightScaleSupported(Enum):
...
@@ -29,6 +29,7 @@ class FusedMoeWeightScaleSupported(Enum):
TENSOR
=
"tensor"
TENSOR
=
"tensor"
CHANNEL
=
"channel"
CHANNEL
=
"channel"
GROUP
=
"group"
GROUP
=
"group"
BLOCK
=
"block"
class
FusedMoEMethodBase
(
QuantizeMethodBase
):
class
FusedMoEMethodBase
(
QuantizeMethodBase
):
...
@@ -40,9 +41,20 @@ class FusedMoEMethodBase(QuantizeMethodBase):
...
@@ -40,9 +41,20 @@ class FusedMoEMethodBase(QuantizeMethodBase):
raise
NotImplementedError
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
def
apply
(
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
self
,
use_grouped_topk
:
bool
)
->
torch
.
Tensor
:
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
raise
NotImplementedError
raise
NotImplementedError
...
@@ -72,16 +84,18 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -72,16 +84,18 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
def
apply
(
def
apply
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
top_k
:
int
,
renormalize
:
bool
,
renormalize
:
bool
,
use_grouped_topk
:
bool
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
self
.
forward
(
x
=
x
,
return
self
.
forward
(
x
=
x
,
layer
=
layer
,
layer
=
layer
,
...
@@ -91,19 +105,23 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -91,19 +105,23 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
use_grouped_topk
=
use_grouped_topk
,
use_grouped_topk
=
use_grouped_topk
,
topk_group
=
topk_group
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
)
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
)
def
forward_cuda
(
def
forward_cuda
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
use_grouped_topk
:
bool
,
use_grouped_topk
:
bool
,
top_k
:
int
,
top_k
:
int
,
router_logits
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
renormalize
:
bool
,
renormalize
:
bool
,
topk_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
hidden_states
=
x
,
...
@@ -113,7 +131,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -113,7 +131,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
renormalize
=
renormalize
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
)
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
)
return
fused_experts
(
hidden_states
=
x
,
return
fused_experts
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w1
=
layer
.
w13_weight
,
...
@@ -127,21 +147,29 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -127,21 +147,29 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
"The CPU backend currently does not support MoE."
)
"The CPU backend currently does not support MoE."
)
def
forward_tpu
(
def
forward_tpu
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
use_grouped_topk
:
bool
,
use_grouped_topk
:
bool
,
top_k
:
int
,
top_k
:
int
,
router_logits
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
renormalize
:
bool
,
renormalize
:
bool
,
topk_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
assert
not
use_grouped_topk
assert
not
use_grouped_topk
assert
num_expert_group
is
None
assert
num_expert_group
is
None
assert
topk_group
is
None
assert
topk_group
is
None
assert
custom_routing_function
is
None
assert
custom_routing_function
is
None
if
scoring_func
!=
"softmax"
:
raise
NotImplementedError
(
"Only softmax scoring function is supported for TPU."
)
if
e_score_correction_bias
is
not
None
:
raise
NotImplementedError
(
"Expert score correction bias is not supported for TPU."
)
return
fused_moe_pallas
(
hidden_states
=
x
,
return
fused_moe_pallas
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
w2
=
layer
.
w2_weight
,
...
@@ -155,7 +183,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -155,7 +183,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
class
FusedMoE
(
torch
.
nn
.
Module
):
class
FusedMoE
(
torch
.
nn
.
Module
):
"""FusedMoE layer for MoE models.
"""FusedMoE layer for MoE models.
This layer contains both MergedColumnParallel weights (gate_up_proj /
This layer contains both MergedColumnParallel weights (gate_up_proj /
w13) and RowParallelLinear weights (down_proj/ w2).
w13) and RowParallelLinear weights (down_proj/ w2).
Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We
Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We
...
@@ -189,6 +217,8 @@ class FusedMoE(torch.nn.Module):
...
@@ -189,6 +217,8 @@ class FusedMoE(torch.nn.Module):
tp_size
:
Optional
[
int
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -199,6 +229,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -199,6 +229,7 @@ class FusedMoE(torch.nn.Module):
get_tensor_model_parallel_world_size
())
get_tensor_model_parallel_world_size
())
self
.
top_k
=
top_k
self
.
top_k
=
top_k
self
.
num_experts
=
num_experts
self
.
num_experts
=
num_experts
assert
intermediate_size
%
self
.
tp_size
==
0
self
.
intermediate_size_per_partition
=
intermediate_size
//
self
.
tp_size
self
.
intermediate_size_per_partition
=
intermediate_size
//
self
.
tp_size
self
.
reduce_results
=
reduce_results
self
.
reduce_results
=
reduce_results
self
.
renormalize
=
renormalize
self
.
renormalize
=
renormalize
...
@@ -208,6 +239,12 @@ class FusedMoE(torch.nn.Module):
...
@@ -208,6 +239,12 @@ class FusedMoE(torch.nn.Module):
self
.
num_expert_group
=
num_expert_group
self
.
num_expert_group
=
num_expert_group
self
.
topk_group
=
topk_group
self
.
topk_group
=
topk_group
self
.
custom_routing_function
=
custom_routing_function
self
.
custom_routing_function
=
custom_routing_function
self
.
scoring_func
=
scoring_func
self
.
e_score_correction_bias
=
e_score_correction_bias
if
self
.
scoring_func
!=
"softmax"
and
not
self
.
use_grouped_topk
:
raise
ValueError
(
"Only softmax scoring function is supported for "
"non-grouped topk."
)
if
quant_config
is
None
:
if
quant_config
is
None
:
self
.
quant_method
:
Optional
[
QuantizeMethodBase
]
=
(
self
.
quant_method
:
Optional
[
QuantizeMethodBase
]
=
(
...
@@ -398,7 +435,10 @@ class FusedMoE(torch.nn.Module):
...
@@ -398,7 +435,10 @@ class FusedMoE(torch.nn.Module):
loaded_weight
=
loaded_weight
,
loaded_weight
=
loaded_weight
,
expert_data
=
expert_data
,
expert_data
=
expert_data
,
tp_rank
=
tp_rank
)
tp_rank
=
tp_rank
)
elif
quant_method
==
FusedMoeWeightScaleSupported
.
GROUP
.
value
:
elif
quant_method
in
[
FusedMoeWeightScaleSupported
.
GROUP
.
value
,
FusedMoeWeightScaleSupported
.
BLOCK
.
value
,
]:
self
.
_load_model_weight_or_group_weight_scale
(
self
.
_load_model_weight_or_group_weight_scale
(
shard_id
=
shard_id
,
shard_id
=
shard_id
,
shard_dim
=
shard_dim
,
shard_dim
=
shard_dim
,
...
@@ -441,7 +481,9 @@ class FusedMoE(torch.nn.Module):
...
@@ -441,7 +481,9 @@ class FusedMoE(torch.nn.Module):
renormalize
:
bool
,
renormalize
:
bool
,
topk_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
):
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
):
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_topk
,
grouped_topk
)
fused_topk
,
grouped_topk
)
...
@@ -455,7 +497,9 @@ class FusedMoE(torch.nn.Module):
...
@@ -455,7 +497,9 @@ class FusedMoE(torch.nn.Module):
topk
=
top_k
,
topk
=
top_k
,
renormalize
=
renormalize
,
renormalize
=
renormalize
,
num_expert_group
=
num_expert_group
,
num_expert_group
=
num_expert_group
,
topk_group
=
topk_group
)
topk_group
=
topk_group
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
)
elif
custom_routing_function
is
None
:
elif
custom_routing_function
is
None
:
topk_weights
,
topk_ids
=
fused_topk
(
hidden_states
=
hidden_states
,
topk_weights
,
topk_ids
=
fused_topk
(
hidden_states
=
hidden_states
,
gating_output
=
router_logits
,
gating_output
=
router_logits
,
...
@@ -484,7 +528,9 @@ class FusedMoE(torch.nn.Module):
...
@@ -484,7 +528,9 @@ class FusedMoE(torch.nn.Module):
use_grouped_topk
=
self
.
use_grouped_topk
,
use_grouped_topk
=
self
.
use_grouped_topk
,
topk_group
=
self
.
topk_group
,
topk_group
=
self
.
topk_group
,
num_expert_group
=
self
.
num_expert_group
,
num_expert_group
=
self
.
num_expert_group
,
custom_routing_function
=
self
.
custom_routing_function
)
custom_routing_function
=
self
.
custom_routing_function
,
scoring_func
=
self
.
scoring_func
,
e_score_correction_bias
=
self
.
e_score_correction_bias
)
if
self
.
reduce_results
and
self
.
tp_size
>
1
:
if
self
.
reduce_results
and
self
.
tp_size
>
1
:
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
=
tensor_model_parallel_all_reduce
(
...
...
vllm/model_executor/layers/linear.py
View file @
96ae75ad
...
@@ -14,11 +14,14 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
...
@@ -14,11 +14,14 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
QuantizationConfig
,
QuantizeMethodBase
)
# yapf: disable
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
BlockQuantScaleParameter
,
PackedColumnParameter
,
PackedColumnParameter
,
PackedvLLMParameter
,
PackedvLLMParameter
,
PerTensorScaleParameter
,
PerTensorScaleParameter
,
RowvLLMParameter
)
RowvLLMParameter
)
# yapf: enable
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
import
os
import
os
...
@@ -642,8 +645,24 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -642,8 +645,24 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
assert
loaded_shard_id
<
len
(
self
.
output_sizes
)
assert
loaded_shard_id
<
len
(
self
.
output_sizes
)
tp_size
=
get_tensor_model_parallel_world_size
()
tp_size
=
get_tensor_model_parallel_world_size
()
shard_offset
=
sum
(
self
.
output_sizes
[:
loaded_shard_id
])
//
tp_size
shard_size
=
self
.
output_sizes
[
loaded_shard_id
]
//
tp_size
if
isinstance
(
param
,
BlockQuantScaleParameter
):
from
vllm.model_executor.layers.quantization.fp8
import
(
Fp8LinearMethod
,
Fp8MoEMethod
)
assert
self
.
quant_method
is
not
None
assert
isinstance
(
self
.
quant_method
,
(
Fp8LinearMethod
,
Fp8MoEMethod
))
weight_block_size
=
self
.
quant_method
.
quant_config
.
weight_block_size
assert
weight_block_size
is
not
None
block_n
,
_
=
weight_block_size
[
0
],
weight_block_size
[
1
]
shard_offset
=
(
(
sum
(
self
.
output_sizes
[:
loaded_shard_id
])
+
block_n
-
1
)
//
block_n
)
//
tp_size
shard_size
=
((
self
.
output_sizes
[
loaded_shard_id
]
+
block_n
-
1
)
//
block_n
//
tp_size
)
else
:
shard_offset
=
sum
(
self
.
output_sizes
[:
loaded_shard_id
])
//
tp_size
shard_size
=
self
.
output_sizes
[
loaded_shard_id
]
//
tp_size
param
.
load_merged_column_weight
(
loaded_weight
=
loaded_weight
,
param
.
load_merged_column_weight
(
loaded_weight
=
loaded_weight
,
shard_id
=
loaded_shard_id
,
shard_id
=
loaded_shard_id
,
...
...
vllm/model_executor/layers/quantization/awq_marlin.py
View file @
96ae75ad
...
@@ -440,11 +440,13 @@ class AWQMoEMethod(FusedMoEMethodBase):
...
@@ -440,11 +440,13 @@ class AWQMoEMethod(FusedMoEMethodBase):
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
top_k
:
int
,
renormalize
:
bool
=
True
,
renormalize
:
bool
,
use_grouped_topk
:
bool
=
False
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
hidden_states
=
x
,
...
@@ -454,7 +456,9 @@ class AWQMoEMethod(FusedMoEMethodBase):
...
@@ -454,7 +456,9 @@ class AWQMoEMethod(FusedMoEMethodBase):
renormalize
=
renormalize
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
)
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
)
return
torch
.
ops
.
vllm
.
fused_marlin_moe
(
return
torch
.
ops
.
vllm
.
fused_marlin_moe
(
x
,
x
,
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
96ae75ad
from
typing
import
Any
,
Dict
,
List
,
Optional
,
cast
from
typing
import
Any
,
Dict
,
List
,
Literal
,
Optional
,
cast
import
torch
import
torch
from
compressed_tensors.config
import
CompressionFormat
from
compressed_tensors.config
import
(
CompressionFormat
,
SparsityCompressionConfig
,
SparsityStructure
)
from
compressed_tensors.quantization
import
(
QuantizationArgs
,
from
compressed_tensors.quantization
import
(
QuantizationArgs
,
QuantizationStrategy
,
QuantizationStrategy
,
QuantizationType
)
QuantizationType
)
...
@@ -15,7 +17,7 @@ from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
...
@@ -15,7 +17,7 @@ from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
from
vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe
import
(
# noqa: E501
from
vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe
import
(
# noqa: E501
CompressedTensorsMoEMethod
)
CompressedTensorsMoEMethod
)
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
W4A16SPARSE24_SUPPORTED_BITS
,
WNA16_SUPPORTED_BITS
,
W4A16SPARSE24_SUPPORTED_BITS
,
WNA16_SUPPORTED_BITS
,
CompressedTensors24
,
CompressedTensorsScheme
,
CompressedTensorsW4A16Sparse24
,
CompressedTensorsScheme
,
CompressedTensorsW4A16Sparse24
,
CompressedTensorsW8A8Fp8
,
CompressedTensorsW8A8Int8
,
CompressedTensorsW8A8Fp8
,
CompressedTensorsW8A8Int8
,
CompressedTensorsW8A16Fp8
,
CompressedTensorsWNA16
)
CompressedTensorsW8A16Fp8
,
CompressedTensorsWNA16
)
...
@@ -27,20 +29,29 @@ from vllm.platforms import current_platform
...
@@ -27,20 +29,29 @@ from vllm.platforms import current_platform
__all__
=
[
"CompressedTensorsLinearMethod"
]
__all__
=
[
"CompressedTensorsLinearMethod"
]
SPARSITY_CONFIG_NAME
:
Literal
[
"sparsity_config"
]
=
"sparsity_config"
QUANTIZATION_SCHEME_MAP_TYPE
=
Dict
[
str
,
Optional
[
Dict
[
str
,
QuantizationArgs
]]]
class
CompressedTensorsConfig
(
QuantizationConfig
):
class
CompressedTensorsConfig
(
QuantizationConfig
):
def
__init__
(
self
,
def
__init__
(
target_scheme_map
:
Dict
[
str
,
Any
],
self
,
ignore
:
List
[
str
],
target_scheme_map
:
Dict
[
str
,
Any
],
quant_format
:
str
,
ignore
:
List
[
str
],
kv_cache_scheme
:
Optional
[
Dict
[
str
,
Any
]]
=
None
):
quant_format
:
str
,
sparsity_scheme_map
:
Dict
[
str
,
SparsityCompressionConfig
],
kv_cache_scheme
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
):
self
.
ignore
=
ignore
self
.
ignore
=
ignore
self
.
quant_format
=
quant_format
self
.
quant_format
=
quant_format
# Map from [target -> scheme]
# Map from [target -> scheme]
self
.
target_scheme_map
=
target_scheme_map
self
.
target_scheme_map
=
target_scheme_map
self
.
kv_cache_scheme
=
kv_cache_scheme
self
.
kv_cache_scheme
=
kv_cache_scheme
self
.
sparsity_scheme_map
=
sparsity_scheme_map
self
.
config
=
config
def
get_linear_method
(
self
)
->
"CompressedTensorsLinearMethod"
:
def
get_linear_method
(
self
)
->
"CompressedTensorsLinearMethod"
:
return
CompressedTensorsLinearMethod
(
self
)
return
CompressedTensorsLinearMethod
(
self
)
...
@@ -78,8 +89,50 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -78,8 +89,50 @@ class CompressedTensorsConfig(QuantizationConfig):
@
classmethod
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"CompressedTensorsConfig"
:
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"CompressedTensorsConfig"
:
ignore
:
List
[
str
]
=
cast
(
List
[
str
],
config
.
get
(
"ignore"
,
[]))
quant_format
=
cast
(
str
,
config
.
get
(
"format"
))
target_scheme_map
=
cls
.
_quantization_scheme_map_from_config
(
config
=
config
)
sparsity_scheme_map
=
cls
.
_sparsity_scheme_map_from_config
(
config
=
config
)
return
cls
(
target_scheme_map
=
target_scheme_map
,
ignore
=
ignore
,
quant_format
=
quant_format
,
sparsity_scheme_map
=
sparsity_scheme_map
,
config
=
config
,
)
@
classmethod
def
_sparsity_scheme_map_from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
Dict
[
str
,
SparsityCompressionConfig
]:
"""
:param config: The `quantization_config` dictionary from config.json
:return: A dictionary mapping target layer names to their corresponding
sparsity compression configurations
"""
if
(
sparsity_config
:
=
config
.
get
(
SPARSITY_CONFIG_NAME
))
is
None
:
return
dict
()
sparsity_config
=
SparsityCompressionConfig
.
model_validate
(
sparsity_config
)
sparse_scheme_map
:
Dict
[
str
,
SparsityCompressionConfig
]
=
{
target
:
sparsity_config
for
target
in
sparsity_config
.
targets
or
list
()
}
return
sparse_scheme_map
@
classmethod
def
_quantization_scheme_map_from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
QUANTIZATION_SCHEME_MAP_TYPE
:
"""
:param config: The `quantization_config` dictionary from config.json
:return: A dictionary mapping target layer names to their corresponding
quantization_args for weights and input activations
"""
target_scheme_map
:
Dict
[
str
,
Any
]
=
dict
()
target_scheme_map
:
Dict
[
str
,
Any
]
=
dict
()
ignore
=
cast
(
List
[
str
],
config
.
get
(
"ignore"
))
quant_format
=
cast
(
str
,
config
.
get
(
"format"
))
quant_format
=
cast
(
str
,
config
.
get
(
"format"
))
# The quant_config has multiple config_groups, each containing
# The quant_config has multiple config_groups, each containing
...
@@ -90,12 +143,14 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -90,12 +143,14 @@ class CompressedTensorsConfig(QuantizationConfig):
# details follow the structure defined by the QuantizationArgs
# details follow the structure defined by the QuantizationArgs
# pydantic model, which is used to verify the structure of the
# pydantic model, which is used to verify the structure of the
# quant_config and also store the details for later use.
# quant_config and also store the details for later use.
for
_
,
quant_config
in
config
[
"config_groups"
].
items
():
config_groups
=
config
.
get
(
"config_groups"
,
dict
())
for
_
,
quant_config
in
config_groups
.
items
():
targets
=
quant_config
.
get
(
"targets"
)
targets
=
quant_config
.
get
(
"targets"
)
for
target
in
targets
:
for
target
in
targets
:
target_scheme_map
[
target
]
=
{}
target_scheme_map
[
target
]
=
{}
target_scheme_map
[
target
][
target_scheme_map
[
target
][
"weights"
]
=
QuantizationArgs
.
parse_obj
(
"weights"
]
=
QuantizationArgs
.
model_validate
(
quant_config
.
get
(
"weights"
))
quant_config
.
get
(
"weights"
))
target_scheme_map
[
target
][
"input_activations"
]
=
None
target_scheme_map
[
target
][
"input_activations"
]
=
None
...
@@ -110,13 +165,9 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -110,13 +165,9 @@ class CompressedTensorsConfig(QuantizationConfig):
"weights"
].
type
==
QuantizationType
.
FLOAT
"weights"
].
type
==
QuantizationType
.
FLOAT
else
:
else
:
target_scheme_map
[
target
][
target_scheme_map
[
target
][
"input_activations"
]
=
QuantizationArgs
.
parse_obj
(
"input_activations"
]
=
QuantizationArgs
.
model_validate
(
# noqa: E501
quant_config
.
get
(
"input_activations"
))
quant_config
.
get
(
"input_activations"
))
return
target_scheme_map
return
cls
(
target_scheme_map
=
target_scheme_map
,
ignore
=
ignore
,
quant_format
=
quant_format
,
kv_cache_scheme
=
config
.
get
(
"kv_cache_scheme"
))
@
classmethod
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
def
get_config_filenames
(
cls
)
->
List
[
str
]:
...
@@ -315,23 +366,105 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -315,23 +366,105 @@ class CompressedTensorsConfig(QuantizationConfig):
# TODO (@robertgshaw): add compressed-tensors as dep
# TODO (@robertgshaw): add compressed-tensors as dep
# so we do not have to re-write these functions
# so we do not have to re-write these functions
# need to make accelerate optional in ct to do this
# need to make accelerate optional in ct to do this
matched_target
=
find_matched_target
(
layer_name
=
layer_name
,
module
=
layer
,
targets
=
self
.
target_scheme_map
.
keys
())
# Find the quant_scheme
# Will be empty for models with only sparsity
scheme_dict
=
self
.
target_scheme_map
[
matched_target
]
if
self
.
target_scheme_map
:
scheme
=
self
.
_get_scheme_from_parts
(
matched_target
=
find_matched_target
(
weight_quant
=
scheme_dict
[
"weights"
],
layer_name
=
layer_name
,
input_quant
=
scheme_dict
[
"input_activations"
])
module
=
layer
,
targets
=
self
.
target_scheme_map
.
keys
())
scheme_dict
=
self
.
target_scheme_map
[
matched_target
]
weight_quant
=
scheme_dict
.
get
(
"weights"
)
input_quant
=
scheme_dict
.
get
(
"input_activations"
)
elif
self
.
sparsity_scheme_map
:
matched_target
=
find_matched_target
(
layer_name
=
layer_name
,
module
=
layer
,
targets
=
self
.
sparsity_scheme_map
.
keys
())
weight_quant
=
None
input_quant
=
None
# For models with sparsity, assumes that the sparse layers are also
# quantized for cutlass 2:4 support
sparsity_scheme
:
Optional
[
SparsityCompressionConfig
]
=
self
.
sparsity_scheme_map
.
get
(
matched_target
)
if
self
.
supports_cutlass_24
(
weight_quant
=
weight_quant
,
input_quant
=
input_quant
,
sparsity_scheme
=
sparsity_scheme
):
# Have a valid sparsity scheme
# Validate layer is supported by Cutlass 2:4 Kernel
scheme
=
CompressedTensors24
(
quantized
=
weight_quant
is
not
None
or
input_quant
is
not
None
,
weight_quant
=
weight_quant
,
input_quant
=
input_quant
)
else
:
# Find the quant_scheme
scheme
=
self
.
_get_scheme_from_parts
(
# type: ignore
weight_quant
=
weight_quant
,
input_quant
=
input_quant
,
)
# Raise error if device does not support the scheme
# Raise error if device does not support the scheme
# (e.g. fp8 needs ada lovelace)
# (e.g. fp8 needs ada lovelace)
self
.
_check_scheme_supported
(
scheme
.
get_min_capability
())
self
.
_check_scheme_supported
(
scheme
.
get_min_capability
())
return
scheme
return
scheme
@
staticmethod
def
supports_cutlass_24
(
weight_quant
:
Optional
[
QuantizationArgs
],
input_quant
:
Optional
[
QuantizationArgs
],
sparsity_scheme
:
Optional
[
SparsityCompressionConfig
]
=
None
)
->
bool
:
"""
Check if the layer is supported by the Cutlass 2:4 Kernel
Conditions:
- Overarching condition: Sparsity Structure is 2:4
- Unquantized cases are supported
- Weight only quantization is not-supported
- Supported weight quantization strategies are TENSOR and CHANNEL
- Supported input quantization strategies are TENSOR and TOKEN
- Only 8 bit quantization is supported
:return: True if the layer is supported by the Cutlass 2:4 Kernel
False otherwise
"""
is_valid_sparsity
=
(
sparsity_scheme
is
not
None
and
sparsity_scheme
.
sparsity_structure
==
SparsityStructure
.
TWO_FOUR
.
value
and
sparsity_scheme
.
format
==
"dense"
)
if
not
is_valid_sparsity
:
return
False
# Unquantized cases are supported
if
weight_quant
is
None
and
input_quant
is
None
:
return
True
# Weight only quantization is not-supported
if
weight_quant
is
not
None
and
input_quant
is
None
:
return
False
supported_weight_quant_strategies
=
[
QuantizationStrategy
.
TENSOR
.
value
,
QuantizationStrategy
.
CHANNEL
.
value
]
assert
weight_quant
is
not
None
assert
input_quant
is
not
None
if
weight_quant
.
strategy
not
in
supported_weight_quant_strategies
:
return
False
supported_input_quant_strategies
=
[
QuantizationStrategy
.
TENSOR
.
value
,
QuantizationStrategy
.
TOKEN
.
value
]
if
input_quant
.
strategy
not
in
supported_input_quant_strategies
:
return
False
return
weight_quant
.
num_bits
==
input_quant
.
num_bits
==
8
class
CompressedTensorsLinearMethod
(
LinearMethodBase
):
class
CompressedTensorsLinearMethod
(
LinearMethodBase
):
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
96ae75ad
...
@@ -203,13 +203,14 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
...
@@ -203,13 +203,14 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
top_k
:
int
,
renormalize
:
bool
=
True
,
renormalize
:
bool
,
use_grouped_topk
:
bool
=
False
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
from
vllm.model_executor.layers.fused_moe
import
fused_experts
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
...
@@ -220,7 +221,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
...
@@ -220,7 +221,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
renormalize
=
renormalize
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
)
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
)
return
fused_experts
(
x
,
return
fused_experts
(
x
,
layer
.
w13_weight
,
layer
.
w13_weight
,
...
@@ -476,12 +479,15 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
...
@@ -476,12 +479,15 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
top_k
:
int
,
renormalize
:
bool
=
True
,
renormalize
:
bool
,
use_grouped_topk
:
bool
=
False
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
hidden_states
=
x
,
router_logits
=
router_logits
,
router_logits
=
router_logits
,
...
@@ -490,7 +496,9 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
...
@@ -490,7 +496,9 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
renormalize
=
renormalize
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
)
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
)
return
torch
.
ops
.
vllm
.
fused_marlin_moe
(
return
torch
.
ops
.
vllm
.
fused_marlin_moe
(
x
,
x
,
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py
View file @
96ae75ad
...
@@ -7,13 +7,12 @@ from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8
...
@@ -7,13 +7,12 @@ from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8
from
.compressed_tensors_wNa16
import
(
WNA16_SUPPORTED_BITS
,
from
.compressed_tensors_wNa16
import
(
WNA16_SUPPORTED_BITS
,
CompressedTensorsWNA16
)
CompressedTensorsWNA16
)
from
.compressed_tensors_24
import
CompressedTensors24
# isort: skip
__all__
=
[
__all__
=
[
"CompressedTensorsScheme"
,
"CompressedTensorsScheme"
,
"CompressedTensorsWNA16"
,
"CompressedTensorsWNA16"
,
"CompressedTensorsW8A16Fp8"
,
"CompressedTensorsW4A16Sparse24"
,
"CompressedTensorsW8A16Fp8"
,
"CompressedTensorsW8A8Int8"
,
"CompressedTensorsW8A8Fp8"
,
"CompressedTensorsW4A16Sparse24"
,
"WNA16_SUPPORTED_BITS"
,
"W4A16SPARSE24_SUPPORTED_BITS"
,
"CompressedTensorsW8A8Int8"
,
"CompressedTensors24"
"CompressedTensorsW8A8Fp8"
,
"WNA16_SUPPORTED_BITS"
,
"W4A16SPARSE24_SUPPORTED_BITS"
,
]
]
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py
0 → 100644
View file @
96ae75ad
from
typing
import
Callable
,
List
,
Optional
import
torch
from
compressed_tensors.quantization
import
(
QuantizationArgs
,
QuantizationStrategy
,
QuantizationType
)
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
convert_to_channelwise
,
sparse_cutlass_supported
)
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
ChannelQuantScaleParameter
,
ModelWeightParameter
,
PerTensorScaleParameter
)
__all__
=
[
"CompressedTensors24"
]
class
CompressedTensors24
(
CompressedTensorsScheme
):
def
__init__
(
self
,
quantized
:
bool
=
False
,
weight_quant
:
Optional
[
QuantizationArgs
]
=
None
,
input_quant
:
Optional
[
QuantizationArgs
]
=
None
):
self
.
quantized
=
quantized
self
.
weight_quant
=
weight_quant
self
.
input_quant
=
input_quant
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
# Only cutlass 3.x kernels are implemented so far
return
90
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size
:
int
,
output_partition_sizes
:
List
[
int
],
input_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
weight_loader
:
Callable
,
**
kwargs
):
if
not
sparse_cutlass_supported
():
raise
ValueError
(
"Sparse CUTLASS not supported. vLLM must be built with"
"CUDA 12.2 or later to use this feature"
)
self
.
output_dtype
=
params_dtype
layer
.
logical_widths
=
output_partition_sizes
self
.
weights_dtype
:
torch
.
dtype
=
self
.
_get_params_dtype
(
params_dtype
)
# parameter to store uncompressed weight
weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
sum
(
output_partition_sizes
),
input_size_per_partition
,
dtype
=
self
.
weights_dtype
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
)
# Check if quantized, not just 2:4 Sparse
if
self
.
quantized
:
if
(
self
.
weight_quant
and
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
.
value
):
weight_scale
=
ChannelQuantScaleParameter
(
data
=
torch
.
empty
((
sum
(
output_partition_sizes
),
1
),
dtype
=
torch
.
float32
),
output_dim
=
0
,
weight_loader
=
weight_loader
)
else
:
assert
(
self
.
weight_quant
and
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
.
value
)
weight_scale
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
# input quant will be non-none
if
self
.
input_quant
and
not
self
.
input_quant
.
dynamic
:
# register input quant scale
assert
(
self
.
input_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
.
value
)
input_scale
=
BasevLLMParameter
(
data
=
torch
.
empty
(
1
,
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"input_scale"
,
input_scale
)
else
:
# for sparse-only, pass in 1 for weight/input scales
weight_scale
=
torch
.
nn
.
Parameter
(
data
=
torch
.
ones
(
1
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
input_scale
=
torch
.
nn
.
Parameter
(
data
=
torch
.
ones
(
1
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"input_scale"
,
input_scale
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
layer
.
register_parameter
(
"weight"
,
weight
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
"""
Compress weights after loading. Store compressed weight and meta
tensor
:post-condition: layer.w_compressed and layer.meta are
set to the compressed weight and meta tensor in the
format expected by the Cutlass kernels
:param layer: The layer with the weights to be processed
"""
# torch.compile workaround
if
hasattr
(
layer
,
"input_scale"
):
layer
.
input_scale
=
torch
.
nn
.
Parameter
(
layer
.
input_scale
.
data
,
requires_grad
=
False
)
if
self
.
weight_quant
:
if
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
.
value
:
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
convert_to_channelwise
(
weight_scale
=
layer
.
weight_scale
,
logical_widths
=
layer
.
logical_widths
),
requires_grad
=
False
)
else
:
# torch.compile workaround
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
layer
.
weight_scale
.
data
,
requires_grad
=
False
)
w_compressed
,
meta
=
ops
.
cutlass_sparse_compress
(
layer
.
weight
.
data
)
layer
.
weight
=
torch
.
nn
.
Parameter
(
w_compressed
,
requires_grad
=
False
)
layer
.
meta
=
torch
.
nn
.
Parameter
(
meta
,
requires_grad
=
False
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
"""
Returns the output tensor for the layer with 2:4
sparse compressed weights, given the input tensor
and bias
:param layer: The layer with 2:4 sparse compressed
weights to be used for the computation
:param x: The input tensor to the layer
:param bias: The bias to be added to the output tensor
:return: The output tensor of the layer
"""
if
self
.
quantized
:
scale
=
None
if
hasattr
(
layer
,
"input_scale"
):
scale
=
layer
.
input_scale
if
self
.
weights_dtype
==
torch
.
int8
:
ops_output
=
ops
.
scaled_int8_quant
(
x
,
scale
=
scale
)
q_input
=
ops_output
[
0
]
input_scale
=
ops_output
[
1
]
else
:
assert
self
.
weights_dtype
==
torch
.
float8_e4m3fn
if
scale
is
not
None
:
q_input
,
input_scale
=
ops
.
scaled_fp8_quant
(
x
,
scale
=
scale
)
else
:
q_input
,
input_scale
=
ops
.
scaled_fp8_quant
(
x
,
use_per_token_if_dynamic
=
True
)
else
:
# Not quantized, nothing to do with the input_scales, use as is
input_scale
=
layer
.
input_scale
q_input
=
x
out
=
ops
.
cutlass_scaled_sparse_mm
(
a
=
q_input
,
bt_nzs
=
layer
.
weight
,
bt_meta
=
layer
.
meta
,
scale_a
=
input_scale
,
scale_b
=
layer
.
weight_scale
,
out_dtype
=
self
.
output_dtype
,
bias
=
bias
)
assert
out
.
is_contiguous
()
return
out
def
_get_params_dtype
(
self
,
params_dtype
:
torch
.
dtype
)
->
torch
.
dtype
:
if
not
self
.
quantized
:
return
params_dtype
assert
self
.
weight_quant
is
not
None
assert
self
.
input_quant
is
not
None
is_8_bits
=
self
.
weight_quant
.
num_bits
==
self
.
input_quant
.
num_bits
==
8
if
not
is_8_bits
:
raise
ValueError
(
"Cutlass only supports 8-bit quantization"
)
if
(
self
.
weight_quant
.
type
==
QuantizationType
.
FLOAT
and
self
.
input_quant
.
type
==
QuantizationType
.
FLOAT
):
return
torch
.
float8_e4m3fn
if
(
self
.
weight_quant
.
type
==
QuantizationType
.
INT
and
self
.
input_quant
.
type
==
QuantizationType
.
INT
):
return
torch
.
int8
raise
ValueError
(
"Quantization type not supported by Cutlass"
)
def
check_24
(
tensor
):
new_tensor
=
tensor
.
view
(
-
1
,
4
)
zero_counts
=
(
new_tensor
==
0
).
sum
(
dim
=
1
)
return
(
zero_counts
>=
2
).
all
().
item
()
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py
View file @
96ae75ad
...
@@ -61,6 +61,10 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
...
@@ -61,6 +61,10 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
params_dtype
:
torch
.
dtype
,
weight_loader
:
Callable
,
params_dtype
:
torch
.
dtype
,
weight_loader
:
Callable
,
**
kwargs
):
**
kwargs
):
assert
params_dtype
==
torch
.
float16
,
(
"float16 is required for marlin24 compressd models. Set dtype=torch.float16"
# noqa: E501
)
pack_factor
=
32
//
self
.
quant_type
.
size_bits
pack_factor
=
32
//
self
.
quant_type
.
size_bits
output_size_per_partition
=
sum
(
output_partition_sizes
)
output_size_per_partition
=
sum
(
output_partition_sizes
)
...
...
vllm/model_executor/layers/quantization/compressed_tensors/utils.py
View file @
96ae75ad
...
@@ -30,7 +30,7 @@ def should_ignore_layer(layer_name: Optional[str],
...
@@ -30,7 +30,7 @@ def should_ignore_layer(layer_name: Optional[str],
# in the safetensors checkpoint. So, we convert the name
# in the safetensors checkpoint. So, we convert the name
# from the fused version to unfused + check to make sure that
# from the fused version to unfused + check to make sure that
# each shard of the fused layer has the same scheme.
# each shard of the fused layer has the same scheme.
if
proj_name
in
FUSED_LAYER_NAME_MAPPING
:
if
proj_name
in
FUSED_LAYER_NAME_MAPPING
and
layer_name
not
in
ignore
:
shard_proj_names
=
FUSED_LAYER_NAME_MAPPING
[
proj_name
]
shard_proj_names
=
FUSED_LAYER_NAME_MAPPING
[
proj_name
]
# Convert fused_name --> [shard_names]
# Convert fused_name --> [shard_names]
...
...
vllm/model_executor/layers/quantization/experts_int8.py
View file @
96ae75ad
...
@@ -99,11 +99,13 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
...
@@ -99,11 +99,13 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
top_k
:
int
,
renormalize
:
bool
=
True
,
renormalize
:
bool
,
use_grouped_topk
:
bool
=
False
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
from
vllm.model_executor.layers.fused_moe
import
fused_experts
...
@@ -115,7 +117,9 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
...
@@ -115,7 +117,9 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
renormalize
=
renormalize
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
)
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
)
return
fused_experts
(
x
,
return
fused_experts
(
x
,
layer
.
w13_weight
,
layer
.
w13_weight
,
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
96ae75ad
...
@@ -6,6 +6,7 @@ from torch.nn.parameter import Parameter
...
@@ -6,6 +6,7 @@ from torch.nn.parameter import Parameter
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe
import
(
FusedMoE
,
FusedMoEMethodBase
,
from
vllm.model_executor.layers.fused_moe
import
(
FusedMoE
,
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
)
FusedMoeWeightScaleSupported
)
...
@@ -14,6 +15,8 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
...
@@ -14,6 +15,8 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
apply_w8a8_block_fp8_linear
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
apply_fp8_marlin_linear
,
prepare_fp8_layer_for_marlin
)
apply_fp8_marlin_linear
,
prepare_fp8_layer_for_marlin
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
...
@@ -22,7 +25,8 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
...
@@ -22,7 +25,8 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d
,
apply_fp8_linear
,
convert_to_channelwise
,
all_close_1d
,
apply_fp8_linear
,
convert_to_channelwise
,
cutlass_fp8_supported
,
normalize_e4m3fn_to_e4m3fnuz
,
per_tensor_dequantize
,
cutlass_fp8_supported
,
normalize_e4m3fn_to_e4m3fnuz
,
per_tensor_dequantize
,
requantize_with_max_scale
)
requantize_with_max_scale
)
from
vllm.model_executor.parameter
import
(
ModelWeightParameter
,
from
vllm.model_executor.parameter
import
(
BlockQuantScaleParameter
,
ModelWeightParameter
,
PerTensorScaleParameter
)
PerTensorScaleParameter
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
...
@@ -41,6 +45,7 @@ class Fp8Config(QuantizationConfig):
...
@@ -41,6 +45,7 @@ class Fp8Config(QuantizationConfig):
is_checkpoint_fp8_serialized
:
bool
=
False
,
is_checkpoint_fp8_serialized
:
bool
=
False
,
activation_scheme
:
str
=
"dynamic"
,
activation_scheme
:
str
=
"dynamic"
,
ignored_layers
:
Optional
[
List
[
str
]]
=
None
,
ignored_layers
:
Optional
[
List
[
str
]]
=
None
,
weight_block_size
:
Optional
[
List
[
int
]]
=
None
,
)
->
None
:
)
->
None
:
self
.
is_checkpoint_fp8_serialized
=
is_checkpoint_fp8_serialized
self
.
is_checkpoint_fp8_serialized
=
is_checkpoint_fp8_serialized
if
is_checkpoint_fp8_serialized
:
if
is_checkpoint_fp8_serialized
:
...
@@ -51,6 +56,20 @@ class Fp8Config(QuantizationConfig):
...
@@ -51,6 +56,20 @@ class Fp8Config(QuantizationConfig):
f
"Unsupported activation scheme
{
activation_scheme
}
"
)
f
"Unsupported activation scheme
{
activation_scheme
}
"
)
self
.
activation_scheme
=
activation_scheme
self
.
activation_scheme
=
activation_scheme
self
.
ignored_layers
=
ignored_layers
or
[]
self
.
ignored_layers
=
ignored_layers
or
[]
if
weight_block_size
is
not
None
:
if
not
is_checkpoint_fp8_serialized
:
raise
ValueError
(
"The block-wise quantization only supports fp8-serialized "
"checkpoint for now."
)
if
len
(
weight_block_size
)
!=
2
:
raise
ValueError
(
"The quantization block size of weight must have 2 "
f
"dimensions, but got
{
len
(
weight_block_size
)
}
dimensions"
)
if
activation_scheme
!=
"dynamic"
:
raise
ValueError
(
"The block-wise quantization only supports "
"dynamic activation scheme for now, but got "
f
"
{
activation_scheme
}
activation scheme."
)
self
.
weight_block_size
=
weight_block_size
@
classmethod
@
classmethod
def
get_name
(
cls
)
->
str
:
def
get_name
(
cls
)
->
str
:
...
@@ -74,9 +93,12 @@ class Fp8Config(QuantizationConfig):
...
@@ -74,9 +93,12 @@ class Fp8Config(QuantizationConfig):
is_checkpoint_fp8_serialized
=
(
"fp8"
in
quant_method
)
is_checkpoint_fp8_serialized
=
(
"fp8"
in
quant_method
)
activation_scheme
=
cls
.
get_from_keys
(
config
,
[
"activation_scheme"
])
activation_scheme
=
cls
.
get_from_keys
(
config
,
[
"activation_scheme"
])
ignored_layers
=
cls
.
get_from_keys_or
(
config
,
[
"ignored_layers"
],
None
)
ignored_layers
=
cls
.
get_from_keys_or
(
config
,
[
"ignored_layers"
],
None
)
weight_block_size
=
cls
.
get_from_keys_or
(
config
,
[
"weight_block_size"
],
None
)
return
cls
(
is_checkpoint_fp8_serialized
=
is_checkpoint_fp8_serialized
,
return
cls
(
is_checkpoint_fp8_serialized
=
is_checkpoint_fp8_serialized
,
activation_scheme
=
activation_scheme
,
activation_scheme
=
activation_scheme
,
ignored_layers
=
ignored_layers
)
ignored_layers
=
ignored_layers
,
weight_block_size
=
weight_block_size
)
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
...
@@ -123,6 +145,11 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -123,6 +145,11 @@ class Fp8LinearMethod(LinearMethodBase):
if
current_platform
.
is_rocm
():
if
current_platform
.
is_rocm
():
self
.
use_marlin
=
False
self
.
use_marlin
=
False
self
.
block_quant
=
self
.
quant_config
.
weight_block_size
is
not
None
if
self
.
block_quant
:
# Marlin doesn't support block-wise fp8
self
.
use_marlin
=
False
def
create_weights
(
def
create_weights
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
...
@@ -133,10 +160,34 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -133,10 +160,34 @@ class Fp8LinearMethod(LinearMethodBase):
params_dtype
:
torch
.
dtype
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
**
extra_weight_attrs
,
):
):
del
input_size
,
output_size
output_size_per_partition
=
sum
(
output_partition_sizes
)
output_size_per_partition
=
sum
(
output_partition_sizes
)
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
if
self
.
block_quant
:
tp_size
=
get_tensor_model_parallel_world_size
()
assert
self
.
quant_config
.
weight_block_size
is
not
None
block_n
,
block_k
=
(
self
.
quant_config
.
weight_block_size
[
0
],
self
.
quant_config
.
weight_block_size
[
1
],
)
# Required by row parallel
if
(
tp_size
>
1
and
input_size
//
input_size_per_partition
==
tp_size
and
input_size_per_partition
%
block_k
!=
0
):
raise
ValueError
(
f
"Weight input_size_per_partition = "
f
"
{
input_size_per_partition
}
is not divisible by "
f
"weight quantization block_k =
{
block_k
}
."
)
# Required by column parallel or enabling merged weights
if
(
tp_size
>
1
and
output_size
//
output_size_per_partition
==
tp_size
)
or
len
(
output_partition_sizes
)
>
1
:
for
output_partition_size
in
output_partition_sizes
:
if
output_partition_size
%
block_n
!=
0
:
raise
ValueError
(
f
"Weight output_partition_size = "
f
"
{
output_partition_size
}
is not divisible by "
f
"weight quantization block_n =
{
block_n
}
."
)
layer
.
logical_widths
=
output_partition_sizes
layer
.
logical_widths
=
output_partition_sizes
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
input_size_per_partition
=
input_size_per_partition
...
@@ -161,12 +212,29 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -161,12 +212,29 @@ class Fp8LinearMethod(LinearMethodBase):
# Otherwise, wait until process_weights_after_loading.
# Otherwise, wait until process_weights_after_loading.
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
# WEIGHT SCALE
# WEIGHT SCALE
scale
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
if
not
self
.
block_quant
:
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
scale
=
PerTensorScaleParameter
(
weight_loader
=
weight_loader
)
data
=
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
weight_loader
=
weight_loader
,
layer
.
register_parameter
(
"weight_scale"
,
scale
)
)
scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
layer
.
register_parameter
(
"weight_scale"
,
scale
)
else
:
assert
self
.
quant_config
.
activation_scheme
==
"dynamic"
scale
=
BlockQuantScaleParameter
(
data
=
torch
.
empty
(
(
output_size_per_partition
+
block_n
-
1
)
//
block_n
,
(
input_size_per_partition
+
block_k
-
1
)
//
block_k
,
dtype
=
torch
.
float32
,
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
,
)
scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
# The weight_scale_inv name is intentional for deepseekv3
layer
.
register_parameter
(
"weight_scale_inv"
,
scale
)
# INPUT ACTIVATION SCALE
# INPUT ACTIVATION SCALE
if
self
.
quant_config
.
activation_scheme
==
"static"
:
if
self
.
quant_config
.
activation_scheme
==
"static"
:
...
@@ -180,6 +248,9 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -180,6 +248,9 @@ class Fp8LinearMethod(LinearMethodBase):
layer
.
register_parameter
(
"input_scale"
,
None
)
layer
.
register_parameter
(
"input_scale"
,
None
)
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
# Block quant doesn't need to process weights after loading
if
self
.
block_quant
:
return
layer
.
weight
=
torch
.
nn
.
Parameter
(
layer
.
weight
.
data
,
layer
.
weight
=
torch
.
nn
.
Parameter
(
layer
.
weight
.
data
,
requires_grad
=
False
)
requires_grad
=
False
)
# If checkpoint not serialized fp8, quantize the weights.
# If checkpoint not serialized fp8, quantize the weights.
...
@@ -266,6 +337,17 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -266,6 +337,17 @@ class Fp8LinearMethod(LinearMethodBase):
size_k
=
layer
.
input_size_per_partition
,
size_k
=
layer
.
input_size_per_partition
,
bias
=
bias
)
bias
=
bias
)
if
self
.
block_quant
:
assert
self
.
quant_config
.
weight_block_size
is
not
None
return
apply_w8a8_block_fp8_linear
(
input
=
x
,
weight
=
layer
.
weight
,
block_size
=
self
.
quant_config
.
weight_block_size
,
weight_scale
=
layer
.
weight_scale_inv
,
input_scale
=
layer
.
input_scale
,
bias
=
bias
,
)
return
apply_fp8_linear
(
return
apply_fp8_linear
(
input
=
x
,
input
=
x
,
weight
=
layer
.
weight
,
weight
=
layer
.
weight
,
...
@@ -291,6 +373,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -291,6 +373,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def
__init__
(
self
,
quant_config
:
Fp8Config
):
def
__init__
(
self
,
quant_config
:
Fp8Config
):
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
block_quant
=
self
.
quant_config
.
weight_block_size
is
not
None
def
create_weights
(
self
,
layer
:
Module
,
num_experts
:
int
,
hidden_size
:
int
,
def
create_weights
(
self
,
layer
:
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
torch
.
dtype
,
intermediate_size
:
int
,
params_dtype
:
torch
.
dtype
,
...
@@ -298,6 +381,27 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -298,6 +381,27 @@ class Fp8MoEMethod(FusedMoEMethodBase):
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
params_dtype
=
torch
.
float8_e4m3fn
params_dtype
=
torch
.
float8_e4m3fn
if
self
.
block_quant
:
assert
self
.
quant_config
.
weight_block_size
is
not
None
tp_size
=
get_tensor_model_parallel_world_size
()
block_n
,
block_k
=
(
self
.
quant_config
.
weight_block_size
[
0
],
self
.
quant_config
.
weight_block_size
[
1
],
)
# NOTE: To ensure proper alignment of the block-wise quantization
# scales, the output_size of the weights for both the gate and up
# layers must be divisible by block_n.
# Required by column parallel or enabling merged weights
if
intermediate_size
%
block_n
!=
0
:
raise
ValueError
(
f
"The output_size of gate's and up's weight = "
f
"
{
intermediate_size
}
is not divisible by "
f
"weight quantization block_n =
{
block_n
}
."
)
if
(
tp_size
>
1
and
intermediate_size
%
block_k
!=
0
):
# Required by row parallel
raise
ValueError
(
f
"The input_size of down's weight = "
f
"
{
intermediate_size
}
is not divisible by "
f
"weight quantization block_k =
{
block_k
}
."
)
# WEIGHTS
# WEIGHTS
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
...
@@ -317,21 +421,45 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -317,21 +421,45 @@ class Fp8MoEMethod(FusedMoEMethodBase):
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
# WEIGHT_SCALES
# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
if
not
self
.
block_quant
:
# They will be combined to a single scale after weight loading.
# Allocate 2 scales for w1 and w3 respectively.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
# They will be combined to a single scale after weight loading.
2
,
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
dtype
=
torch
.
float32
),
num_experts
,
2
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
w2_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
else
:
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
2
*
((
intermediate_size
+
block_n
-
1
)
//
block_n
),
(
hidden_size
+
block_k
-
1
)
//
block_k
,
dtype
=
torch
.
float32
,
),
requires_grad
=
False
,
)
w2_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
(
hidden_size
+
block_n
-
1
)
//
block_n
,
(
intermediate_size
+
block_k
-
1
)
//
block_k
,
dtype
=
torch
.
float32
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight_scale_inv"
,
w13_weight_scale
)
layer
.
register_parameter
(
"w2_weight_scale_inv"
,
w2_weight_scale
)
assert
self
.
quant_config
.
activation_scheme
==
"dynamic"
w2_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
# Add the quantization method used (per tensor/grouped/channel)
# Add the quantization method used (per tensor/grouped/channel)
# to ensure the weight scales are loaded in properly
# to ensure the weight scales are loaded in properly
extra_weight_attrs
.
update
(
extra_weight_attrs
.
update
(
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
BLOCK
.
value
}
if
self
.
block_quant
else
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
TENSOR
.
value
})
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
TENSOR
.
value
})
# If loading fp8 checkpoint, pass the weight loaders.
# If loading fp8 checkpoint, pass the weight loaders.
# If loading an fp16 checkpoint, do not (we will quantize in
# If loading an fp16 checkpoint, do not (we will quantize in
...
@@ -364,7 +492,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -364,7 +492,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer
.
w2_input_scale
=
None
layer
.
w2_input_scale
=
None
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
# Block quant doesn't need to process weights after loading
if
self
.
block_quant
:
return
# If checkpoint is fp16, quantize in place.
# If checkpoint is fp16, quantize in place.
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
# If rocm, use float8_e4m3fnuz as dtype
# If rocm, use float8_e4m3fnuz as dtype
...
@@ -471,12 +601,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -471,12 +601,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
router_logits
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
top_k
:
int
,
renormalize
:
bool
,
renormalize
:
bool
,
use_grouped_topk
:
bool
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
from
vllm.model_executor.layers.fused_moe
import
fused_experts
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
...
@@ -487,19 +618,27 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -487,19 +618,27 @@ class Fp8MoEMethod(FusedMoEMethodBase):
renormalize
=
renormalize
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
)
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
return
fused_experts
(
x
,
e_score_correction_bias
=
e_score_correction_bias
,
layer
.
w13_weight
,
)
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
return
fused_experts
(
topk_ids
=
topk_ids
,
x
,
inplace
=
True
,
layer
.
w13_weight
,
use_fp8_w8a8
=
True
,
layer
.
w2_weight
,
w1_scale
=
layer
.
w13_weight_scale
,
topk_weights
=
topk_weights
,
w2_scale
=
layer
.
w2_weight_scale
,
topk_ids
=
topk_ids
,
a1_scale
=
layer
.
w13_input_scale
,
inplace
=
True
,
a2_scale
=
layer
.
w2_input_scale
)
use_fp8_w8a8
=
True
,
w1_scale
=
(
layer
.
w13_weight_scale_inv
if
self
.
block_quant
else
layer
.
w13_weight_scale
),
w2_scale
=
(
layer
.
w2_weight_scale_inv
if
self
.
block_quant
else
layer
.
w2_weight_scale
),
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
block_shape
=
self
.
quant_config
.
weight_block_size
,
)
class
Fp8KVCacheMethod
(
BaseKVCacheMethod
):
class
Fp8KVCacheMethod
(
BaseKVCacheMethod
):
...
...
vllm/model_executor/layers/quantization/gptq_marlin.py
View file @
96ae75ad
...
@@ -532,11 +532,13 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
...
@@ -532,11 +532,13 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
top_k
:
int
,
renormalize
:
bool
=
True
,
renormalize
:
bool
,
use_grouped_topk
:
bool
=
False
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# The input must currently be float16
# The input must currently be float16
orig_dtype
=
x
.
dtype
orig_dtype
=
x
.
dtype
...
@@ -550,7 +552,9 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
...
@@ -550,7 +552,9 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
renormalize
=
renormalize
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
None
)
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
)
return
torch
.
ops
.
vllm
.
fused_marlin_moe
(
return
torch
.
ops
.
vllm
.
fused_marlin_moe
(
x
,
x
,
...
...
Prev
1
…
11
12
13
14
15
16
17
18
19
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment