Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
96ae75ad
Commit
96ae75ad
authored
Jan 04, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.6.6.post1' into v0.6.6.post1-dev
parents
f9f4a735
2339d59f
Changes
374
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1056 additions
and
261 deletions
+1056
-261
vllm/lora/models.py
vllm/lora/models.py
+8
-4
vllm/lora/utils.py
vllm/lora/utils.py
+23
-4
vllm/lora/worker_manager.py
vllm/lora/worker_manager.py
+12
-1
vllm/model_executor/guided_decoding/__init__.py
vllm/model_executor/guided_decoding/__init__.py
+39
-40
vllm/model_executor/guided_decoding/outlines_logits_processors.py
...el_executor/guided_decoding/outlines_logits_processors.py
+14
-9
vllm/model_executor/guided_decoding/utils.py
vllm/model_executor/guided_decoding/utils.py
+70
-0
vllm/model_executor/guided_decoding/xgrammar_decoding.py
vllm/model_executor/guided_decoding/xgrammar_decoding.py
+75
-42
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+117
-31
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+86
-40
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+21
-2
vllm/model_executor/layers/quantization/awq_marlin.py
vllm/model_executor/layers/quantization/awq_marlin.py
+7
-3
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+160
-27
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+15
-7
vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py
...ayers/quantization/compressed_tensors/schemes/__init__.py
+7
-8
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py
...ation/compressed_tensors/schemes/compressed_tensors_24.py
+208
-0
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py
...compressed_tensors/schemes/compressed_tensors_w4a16_24.py
+4
-0
vllm/model_executor/layers/quantization/compressed_tensors/utils.py
..._executor/layers/quantization/compressed_tensors/utils.py
+1
-1
vllm/model_executor/layers/quantization/experts_int8.py
vllm/model_executor/layers/quantization/experts_int8.py
+7
-3
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+175
-36
vllm/model_executor/layers/quantization/gptq_marlin.py
vllm/model_executor/layers/quantization/gptq_marlin.py
+7
-3
No files found.
vllm/lora/models.py
View file @
96ae75ad
...
...
@@ -28,7 +28,7 @@ from vllm.lora.utils import (from_layer, from_layer_logits_processor,
parse_fine_tuned_lora_name
,
replace_submodule
)
from
vllm.model_executor.models
import
SupportsLoRA
,
supports_multimodal
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.model_executor.models.utils
import
PPMissingLayer
from
vllm.model_executor.models.utils
import
PPMissingLayer
,
WeightsMapper
from
vllm.utils
import
is_pin_memory_available
logger
=
init_logger
(
__name__
)
...
...
@@ -113,13 +113,14 @@ class LoRAModel(AdapterModel):
target_embedding_padding
:
Optional
[
int
]
=
None
,
embedding_modules
:
Optional
[
Dict
[
str
,
str
]]
=
None
,
embedding_padding_modules
:
Optional
[
List
[
str
]]
=
None
,
weights_mapper
:
Optional
[
WeightsMapper
]
=
None
,
)
->
"LoRAModel"
:
"""Create a LoRAModel from a dictionary of tensors."""
pin_memory
=
str
(
device
)
==
"cpu"
and
is_pin_memory_available
()
loras
:
Dict
[
str
,
LoRALayerWeights
]
=
{}
for
tensor_name
,
tensor
in
tensors
.
items
():
module_name
,
is_lora_a
,
is_bias
=
parse_fine_tuned_lora_name
(
tensor_name
)
tensor_name
,
weights_mapper
)
if
module_name
not
in
loras
:
lora_embeddings_tensor
=
None
if
embeddings
:
...
...
@@ -187,6 +188,7 @@ class LoRAModel(AdapterModel):
target_embedding_padding
:
Optional
[
int
]
=
None
,
embedding_modules
:
Optional
[
Dict
[
str
,
str
]]
=
None
,
embedding_padding_modules
:
Optional
[
List
[
str
]]
=
None
,
weights_mapper
:
Optional
[
WeightsMapper
]
=
None
,
)
->
"LoRAModel"
:
"""Create a LoRAModel from a local checkpoint.
...
...
@@ -229,7 +231,8 @@ class LoRAModel(AdapterModel):
with
safetensors
.
safe_open
(
lora_tensor_path
,
framework
=
"pt"
)
as
f
:
# type: ignore
for
lora_module
in
f
.
keys
():
# noqa
module_name
,
_
,
_
=
parse_fine_tuned_lora_name
(
lora_module
)
module_name
,
_
,
_
=
parse_fine_tuned_lora_name
(
lora_module
,
weights_mapper
)
part_name
=
module_name
.
split
(
"."
)[
-
1
]
if
part_name
not
in
expected_lora_modules
:
unexpected_modules
.
append
(
module_name
)
...
...
@@ -289,7 +292,8 @@ class LoRAModel(AdapterModel):
embeddings
=
embeddings
,
target_embedding_padding
=
target_embedding_padding
,
embedding_modules
=
embedding_modules
,
embedding_padding_modules
=
embedding_padding_modules
)
embedding_padding_modules
=
embedding_padding_modules
,
weights_mapper
=
weights_mapper
)
class
LoRAModelManager
(
AdapterModelManager
):
...
...
vllm/lora/utils.py
View file @
96ae75ad
...
...
@@ -30,6 +30,7 @@ from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
# yapf: enable
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.models.utils
import
WeightsMapper
logger
=
init_logger
(
__name__
)
...
...
@@ -91,28 +92,46 @@ def replace_submodule(model: nn.Module, module_name: str,
return
new_module
def
parse_fine_tuned_lora_name
(
name
:
str
)
->
Tuple
[
str
,
bool
,
bool
]:
def
parse_fine_tuned_lora_name
(
name
:
str
,
weights_mapper
:
Optional
[
WeightsMapper
]
=
None
)
->
Tuple
[
str
,
bool
,
bool
]:
"""Parse the name of lora weights.
args:
name: the name of the fine-tuned LoRA, e.g.
base_model.model.dense1.weight
weights_mapper: maps the name of weight, e.g.
`model.` -> `language_model.model.`,
return:
Tuple(module_name, is_lora_a):
module_name: the name of the module, e.g. model.dense1,
is_lora_a whether the tensor is lora_a or lora_b.
is_bias whether the tensor is lora bias.
"""
# LoRA weight qualified name always starts with `base_model.model.`,
# so we remove the prefix `base_model.model.` to make the following
# mapping correctly.
if
"base_model.model."
in
name
:
name
=
name
.
replace
(
"base_model.model."
,
""
)
name
=
weights_mapper
.
_map_name
(
name
)
if
weights_mapper
else
name
# recover the prefix `base_model.model.`
name
=
"base_model.model."
+
name
parts
=
name
.
split
(
"."
)
if
parts
[
-
1
]
==
"weight"
and
(
parts
[
-
2
]
==
"lora_A"
or
parts
[
-
2
]
==
"lora_B"
):
return
"."
.
join
(
parts
[
2
:
-
2
]),
parts
[
-
2
]
==
"lora_A"
,
False
new_name
=
"."
.
join
(
parts
[
2
:
-
2
])
return
new_name
,
parts
[
-
2
]
==
"lora_A"
,
False
if
parts
[
-
1
]
==
"lora_embedding_A"
or
parts
[
-
1
]
==
"lora_embedding_B"
:
return
"."
.
join
(
parts
[
2
:
-
1
]),
parts
[
-
1
]
==
"lora_embedding_A"
,
False
new_name
=
"."
.
join
(
parts
[
2
:
-
1
])
return
new_name
,
parts
[
-
1
]
==
"lora_embedding_A"
,
False
if
parts
[
-
1
]
==
"bias"
:
return
"."
.
join
(
parts
[
2
:
-
2
]),
False
,
True
new_name
=
"."
.
join
(
parts
[
2
:
-
2
])
return
new_name
,
False
,
True
raise
ValueError
(
f
"
{
name
}
is unsupported LoRA weight"
)
...
...
vllm/lora/worker_manager.py
View file @
96ae75ad
...
...
@@ -91,7 +91,17 @@ class WorkerLoRAManager(AbstractWorkerManager):
packed_modules_mapping
[
module
])
else
:
expected_lora_modules
.
append
(
module
)
expected_lora_modules
=
list
(
set
(
expected_lora_modules
))
lora_path
=
get_adapter_absolute_path
(
lora_request
.
lora_path
)
# For some models like Qwen2VL, we need to use hf_to_vllm_mapper
# to ensure correct loading of lora weights.
hf_to_vllm_mapper
=
None
if
(
hasattr
(
model
,
"hf_to_vllm_mapper"
)
and
model
.
hf_to_vllm_mapper
is
not
None
):
hf_to_vllm_mapper
=
model
.
hf_to_vllm_mapper
lora
=
self
.
_lora_model_cls
.
from_local_checkpoint
(
lora_path
,
expected_lora_modules
,
...
...
@@ -103,7 +113,8 @@ class WorkerLoRAManager(AbstractWorkerManager):
self
.
lora_config
.
lora_extra_vocab_size
,
embedding_modules
=
self
.
embedding_modules
,
embedding_padding_modules
=
self
.
embedding_padding_modules
,
)
weights_mapper
=
hf_to_vllm_mapper
)
except
Exception
as
e
:
raise
RuntimeError
(
f
"Loading lora
{
lora_path
}
failed"
)
from
e
if
lora
.
rank
>
self
.
lora_config
.
max_lora_rank
:
...
...
vllm/model_executor/guided_decoding/__init__.py
View file @
96ae75ad
...
...
@@ -3,6 +3,9 @@ from __future__ import annotations
from
typing
import
TYPE_CHECKING
from
vllm.logger
import
init_logger
from
vllm.model_executor.guided_decoding.utils
import
(
convert_lark_to_gbnf
,
grammar_is_likely_lark
,
has_lmf_unsupported_json_features
,
has_xgrammar_unsupported_json_features
)
from
vllm.platforms
import
CpuArchEnum
,
current_platform
if
TYPE_CHECKING
:
...
...
@@ -15,49 +18,24 @@ if TYPE_CHECKING:
logger
=
init_logger
(
__name__
)
def
has_xgrammar_unsupported_json_features
(
schema
:
dict
)
->
bool
:
"""Check if JSON schema contains features unsupported by xgrammar."""
def
check_object
(
obj
:
dict
)
->
bool
:
if
not
isinstance
(
obj
,
dict
):
return
False
# Check for pattern restrictions
if
"pattern"
in
obj
:
return
True
# Check for numeric ranges
if
obj
.
get
(
"type"
)
in
(
"integer"
,
"number"
)
and
any
(
key
in
obj
for
key
in
[
"minimum"
,
"maximum"
,
"exclusiveMinimum"
,
"exclusiveMaximum"
,
"multipleOf"
]):
return
True
# Recursively check all nested objects and arrays
for
value
in
obj
.
values
():
if
isinstance
(
value
,
dict
):
if
check_object
(
value
):
return
True
elif
isinstance
(
value
,
list
):
for
item
in
value
:
if
isinstance
(
item
,
dict
)
and
check_object
(
item
):
return
True
return
False
return
check_object
(
schema
)
def
maybe_backend_fallback
(
guided_params
:
GuidedDecodingParams
)
->
GuidedDecodingParams
:
# lm-format-enforce doesn't support grammar, fallback to xgrammar
if
(
guided_params
.
backend
==
"lm-format-enforcer"
and
guided_params
.
grammar
is
not
None
):
logger
.
warning
(
"lm-format-enforcer does not support grammar guided decoding. "
"Falling back to use xgrammar instead."
)
guided_params
.
backend
=
"xgrammar"
if
guided_params
.
backend
==
"lm-format-enforcer"
:
if
guided_params
.
grammar
is
not
None
:
logger
.
warning
(
"lm-format-enforcer does not support grammar guided decoding. "
"Falling back to use xgrammar instead."
)
guided_params
.
backend
=
"xgrammar"
# lm-format-enforcer doesn't support some JSON schema features
elif
(
guided_params
.
json
is
not
None
and
has_lmf_unsupported_json_features
(
guided_params
.
json
)):
logger
.
warning
(
"lm-format-enforcer does not support advanced JSON schema "
"features like patterns or numeric ranges. "
"Falling back to use outlines instead."
)
guided_params
.
backend
=
"outlines"
if
guided_params
.
backend
==
"xgrammar"
:
# xgrammar only has x86 wheels for linux, fallback to outlines
...
...
@@ -82,6 +60,27 @@ def maybe_backend_fallback(
"Falling back to use outlines instead."
)
guided_params
.
backend
=
"outlines"
# xgrammar only supports GBNF grammars, so we must convert Lark.
# We must check if the grammar is likely Lark and if that
# grammar is convertible to GBNF
elif
(
guided_params
.
grammar
is
not
None
and
grammar_is_likely_lark
(
guided_params
.
grammar
)):
try
:
convert_lark_to_gbnf
(
guided_params
.
grammar
)
except
Exception
:
logger
.
warning
(
"xgrammar does not support Lark grammars and the "
"grammar failed to convert to GBNF. "
"Falling back to use outlines instead."
)
guided_params
.
backend
=
"outlines"
if
(
guided_params
.
backend
==
"outlines"
and
guided_params
.
json_object
is
not
None
):
# outlines doesn't support json_object, fallback to xgrammar
logger
.
warning
(
"outlines does not support json_object. "
"Falling back to use xgrammar instead."
)
guided_params
.
backend
=
"xgrammar"
return
guided_params
...
...
vllm/model_executor/guided_decoding/outlines_logits_processors.py
View file @
96ae75ad
...
...
@@ -21,10 +21,11 @@ from typing import Callable, DefaultDict, Dict, List, Union
import
numpy
as
np
import
torch
from
lark
import
Lark
from
outlines
import
grammars
from
outlines.caching
import
cache
from
outlines.fsm.guide
import
CFGGuide
,
Generate
,
Guide
,
RegexGuide
,
Write
from
outlines.fsm.guide
import
(
CFGGuide
,
CFGState
,
Generate
,
Guide
,
RegexGuide
,
Write
)
from
outlines.fsm.parsing
import
PartialLark
from
outlines_core.fsm.json_schema
import
build_regex_from_schema
from
pydantic
import
BaseModel
from
transformers
import
PreTrainedTokenizerBase
...
...
@@ -34,7 +35,9 @@ class BaseLogitsProcessor:
def
__init__
(
self
,
guide
:
Guide
):
self
.
_guide
:
Guide
=
guide
self
.
_fsm_state
:
DefaultDict
[
int
,
int
]
=
defaultdict
(
int
)
# CFGState is used for the FSM state for CFGGuide
self
.
_fsm_state
:
DefaultDict
[
int
,
Union
[
int
,
CFGState
]]
=
defaultdict
(
int
)
def
__call__
(
self
,
input_ids
:
List
[
int
],
scores
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
@@ -54,15 +57,13 @@ class BaseLogitsProcessor:
# On the first time this is called, we simply re-create
# the Lark object.
if
isinstance
(
self
.
_guide
,
CFGGuide
):
self
.
_guide
.
parser
=
Lark
(
self
.
_guide
.
parser
=
Partial
Lark
(
self
.
_guide
.
cfg_string
,
parser
=
"lalr"
,
lexer
=
"contextual"
,
propagate_positions
=
False
,
maybe_placeholders
=
False
,
regex
=
True
,
import_paths
=
[
grammars
.
GRAMMAR_PATH
],
)
self
.
_fsm_state
[
seq_id
]
=
CFGState
(
parser_state
=
self
.
_guide
.
parser
.
parse
(
""
),
prev_token
=
None
)
instruction
=
self
.
_guide
.
get_next_instruction
(
state
=
self
.
_fsm_state
[
seq_id
])
...
...
@@ -200,7 +201,8 @@ def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase):
string
=
tokenizer
.
convert_tokens_to_string
([
token
])
# A hack to handle missing spaces to HF's Llama tokenizers
if
token
.
startswith
(
SPIECE_UNDERLINE
)
or
token
==
"<0x20>"
:
if
(
type
(
token
)
is
str
and
token
.
startswith
(
SPIECE_UNDERLINE
)
or
token
==
"<0x20>"
):
return
" "
+
string
return
string
...
...
@@ -211,6 +213,9 @@ def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase):
"""Sync vLLM's decoder with the outlines by returning list."""
def
new_decoder
(
inp_tokens
:
List
[
int
])
->
List
[
str
]:
if
(
isinstance
(
inp_tokens
,
list
)
and
len
(
inp_tokens
)
==
1
and
isinstance
(
inp_tokens
[
0
],
list
)):
inp_tokens
=
inp_tokens
[
0
]
return
[
decoder
(
inp_tokens
)]
return
new_decoder
...
...
vllm/model_executor/guided_decoding/
xgrammar_
utils.py
→
vllm/model_executor/guided_decoding/utils.py
View file @
96ae75ad
import
re
def
has_xgrammar_unsupported_json_features
(
schema
:
dict
)
->
bool
:
"""Check if JSON schema contains features unsupported by xgrammar."""
def
check_object
(
obj
:
dict
)
->
bool
:
if
not
isinstance
(
obj
,
dict
):
return
False
# Check for pattern restrictions
if
"pattern"
in
obj
:
return
True
# Check for numeric ranges
if
obj
.
get
(
"type"
)
in
(
"integer"
,
"number"
)
and
any
(
key
in
obj
for
key
in
[
"minimum"
,
"maximum"
,
"exclusiveMinimum"
,
"exclusiveMaximum"
,
"multipleOf"
]):
return
True
# Recursively check all nested objects and arrays
for
value
in
obj
.
values
():
if
isinstance
(
value
,
dict
):
if
check_object
(
value
):
return
True
elif
isinstance
(
value
,
list
):
for
item
in
value
:
if
isinstance
(
item
,
dict
)
and
check_object
(
item
):
return
True
return
False
return
check_object
(
schema
)
def
has_lmf_unsupported_json_features
(
schema
:
dict
)
->
bool
:
"""
Check if JSON schema contains features unsupported
by lm_format_enforcer.
Known issues:
- Regex patterns:
"grade": {
"type": "string",
"pattern": "^[A-D]$" # Regex pattern
},
"""
def
check_object
(
obj
:
dict
)
->
bool
:
if
not
isinstance
(
obj
,
dict
):
return
False
# Check for pattern restrictions
if
"pattern"
in
obj
:
return
True
# Recursively check all nested objects and arrays
for
value
in
obj
.
values
():
if
isinstance
(
value
,
dict
):
if
check_object
(
value
):
return
True
elif
isinstance
(
value
,
list
):
for
item
in
value
:
if
isinstance
(
item
,
dict
)
and
check_object
(
item
):
return
True
return
False
return
check_object
(
schema
)
def
grammar_is_likely_lark
(
grammar_str
:
str
)
->
bool
:
"""
Check if grammar appears to use Lark syntax.
...
...
vllm/model_executor/guided_decoding/xgrammar_decoding.py
View file @
96ae75ad
...
...
@@ -3,7 +3,7 @@ from __future__ import annotations
import
json
from
dataclasses
import
dataclass
,
field
from
typing
import
TYPE_CHECKING
,
Any
,
NamedTuple
from
typing
import
TYPE_CHECKING
,
Any
import
torch
from
transformers
import
PreTrainedTokenizerFast
...
...
@@ -14,8 +14,9 @@ try:
except
ImportError
:
pass
from
vllm.model_executor.guided_decoding.xgrammar_utils
import
(
convert_lark_to_gbnf
,
grammar_is_likely_lark
)
from
vllm.model_executor.guided_decoding.utils
import
(
convert_lark_to_gbnf
,
grammar_is_likely_lark
)
from
vllm.transformers_utils.tokenizers.mistral
import
MistralTokenizer
if
TYPE_CHECKING
:
from
transformers
import
PreTrainedTokenizer
...
...
@@ -37,11 +38,21 @@ def get_local_xgrammar_guided_decoding_logits_processor(
return
XGrammarLogitsProcessor
(
config
)
class
TokenizerData
(
NamedTuple
):
@
dataclass
(
frozen
=
True
)
class
TokenizerData
:
"""Immutable container for cached tokenizer data."""
encoded_vocab
:
list
[
str
]
stop_token_ids
:
list
[
int
]
|
None
backend_str
:
str
encoded_vocab
:
list
[
str
]
=
field
(
default_factory
=
list
)
stop_token_ids
:
list
[
int
]
|
None
=
None
# These fields are mutually exclusive: `backend_str` is used to create a
# TokenizeInfo with `TokenizerInfo.from_huggingface` while `vocab_type` is
# used within the constructor of TokenizeInfo
backend_str
:
str
|
None
=
None
vocab_type
:
xgr
.
VocabType
|
None
=
None
def
__post_init__
(
self
):
# Check for mutual exclusive
assert
not
(
self
.
backend_str
and
self
.
vocab_type
),
\
"backend_str and vocab_type are mutual exclusive"
class
TokenizerDataCache
:
...
...
@@ -68,18 +79,27 @@ class TokenizerDataCache:
"get_vocab method."
)
from
e
stop_token_ids
=
None
backend_str
=
xgr
.
VocabType
.
RAW
backend_str
=
""
vocab_type
=
xgr
.
VocabType
.
RAW
if
stop_token_ids
is
None
and
hasattr
(
tokenizer
,
"eos_token_id"
)
and
tokenizer
.
eos_token_id
is
not
None
:
stop_token_ids
=
[
tokenizer
.
eos_token_id
]
if
isinstance
(
tokenizer
,
PreTrainedTokenizerFast
):
backend_str
=
tokenizer
.
backend_tokenizer
.
to_str
()
if
stop_token_ids
is
None
and
hasattr
(
tokenizer
,
"eos_token_id"
)
and
tokenizer
.
eos_token_id
is
not
None
:
stop_token_ids
=
[
tokenizer
.
eos_token_id
]
vocab_type
=
None
elif
isinstance
(
tokenizer
,
MistralTokenizer
):
# REF: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501
vocab_type
=
xgr
.
VocabType
.
BYTE_FALLBACK
cls
.
_cache
[
tokenizer_hash
]
=
TokenizerData
(
encoded_vocab
=
encoded_vocab
,
stop_token_ids
=
stop_token_ids
,
backend_str
=
backend_str
)
backend_str
=
backend_str
,
vocab_type
=
vocab_type
)
return
cls
.
_cache
[
tokenizer_hash
]
...
...
@@ -98,11 +118,30 @@ class GrammarCompilerCache:
cache_key
=
str
(
config
.
tokenizer_hash
)
if
cache_key
not
in
cls
.
_cache
:
assert
config
.
encoded_vocab
is
not
None
tokenizer_info
=
xgr
.
TokenizerInfo
.
_create_from_handle
(
xgr_core
.
TokenizerInfo
.
from_huggingface
(
config
.
encoded_vocab
,
config
.
backend_str
,
config
.
vocab_size
,
config
.
stop_token_ids
))
assert
config
.
tokenizer_data
is
not
None
assert
config
.
tokenizer_data
.
encoded_vocab
is
not
None
config_data
=
config
.
tokenizer_data
# In TokenizerDataCache.get_tokenizer_data, a serializable
# tokenizer_data is created and cached. This data is used to build
# a tokenizer_info and create an xgrammar compiler.
# - If tokenizer_data has backend_str set, use
# xgr_core.TokenizerInfo.from_huggingface (a C++ bind).
# - Otherwise, use the default constructor with vocab_type.
# - xgr_core.TokenizerInfo.from_huggingface !=
# xgr.TokenizerInfo.from_huggingface.
if
config_data
.
backend_str
:
tokenizer_info
=
xgr
.
TokenizerInfo
.
_create_from_handle
(
xgr_core
.
TokenizerInfo
.
from_huggingface
(
config_data
.
encoded_vocab
,
config_data
.
backend_str
,
config
.
vocab_size
,
config_data
.
stop_token_ids
))
else
:
tokenizer_info
=
xgr
.
TokenizerInfo
(
config_data
.
encoded_vocab
,
config_data
.
vocab_type
,
vocab_size
=
config
.
vocab_size
,
stop_token_ids
=
config_data
.
stop_token_ids
)
cls
.
_cache
[
cache_key
]
=
xgr
.
GrammarCompiler
(
tokenizer_info
,
max_threads
=
config
.
max_threads
)
...
...
@@ -118,10 +157,7 @@ class GrammarConfig:
grammar_str
:
str
|
None
=
None
json_object
:
bool
|
None
=
None
max_threads
:
int
=
8
# Only populated if tokenizer_hash not in cache
encoded_vocab
:
list
[
str
]
|
None
=
None
stop_token_ids
:
list
[
int
]
|
None
=
None
backend_str
:
str
|
None
=
None
tokenizer_data
:
TokenizerData
|
None
=
None
@
classmethod
def
from_guided_params
(
cls
,
...
...
@@ -132,9 +168,6 @@ class GrammarConfig:
tokenizer_hash
=
hash
(
tokenizer
)
tokenizer_data
=
TokenizerDataCache
.
get_tokenizer_data
(
tokenizer
)
encoded_vocab
=
tokenizer_data
.
encoded_vocab
stop_token_ids
=
tokenizer_data
.
stop_token_ids
backend_str
=
tokenizer_data
.
backend_str
if
guided_params
.
json
:
if
not
isinstance
(
guided_params
.
json
,
str
):
...
...
@@ -152,11 +185,9 @@ class GrammarConfig:
return
cls
(
json_str
=
json_str
,
vocab_size
=
model_config
.
hf_text_config
.
vocab_size
,
encoded_vocab
=
encoded_vocab
,
stop_token_ids
=
stop_token_ids
,
backend_str
=
backend_str
,
tokenizer_hash
=
tokenizer_hash
,
max_threads
=
max_threads
)
max_threads
=
max_threads
,
tokenizer_data
=
tokenizer_data
)
elif
guided_params
.
grammar
:
# XGrammar only supports GBNF grammars, so we must convert Lark
if
grammar_is_likely_lark
(
guided_params
.
grammar
):
...
...
@@ -181,19 +212,17 @@ class GrammarConfig:
return
cls
(
grammar_str
=
grammar_str
,
vocab_size
=
model_config
.
hf_text_config
.
vocab_size
,
encoded_vocab
=
encoded_vocab
,
stop_token_ids
=
stop_token_ids
,
backend_str
=
backend_str
,
tokenizer_hash
=
tokenizer_hash
,
max_threads
=
max_threads
)
max_threads
=
max_threads
,
tokenizer_data
=
tokenizer_data
)
elif
guided_params
.
json_object
:
return
cls
(
json_object
=
True
,
vocab_size
=
model_config
.
hf_text_config
.
vocab_siz
e
,
encoded_vocab
=
encoded_vocab
,
stop_token_ids
=
stop_token_ids
,
backend_str
=
backend_str
,
tokenizer_
hash
=
tokenizer_
hash
,
max_threads
=
max_threads
)
return
cls
(
json_object
=
Tru
e
,
vocab_size
=
model_config
.
hf_text_config
.
vocab_size
,
tokenizer_hash
=
tokenizer_hash
,
max_threads
=
max_threads
,
tokenizer_
data
=
tokenizer_
data
,
)
else
:
raise
ValueError
(
"Currently only support JSON and EBNF grammar mode for xgrammar"
...
...
@@ -269,10 +298,14 @@ class XGrammarLogitsProcessor:
# fill_next_token_bitmask so we move it to the device of scores
device_type
=
scores
.
device
.
type
if
device_type
!=
"cuda"
:
scores
=
scores
.
to
(
"cpu"
)
scores
=
scores
.
to
(
"cpu"
).
unsqueeze
(
0
)
# Note: In this method, if the tensors have different dimensions
# on CPU device fails, but on GPU it runs without error. Hence the
# unsqueeze above for scores, to match the token bitmask shape
xgr
.
apply_token_bitmask_inplace
(
scores
,
self
.
token_bitmask
.
to
(
scores
.
device
))
if
device_type
!=
"cuda"
:
scores
=
scores
.
to
(
device_type
)
scores
=
scores
.
to
(
device_type
)
.
squeeze
()
return
scores
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
96ae75ad
...
...
@@ -2,7 +2,7 @@
import
functools
import
json
import
os
from
typing
import
Any
,
Callable
,
Dict
,
Optional
,
Tuple
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
import
torch
import
triton
...
...
@@ -11,6 +11,8 @@ import triton.language as tl
import
vllm.envs
as
envs
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
per_token_group_quant_fp8
)
from
vllm.platforms
import
current_platform
from
vllm.utils
import
direct_register_custom_op
...
...
@@ -45,8 +47,14 @@ def fused_moe_kernel(
stride_bn
,
stride_cm
,
stride_cn
,
stride_asm
,
stride_ask
,
stride_bse
,
stride_bsk
,
stride_bsn
,
# Block size for block-wise quantization
group_n
:
tl
.
constexpr
,
group_k
:
tl
.
constexpr
,
# Meta-parameters
BLOCK_SIZE_M
:
tl
.
constexpr
,
BLOCK_SIZE_N
:
tl
.
constexpr
,
...
...
@@ -125,8 +133,14 @@ def fused_moe_kernel(
b_scale
=
tl
.
load
(
b_scale_ptrs
)
if
use_fp8_w8a8
:
a_scale
=
tl
.
load
(
a_scale_ptr
)
b_scale
=
tl
.
load
(
b_scale_ptr
+
off_experts
)
if
group_k
>
0
and
group_n
>
0
:
a_scale_ptrs
=
a_scale_ptr
+
(
offs_token
//
top_k
)
*
stride_asm
offs_bsn
=
offs_bn
//
group_n
b_scale_ptrs
=
(
b_scale_ptr
+
off_experts
*
stride_bse
+
offs_bsn
*
stride_bsn
)
else
:
a_scale
=
tl
.
load
(
a_scale_ptr
)
b_scale
=
tl
.
load
(
b_scale_ptr
+
off_experts
)
# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
...
...
@@ -149,7 +163,18 @@ def fused_moe_kernel(
if
use_int8_w8a16
:
accumulator
=
tl
.
dot
(
a
,
b
.
to
(
compute_type
),
acc
=
accumulator
)
elif
use_fp8_w8a8
:
accumulator
=
tl
.
dot
(
a
,
b
,
acc
=
accumulator
)
if
group_k
>
0
and
group_n
>
0
:
k_start
=
k
*
BLOCK_SIZE_K
offs_ks
=
k_start
//
group_k
a_scale
=
tl
.
load
(
a_scale_ptrs
+
offs_ks
*
stride_ask
,
mask
=
token_mask
,
other
=
0.0
)
b_scale
=
tl
.
load
(
b_scale_ptrs
+
offs_ks
*
stride_bsk
)
accumulator
+=
tl
.
dot
(
a
,
b
)
*
a_scale
[:,
None
]
*
b_scale
[
None
,
:]
else
:
accumulator
=
tl
.
dot
(
a
,
b
,
acc
=
accumulator
)
else
:
accumulator
+=
tl
.
dot
(
a
,
b
)
# Advance the ptrs to the next K block.
...
...
@@ -164,7 +189,10 @@ def fused_moe_kernel(
if
use_int8_w8a16
:
accumulator
=
(
accumulator
*
b_scale
).
to
(
compute_type
)
elif
use_fp8_w8a8
:
accumulator
=
(
accumulator
*
a_scale
*
b_scale
).
to
(
compute_type
)
if
group_k
>
0
and
group_n
>
0
:
accumulator
=
accumulator
.
to
(
compute_type
)
else
:
accumulator
=
(
accumulator
*
a_scale
*
b_scale
).
to
(
compute_type
)
else
:
accumulator
=
accumulator
.
to
(
compute_type
)
# -----------------------------------------------------------
...
...
@@ -233,22 +261,37 @@ def moe_align_block_size(
return
sorted_ids
,
expert_ids
,
num_tokens_post_pad
def
invoke_fused_moe_kernel
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
C
:
torch
.
Tensor
,
def
invoke_fused_moe_kernel
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
C
:
torch
.
Tensor
,
A_scale
:
Optional
[
torch
.
Tensor
],
B_scale
:
Optional
[
torch
.
Tensor
],
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
sorted_token_ids
:
torch
.
Tensor
,
expert_ids
:
torch
.
Tensor
,
num_tokens_post_padded
:
torch
.
Tensor
,
mul_routed_weight
:
bool
,
top_k
:
int
,
config
:
Dict
[
str
,
Any
],
compute_type
:
tl
.
dtype
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
)
->
None
:
mul_routed_weight
:
bool
,
top_k
:
int
,
config
:
Dict
[
str
,
Any
],
compute_type
:
tl
.
dtype
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
block_shape
:
Optional
[
List
[
int
]]
=
None
)
->
None
:
assert
topk_weights
.
stride
(
1
)
==
1
assert
sorted_token_ids
.
stride
(
0
)
==
1
if
use_fp8_w8a8
:
A
,
A_scale
=
ops
.
scaled_fp8_quant
(
A
,
A_scale
)
assert
B_scale
is
not
None
if
block_shape
is
None
:
A
,
A_scale
=
ops
.
scaled_fp8_quant
(
A
,
A_scale
)
else
:
assert
len
(
block_shape
)
==
2
block_n
,
block_k
=
block_shape
[
0
],
block_shape
[
1
]
A
,
A_scale
=
per_token_group_quant_fp8
(
A
,
block_k
)
assert
triton
.
cdiv
(
A
.
shape
[
-
1
],
block_k
)
==
A_scale
.
shape
[
-
1
]
assert
triton
.
cdiv
(
B
.
shape
[
-
2
],
block_n
)
==
B_scale
.
shape
[
-
2
]
assert
triton
.
cdiv
(
B
.
shape
[
-
1
],
block_k
)
==
B_scale
.
shape
[
-
1
]
elif
use_int8_w8a16
:
assert
B_scale
is
not
None
else
:
...
...
@@ -279,8 +322,13 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
B
.
stride
(
1
),
C
.
stride
(
1
),
C
.
stride
(
2
),
B_scale
.
stride
(
0
)
if
B_scale
is
not
None
and
use_int8_w8a16
else
0
,
B_scale
.
stride
(
1
)
if
B_scale
is
not
None
and
use_int8_w8a16
else
0
,
A_scale
.
stride
(
0
)
if
A_scale
is
not
None
and
A_scale
.
ndim
==
2
else
0
,
A_scale
.
stride
(
1
)
if
A_scale
is
not
None
and
A_scale
.
ndim
==
2
else
0
,
B_scale
.
stride
(
0
)
if
B_scale
is
not
None
and
B_scale
.
ndim
>=
2
else
0
,
B_scale
.
stride
(
2
)
if
B_scale
is
not
None
and
B_scale
.
ndim
==
3
else
0
,
B_scale
.
stride
(
1
)
if
B_scale
is
not
None
and
B_scale
.
ndim
>=
2
else
0
,
0
if
block_shape
is
None
else
block_shape
[
0
],
0
if
block_shape
is
None
else
block_shape
[
1
],
MUL_ROUTED_WEIGHT
=
mul_routed_weight
,
top_k
=
top_k
,
compute_type
=
compute_type
,
...
...
@@ -362,6 +410,7 @@ def try_get_optimal_moe_config(
dtype
:
Optional
[
str
],
M
:
int
,
is_marlin
:
bool
=
False
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
):
from
vllm.model_executor.layers.fused_moe
import
get_config
override_config
=
get_config
()
...
...
@@ -380,6 +429,12 @@ def try_get_optimal_moe_config(
# Else use the default config
config
=
get_default_config
(
M
,
E
,
N
,
w1_shape
[
2
],
top_k
,
dtype
,
is_marlin
)
# NOTE: For block-wise quant,
# BLOCK_K must be divisible by block_shape[1]
# BLOCK_N and BLOCK_M has no requirements
if
block_shape
is
not
None
:
config
[
"BLOCK_SIZE_N"
]
=
block_shape
[
0
]
config
[
"BLOCK_SIZE_K"
]
=
block_shape
[
1
]
return
config
...
...
@@ -421,18 +476,29 @@ def fused_topk(
return
topk_weights
,
topk_ids
# This is used by the Deepseek-V2 model
# This is used by the Deepseek-V2
and Deepseek-V3
model
def
grouped_topk
(
hidden_states
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
num_expert_group
:
int
=
0
,
topk_group
:
int
=
0
):
topk_group
:
int
=
0
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
):
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
(
"Number of tokens mismatch"
)
scores
=
torch
.
softmax
(
gating_output
,
dim
=-
1
)
if
scoring_func
==
"softmax"
:
scores
=
torch
.
softmax
(
gating_output
,
dim
=-
1
)
elif
scoring_func
==
"sigmoid"
:
scores
=
gating_output
.
sigmoid
()
else
:
raise
ValueError
(
f
"Unsupported scoring function:
{
scoring_func
}
"
)
if
e_score_correction_bias
is
not
None
:
scores
.
add_
(
e_score_correction_bias
.
unsqueeze
(
0
))
num_token
=
scores
.
shape
[
0
]
group_scores
=
scores
.
view
(
num_token
,
num_expert_group
,
-
1
).
max
(
dim
=-
1
).
values
# [n, n_group]
...
...
@@ -479,10 +545,11 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
)
->
None
:
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
)
->
None
:
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
True
,
use_fp8_w8a8
,
use_int8_w8a16
,
w1_scale
,
w2_scale
,
a1_scale
,
a2_scale
)
a1_scale
,
a2_scale
,
block_shape
)
def
inplace_fused_experts_fake
(
...
...
@@ -496,7 +563,8 @@ def inplace_fused_experts_fake(
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
)
->
None
:
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
)
->
None
:
pass
...
...
@@ -519,10 +587,11 @@ def outplace_fused_experts(
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
)
->
torch
.
Tensor
:
return
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
False
,
use_fp8_w8a8
,
use_int8_w8a16
,
w1_scale
,
w2_scale
,
a1_scale
,
a2_scale
)
w2_scale
,
a1_scale
,
a2_scale
,
block_shape
)
def
outplace_fused_experts_fake
(
...
...
@@ -536,7 +605,8 @@ def outplace_fused_experts_fake(
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
hidden_states
)
...
...
@@ -559,18 +629,22 @@ def fused_experts(hidden_states: torch.Tensor,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
):
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
):
if
inplace
:
torch
.
ops
.
vllm
.
inplace_fused_experts
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
use_fp8_w8a8
,
use_int8_w8a16
,
w1_scale
,
w2_scale
,
a1_scale
,
a2_scale
)
a2_scale
,
block_shape
)
return
hidden_states
else
:
return
torch
.
ops
.
vllm
.
outplace_fused_experts
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
use_fp8_w8a8
,
use_int8_w8a16
,
w1_scale
,
w2_scale
,
a1_scale
,
a2_scale
)
return
torch
.
ops
.
vllm
.
outplace_fused_experts
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
use_fp8_w8a8
,
use_int8_w8a16
,
w1_scale
,
w2_scale
,
a1_scale
,
a2_scale
,
block_shape
)
def
fused_experts_impl
(
hidden_states
:
torch
.
Tensor
,
...
...
@@ -584,7 +658,8 @@ def fused_experts_impl(hidden_states: torch.Tensor,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
):
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
):
# Check constraints.
assert
hidden_states
.
shape
[
1
]
==
w1
.
shape
[
2
],
"Hidden size mismatch"
assert
topk_weights
.
shape
==
topk_ids
.
shape
,
"topk shape mismatch"
...
...
@@ -611,6 +686,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
w2
.
shape
,
topk_ids
.
shape
[
1
],
config_dtype
,
block_shape
=
block_shape
,
)
config
=
get_config_func
(
M
)
...
...
@@ -674,7 +750,8 @@ def fused_experts_impl(hidden_states: torch.Tensor,
config
,
compute_type
=
compute_type
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
)
use_int8_w8a16
=
use_int8_w8a16
,
block_shape
=
block_shape
)
ops
.
silu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
...
...
@@ -693,7 +770,8 @@ def fused_experts_impl(hidden_states: torch.Tensor,
config
,
compute_type
=
compute_type
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
)
use_int8_w8a16
=
use_int8_w8a16
,
block_shape
=
block_shape
)
ops
.
moe_sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
])
...
...
@@ -718,6 +796,7 @@ def fused_moe(
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
)
->
torch
.
Tensor
:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
...
...
@@ -745,6 +824,12 @@ def fused_moe(
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.
- a1_scale (Optional[torch.Tensor]): Optional scale to be used for
a1.
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for
a2.
- block_shape: (Optional[List[int]]): Optional block size for block-wise
quantization.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
...
...
@@ -775,4 +860,5 @@ def fused_moe(
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
)
\ No newline at end of file
a2_scale
=
a2_scale
,
block_shape
=
block_shape
)
vllm/model_executor/layers/fused_moe/layer.py
View file @
96ae75ad
...
...
@@ -29,6 +29,7 @@ class FusedMoeWeightScaleSupported(Enum):
TENSOR
=
"tensor"
CHANNEL
=
"channel"
GROUP
=
"group"
BLOCK
=
"block"
class
FusedMoEMethodBase
(
QuantizeMethodBase
):
...
...
@@ -40,9 +41,20 @@ class FusedMoEMethodBase(QuantizeMethodBase):
raise
NotImplementedError
@
abstractmethod
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
)
->
torch
.
Tensor
:
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
raise
NotImplementedError
...
...
@@ -72,16 +84,18 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
return
self
.
forward
(
x
=
x
,
layer
=
layer
,
...
...
@@ -91,19 +105,23 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
use_grouped_topk
=
use_grouped_topk
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
)
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
)
def
forward_cuda
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
use_grouped_topk
:
bool
,
top_k
:
int
,
router_logits
:
torch
.
Tensor
,
renormalize
:
bool
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
use_grouped_topk
:
bool
,
top_k
:
int
,
router_logits
:
torch
.
Tensor
,
renormalize
:
bool
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
...
...
@@ -113,7 +131,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
)
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
)
return
fused_experts
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
...
...
@@ -127,21 +147,29 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
"The CPU backend currently does not support MoE."
)
def
forward_tpu
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
use_grouped_topk
:
bool
,
top_k
:
int
,
router_logits
:
torch
.
Tensor
,
renormalize
:
bool
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
use_grouped_topk
:
bool
,
top_k
:
int
,
router_logits
:
torch
.
Tensor
,
renormalize
:
bool
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
assert
not
use_grouped_topk
assert
num_expert_group
is
None
assert
topk_group
is
None
assert
custom_routing_function
is
None
if
scoring_func
!=
"softmax"
:
raise
NotImplementedError
(
"Only softmax scoring function is supported for TPU."
)
if
e_score_correction_bias
is
not
None
:
raise
NotImplementedError
(
"Expert score correction bias is not supported for TPU."
)
return
fused_moe_pallas
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
...
...
@@ -155,7 +183,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
class
FusedMoE
(
torch
.
nn
.
Module
):
"""FusedMoE layer for MoE models.
This layer contains both MergedColumnParallel weights (gate_up_proj /
This layer contains both MergedColumnParallel weights (gate_up_proj /
w13) and RowParallelLinear weights (down_proj/ w2).
Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We
...
...
@@ -189,6 +217,8 @@ class FusedMoE(torch.nn.Module):
tp_size
:
Optional
[
int
]
=
None
,
prefix
:
str
=
""
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
):
super
().
__init__
()
...
...
@@ -199,6 +229,7 @@ class FusedMoE(torch.nn.Module):
get_tensor_model_parallel_world_size
())
self
.
top_k
=
top_k
self
.
num_experts
=
num_experts
assert
intermediate_size
%
self
.
tp_size
==
0
self
.
intermediate_size_per_partition
=
intermediate_size
//
self
.
tp_size
self
.
reduce_results
=
reduce_results
self
.
renormalize
=
renormalize
...
...
@@ -208,6 +239,12 @@ class FusedMoE(torch.nn.Module):
self
.
num_expert_group
=
num_expert_group
self
.
topk_group
=
topk_group
self
.
custom_routing_function
=
custom_routing_function
self
.
scoring_func
=
scoring_func
self
.
e_score_correction_bias
=
e_score_correction_bias
if
self
.
scoring_func
!=
"softmax"
and
not
self
.
use_grouped_topk
:
raise
ValueError
(
"Only softmax scoring function is supported for "
"non-grouped topk."
)
if
quant_config
is
None
:
self
.
quant_method
:
Optional
[
QuantizeMethodBase
]
=
(
...
...
@@ -398,7 +435,10 @@ class FusedMoE(torch.nn.Module):
loaded_weight
=
loaded_weight
,
expert_data
=
expert_data
,
tp_rank
=
tp_rank
)
elif
quant_method
==
FusedMoeWeightScaleSupported
.
GROUP
.
value
:
elif
quant_method
in
[
FusedMoeWeightScaleSupported
.
GROUP
.
value
,
FusedMoeWeightScaleSupported
.
BLOCK
.
value
,
]:
self
.
_load_model_weight_or_group_weight_scale
(
shard_id
=
shard_id
,
shard_dim
=
shard_dim
,
...
...
@@ -441,7 +481,9 @@ class FusedMoE(torch.nn.Module):
renormalize
:
bool
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
):
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
):
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_topk
,
grouped_topk
)
...
...
@@ -455,7 +497,9 @@ class FusedMoE(torch.nn.Module):
topk
=
top_k
,
renormalize
=
renormalize
,
num_expert_group
=
num_expert_group
,
topk_group
=
topk_group
)
topk_group
=
topk_group
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
)
elif
custom_routing_function
is
None
:
topk_weights
,
topk_ids
=
fused_topk
(
hidden_states
=
hidden_states
,
gating_output
=
router_logits
,
...
...
@@ -484,7 +528,9 @@ class FusedMoE(torch.nn.Module):
use_grouped_topk
=
self
.
use_grouped_topk
,
topk_group
=
self
.
topk_group
,
num_expert_group
=
self
.
num_expert_group
,
custom_routing_function
=
self
.
custom_routing_function
)
custom_routing_function
=
self
.
custom_routing_function
,
scoring_func
=
self
.
scoring_func
,
e_score_correction_bias
=
self
.
e_score_correction_bias
)
if
self
.
reduce_results
and
self
.
tp_size
>
1
:
final_hidden_states
=
tensor_model_parallel_all_reduce
(
...
...
vllm/model_executor/layers/linear.py
View file @
96ae75ad
...
...
@@ -14,11 +14,14 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
# yapf: disable
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
BlockQuantScaleParameter
,
PackedColumnParameter
,
PackedvLLMParameter
,
PerTensorScaleParameter
,
RowvLLMParameter
)
# yapf: enable
from
vllm.model_executor.utils
import
set_weight_attrs
import
os
...
...
@@ -642,8 +645,24 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
assert
loaded_shard_id
<
len
(
self
.
output_sizes
)
tp_size
=
get_tensor_model_parallel_world_size
()
shard_offset
=
sum
(
self
.
output_sizes
[:
loaded_shard_id
])
//
tp_size
shard_size
=
self
.
output_sizes
[
loaded_shard_id
]
//
tp_size
if
isinstance
(
param
,
BlockQuantScaleParameter
):
from
vllm.model_executor.layers.quantization.fp8
import
(
Fp8LinearMethod
,
Fp8MoEMethod
)
assert
self
.
quant_method
is
not
None
assert
isinstance
(
self
.
quant_method
,
(
Fp8LinearMethod
,
Fp8MoEMethod
))
weight_block_size
=
self
.
quant_method
.
quant_config
.
weight_block_size
assert
weight_block_size
is
not
None
block_n
,
_
=
weight_block_size
[
0
],
weight_block_size
[
1
]
shard_offset
=
(
(
sum
(
self
.
output_sizes
[:
loaded_shard_id
])
+
block_n
-
1
)
//
block_n
)
//
tp_size
shard_size
=
((
self
.
output_sizes
[
loaded_shard_id
]
+
block_n
-
1
)
//
block_n
//
tp_size
)
else
:
shard_offset
=
sum
(
self
.
output_sizes
[:
loaded_shard_id
])
//
tp_size
shard_size
=
self
.
output_sizes
[
loaded_shard_id
]
//
tp_size
param
.
load_merged_column_weight
(
loaded_weight
=
loaded_weight
,
shard_id
=
loaded_shard_id
,
...
...
vllm/model_executor/layers/quantization/awq_marlin.py
View file @
96ae75ad
...
...
@@ -440,11 +440,13 @@ class AWQMoEMethod(FusedMoEMethodBase):
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
=
True
,
renormalize
:
bool
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
...
...
@@ -454,7 +456,9 @@ class AWQMoEMethod(FusedMoEMethodBase):
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
)
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
)
return
torch
.
ops
.
vllm
.
fused_marlin_moe
(
x
,
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
96ae75ad
from
typing
import
Any
,
Dict
,
List
,
Optional
,
cast
from
typing
import
Any
,
Dict
,
List
,
Literal
,
Optional
,
cast
import
torch
from
compressed_tensors.config
import
CompressionFormat
from
compressed_tensors.config
import
(
CompressionFormat
,
SparsityCompressionConfig
,
SparsityStructure
)
from
compressed_tensors.quantization
import
(
QuantizationArgs
,
QuantizationStrategy
,
QuantizationType
)
...
...
@@ -15,7 +17,7 @@ from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
from
vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe
import
(
# noqa: E501
CompressedTensorsMoEMethod
)
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
W4A16SPARSE24_SUPPORTED_BITS
,
WNA16_SUPPORTED_BITS
,
W4A16SPARSE24_SUPPORTED_BITS
,
WNA16_SUPPORTED_BITS
,
CompressedTensors24
,
CompressedTensorsScheme
,
CompressedTensorsW4A16Sparse24
,
CompressedTensorsW8A8Fp8
,
CompressedTensorsW8A8Int8
,
CompressedTensorsW8A16Fp8
,
CompressedTensorsWNA16
)
...
...
@@ -27,20 +29,29 @@ from vllm.platforms import current_platform
__all__
=
[
"CompressedTensorsLinearMethod"
]
SPARSITY_CONFIG_NAME
:
Literal
[
"sparsity_config"
]
=
"sparsity_config"
QUANTIZATION_SCHEME_MAP_TYPE
=
Dict
[
str
,
Optional
[
Dict
[
str
,
QuantizationArgs
]]]
class
CompressedTensorsConfig
(
QuantizationConfig
):
def
__init__
(
self
,
target_scheme_map
:
Dict
[
str
,
Any
],
ignore
:
List
[
str
],
quant_format
:
str
,
kv_cache_scheme
:
Optional
[
Dict
[
str
,
Any
]]
=
None
):
def
__init__
(
self
,
target_scheme_map
:
Dict
[
str
,
Any
],
ignore
:
List
[
str
],
quant_format
:
str
,
sparsity_scheme_map
:
Dict
[
str
,
SparsityCompressionConfig
],
kv_cache_scheme
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
):
self
.
ignore
=
ignore
self
.
quant_format
=
quant_format
# Map from [target -> scheme]
self
.
target_scheme_map
=
target_scheme_map
self
.
kv_cache_scheme
=
kv_cache_scheme
self
.
sparsity_scheme_map
=
sparsity_scheme_map
self
.
config
=
config
def
get_linear_method
(
self
)
->
"CompressedTensorsLinearMethod"
:
return
CompressedTensorsLinearMethod
(
self
)
...
...
@@ -78,8 +89,50 @@ class CompressedTensorsConfig(QuantizationConfig):
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"CompressedTensorsConfig"
:
ignore
:
List
[
str
]
=
cast
(
List
[
str
],
config
.
get
(
"ignore"
,
[]))
quant_format
=
cast
(
str
,
config
.
get
(
"format"
))
target_scheme_map
=
cls
.
_quantization_scheme_map_from_config
(
config
=
config
)
sparsity_scheme_map
=
cls
.
_sparsity_scheme_map_from_config
(
config
=
config
)
return
cls
(
target_scheme_map
=
target_scheme_map
,
ignore
=
ignore
,
quant_format
=
quant_format
,
sparsity_scheme_map
=
sparsity_scheme_map
,
config
=
config
,
)
@
classmethod
def
_sparsity_scheme_map_from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
Dict
[
str
,
SparsityCompressionConfig
]:
"""
:param config: The `quantization_config` dictionary from config.json
:return: A dictionary mapping target layer names to their corresponding
sparsity compression configurations
"""
if
(
sparsity_config
:
=
config
.
get
(
SPARSITY_CONFIG_NAME
))
is
None
:
return
dict
()
sparsity_config
=
SparsityCompressionConfig
.
model_validate
(
sparsity_config
)
sparse_scheme_map
:
Dict
[
str
,
SparsityCompressionConfig
]
=
{
target
:
sparsity_config
for
target
in
sparsity_config
.
targets
or
list
()
}
return
sparse_scheme_map
@
classmethod
def
_quantization_scheme_map_from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
QUANTIZATION_SCHEME_MAP_TYPE
:
"""
:param config: The `quantization_config` dictionary from config.json
:return: A dictionary mapping target layer names to their corresponding
quantization_args for weights and input activations
"""
target_scheme_map
:
Dict
[
str
,
Any
]
=
dict
()
ignore
=
cast
(
List
[
str
],
config
.
get
(
"ignore"
))
quant_format
=
cast
(
str
,
config
.
get
(
"format"
))
# The quant_config has multiple config_groups, each containing
...
...
@@ -90,12 +143,14 @@ class CompressedTensorsConfig(QuantizationConfig):
# details follow the structure defined by the QuantizationArgs
# pydantic model, which is used to verify the structure of the
# quant_config and also store the details for later use.
for
_
,
quant_config
in
config
[
"config_groups"
].
items
():
config_groups
=
config
.
get
(
"config_groups"
,
dict
())
for
_
,
quant_config
in
config_groups
.
items
():
targets
=
quant_config
.
get
(
"targets"
)
for
target
in
targets
:
target_scheme_map
[
target
]
=
{}
target_scheme_map
[
target
][
"weights"
]
=
QuantizationArgs
.
parse_obj
(
"weights"
]
=
QuantizationArgs
.
model_validate
(
quant_config
.
get
(
"weights"
))
target_scheme_map
[
target
][
"input_activations"
]
=
None
...
...
@@ -110,13 +165,9 @@ class CompressedTensorsConfig(QuantizationConfig):
"weights"
].
type
==
QuantizationType
.
FLOAT
else
:
target_scheme_map
[
target
][
"input_activations"
]
=
QuantizationArgs
.
parse_obj
(
"input_activations"
]
=
QuantizationArgs
.
model_validate
(
# noqa: E501
quant_config
.
get
(
"input_activations"
))
return
cls
(
target_scheme_map
=
target_scheme_map
,
ignore
=
ignore
,
quant_format
=
quant_format
,
kv_cache_scheme
=
config
.
get
(
"kv_cache_scheme"
))
return
target_scheme_map
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
...
...
@@ -315,23 +366,105 @@ class CompressedTensorsConfig(QuantizationConfig):
# TODO (@robertgshaw): add compressed-tensors as dep
# so we do not have to re-write these functions
# need to make accelerate optional in ct to do this
matched_target
=
find_matched_target
(
layer_name
=
layer_name
,
module
=
layer
,
targets
=
self
.
target_scheme_map
.
keys
())
# Find the quant_scheme
scheme_dict
=
self
.
target_scheme_map
[
matched_target
]
scheme
=
self
.
_get_scheme_from_parts
(
weight_quant
=
scheme_dict
[
"weights"
],
input_quant
=
scheme_dict
[
"input_activations"
])
# Will be empty for models with only sparsity
if
self
.
target_scheme_map
:
matched_target
=
find_matched_target
(
layer_name
=
layer_name
,
module
=
layer
,
targets
=
self
.
target_scheme_map
.
keys
())
scheme_dict
=
self
.
target_scheme_map
[
matched_target
]
weight_quant
=
scheme_dict
.
get
(
"weights"
)
input_quant
=
scheme_dict
.
get
(
"input_activations"
)
elif
self
.
sparsity_scheme_map
:
matched_target
=
find_matched_target
(
layer_name
=
layer_name
,
module
=
layer
,
targets
=
self
.
sparsity_scheme_map
.
keys
())
weight_quant
=
None
input_quant
=
None
# For models with sparsity, assumes that the sparse layers are also
# quantized for cutlass 2:4 support
sparsity_scheme
:
Optional
[
SparsityCompressionConfig
]
=
self
.
sparsity_scheme_map
.
get
(
matched_target
)
if
self
.
supports_cutlass_24
(
weight_quant
=
weight_quant
,
input_quant
=
input_quant
,
sparsity_scheme
=
sparsity_scheme
):
# Have a valid sparsity scheme
# Validate layer is supported by Cutlass 2:4 Kernel
scheme
=
CompressedTensors24
(
quantized
=
weight_quant
is
not
None
or
input_quant
is
not
None
,
weight_quant
=
weight_quant
,
input_quant
=
input_quant
)
else
:
# Find the quant_scheme
scheme
=
self
.
_get_scheme_from_parts
(
# type: ignore
weight_quant
=
weight_quant
,
input_quant
=
input_quant
,
)
# Raise error if device does not support the scheme
# (e.g. fp8 needs ada lovelace)
self
.
_check_scheme_supported
(
scheme
.
get_min_capability
())
return
scheme
@
staticmethod
def
supports_cutlass_24
(
weight_quant
:
Optional
[
QuantizationArgs
],
input_quant
:
Optional
[
QuantizationArgs
],
sparsity_scheme
:
Optional
[
SparsityCompressionConfig
]
=
None
)
->
bool
:
"""
Check if the layer is supported by the Cutlass 2:4 Kernel
Conditions:
- Overarching condition: Sparsity Structure is 2:4
- Unquantized cases are supported
- Weight only quantization is not-supported
- Supported weight quantization strategies are TENSOR and CHANNEL
- Supported input quantization strategies are TENSOR and TOKEN
- Only 8 bit quantization is supported
:return: True if the layer is supported by the Cutlass 2:4 Kernel
False otherwise
"""
is_valid_sparsity
=
(
sparsity_scheme
is
not
None
and
sparsity_scheme
.
sparsity_structure
==
SparsityStructure
.
TWO_FOUR
.
value
and
sparsity_scheme
.
format
==
"dense"
)
if
not
is_valid_sparsity
:
return
False
# Unquantized cases are supported
if
weight_quant
is
None
and
input_quant
is
None
:
return
True
# Weight only quantization is not-supported
if
weight_quant
is
not
None
and
input_quant
is
None
:
return
False
supported_weight_quant_strategies
=
[
QuantizationStrategy
.
TENSOR
.
value
,
QuantizationStrategy
.
CHANNEL
.
value
]
assert
weight_quant
is
not
None
assert
input_quant
is
not
None
if
weight_quant
.
strategy
not
in
supported_weight_quant_strategies
:
return
False
supported_input_quant_strategies
=
[
QuantizationStrategy
.
TENSOR
.
value
,
QuantizationStrategy
.
TOKEN
.
value
]
if
input_quant
.
strategy
not
in
supported_input_quant_strategies
:
return
False
return
weight_quant
.
num_bits
==
input_quant
.
num_bits
==
8
class
CompressedTensorsLinearMethod
(
LinearMethodBase
):
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
96ae75ad
...
...
@@ -203,13 +203,14 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
=
True
,
renormalize
:
bool
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
...
...
@@ -220,7 +221,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
)
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
)
return
fused_experts
(
x
,
layer
.
w13_weight
,
...
...
@@ -476,12 +479,15 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
=
True
,
renormalize
:
bool
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
...
...
@@ -490,7 +496,9 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
)
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
)
return
torch
.
ops
.
vllm
.
fused_marlin_moe
(
x
,
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py
View file @
96ae75ad
...
...
@@ -7,13 +7,12 @@ from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8
from
.compressed_tensors_wNa16
import
(
WNA16_SUPPORTED_BITS
,
CompressedTensorsWNA16
)
from
.compressed_tensors_24
import
CompressedTensors24
# isort: skip
__all__
=
[
"CompressedTensorsScheme"
,
"CompressedTensorsWNA16"
,
"CompressedTensorsW8A16Fp8"
,
"CompressedTensorsW4A16Sparse24"
,
"CompressedTensorsW8A8Int8"
,
"CompressedTensorsW8A8Fp8"
,
"WNA16_SUPPORTED_BITS"
,
"W4A16SPARSE24_SUPPORTED_BITS"
,
"CompressedTensorsScheme"
,
"CompressedTensorsWNA16"
,
"CompressedTensorsW8A16Fp8"
,
"CompressedTensorsW4A16Sparse24"
,
"CompressedTensorsW8A8Int8"
,
"CompressedTensorsW8A8Fp8"
,
"WNA16_SUPPORTED_BITS"
,
"W4A16SPARSE24_SUPPORTED_BITS"
,
"CompressedTensors24"
]
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py
0 → 100644
View file @
96ae75ad
from
typing
import
Callable
,
List
,
Optional
import
torch
from
compressed_tensors.quantization
import
(
QuantizationArgs
,
QuantizationStrategy
,
QuantizationType
)
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
convert_to_channelwise
,
sparse_cutlass_supported
)
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
ChannelQuantScaleParameter
,
ModelWeightParameter
,
PerTensorScaleParameter
)
__all__
=
[
"CompressedTensors24"
]
class
CompressedTensors24
(
CompressedTensorsScheme
):
def
__init__
(
self
,
quantized
:
bool
=
False
,
weight_quant
:
Optional
[
QuantizationArgs
]
=
None
,
input_quant
:
Optional
[
QuantizationArgs
]
=
None
):
self
.
quantized
=
quantized
self
.
weight_quant
=
weight_quant
self
.
input_quant
=
input_quant
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
# Only cutlass 3.x kernels are implemented so far
return
90
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size
:
int
,
output_partition_sizes
:
List
[
int
],
input_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
weight_loader
:
Callable
,
**
kwargs
):
if
not
sparse_cutlass_supported
():
raise
ValueError
(
"Sparse CUTLASS not supported. vLLM must be built with"
"CUDA 12.2 or later to use this feature"
)
self
.
output_dtype
=
params_dtype
layer
.
logical_widths
=
output_partition_sizes
self
.
weights_dtype
:
torch
.
dtype
=
self
.
_get_params_dtype
(
params_dtype
)
# parameter to store uncompressed weight
weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
sum
(
output_partition_sizes
),
input_size_per_partition
,
dtype
=
self
.
weights_dtype
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
)
# Check if quantized, not just 2:4 Sparse
if
self
.
quantized
:
if
(
self
.
weight_quant
and
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
.
value
):
weight_scale
=
ChannelQuantScaleParameter
(
data
=
torch
.
empty
((
sum
(
output_partition_sizes
),
1
),
dtype
=
torch
.
float32
),
output_dim
=
0
,
weight_loader
=
weight_loader
)
else
:
assert
(
self
.
weight_quant
and
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
.
value
)
weight_scale
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
# input quant will be non-none
if
self
.
input_quant
and
not
self
.
input_quant
.
dynamic
:
# register input quant scale
assert
(
self
.
input_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
.
value
)
input_scale
=
BasevLLMParameter
(
data
=
torch
.
empty
(
1
,
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"input_scale"
,
input_scale
)
else
:
# for sparse-only, pass in 1 for weight/input scales
weight_scale
=
torch
.
nn
.
Parameter
(
data
=
torch
.
ones
(
1
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
input_scale
=
torch
.
nn
.
Parameter
(
data
=
torch
.
ones
(
1
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"input_scale"
,
input_scale
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
layer
.
register_parameter
(
"weight"
,
weight
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
"""
Compress weights after loading. Store compressed weight and meta
tensor
:post-condition: layer.w_compressed and layer.meta are
set to the compressed weight and meta tensor in the
format expected by the Cutlass kernels
:param layer: The layer with the weights to be processed
"""
# torch.compile workaround
if
hasattr
(
layer
,
"input_scale"
):
layer
.
input_scale
=
torch
.
nn
.
Parameter
(
layer
.
input_scale
.
data
,
requires_grad
=
False
)
if
self
.
weight_quant
:
if
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
.
value
:
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
convert_to_channelwise
(
weight_scale
=
layer
.
weight_scale
,
logical_widths
=
layer
.
logical_widths
),
requires_grad
=
False
)
else
:
# torch.compile workaround
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
layer
.
weight_scale
.
data
,
requires_grad
=
False
)
w_compressed
,
meta
=
ops
.
cutlass_sparse_compress
(
layer
.
weight
.
data
)
layer
.
weight
=
torch
.
nn
.
Parameter
(
w_compressed
,
requires_grad
=
False
)
layer
.
meta
=
torch
.
nn
.
Parameter
(
meta
,
requires_grad
=
False
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
"""
Returns the output tensor for the layer with 2:4
sparse compressed weights, given the input tensor
and bias
:param layer: The layer with 2:4 sparse compressed
weights to be used for the computation
:param x: The input tensor to the layer
:param bias: The bias to be added to the output tensor
:return: The output tensor of the layer
"""
if
self
.
quantized
:
scale
=
None
if
hasattr
(
layer
,
"input_scale"
):
scale
=
layer
.
input_scale
if
self
.
weights_dtype
==
torch
.
int8
:
ops_output
=
ops
.
scaled_int8_quant
(
x
,
scale
=
scale
)
q_input
=
ops_output
[
0
]
input_scale
=
ops_output
[
1
]
else
:
assert
self
.
weights_dtype
==
torch
.
float8_e4m3fn
if
scale
is
not
None
:
q_input
,
input_scale
=
ops
.
scaled_fp8_quant
(
x
,
scale
=
scale
)
else
:
q_input
,
input_scale
=
ops
.
scaled_fp8_quant
(
x
,
use_per_token_if_dynamic
=
True
)
else
:
# Not quantized, nothing to do with the input_scales, use as is
input_scale
=
layer
.
input_scale
q_input
=
x
out
=
ops
.
cutlass_scaled_sparse_mm
(
a
=
q_input
,
bt_nzs
=
layer
.
weight
,
bt_meta
=
layer
.
meta
,
scale_a
=
input_scale
,
scale_b
=
layer
.
weight_scale
,
out_dtype
=
self
.
output_dtype
,
bias
=
bias
)
assert
out
.
is_contiguous
()
return
out
def
_get_params_dtype
(
self
,
params_dtype
:
torch
.
dtype
)
->
torch
.
dtype
:
if
not
self
.
quantized
:
return
params_dtype
assert
self
.
weight_quant
is
not
None
assert
self
.
input_quant
is
not
None
is_8_bits
=
self
.
weight_quant
.
num_bits
==
self
.
input_quant
.
num_bits
==
8
if
not
is_8_bits
:
raise
ValueError
(
"Cutlass only supports 8-bit quantization"
)
if
(
self
.
weight_quant
.
type
==
QuantizationType
.
FLOAT
and
self
.
input_quant
.
type
==
QuantizationType
.
FLOAT
):
return
torch
.
float8_e4m3fn
if
(
self
.
weight_quant
.
type
==
QuantizationType
.
INT
and
self
.
input_quant
.
type
==
QuantizationType
.
INT
):
return
torch
.
int8
raise
ValueError
(
"Quantization type not supported by Cutlass"
)
def
check_24
(
tensor
):
new_tensor
=
tensor
.
view
(
-
1
,
4
)
zero_counts
=
(
new_tensor
==
0
).
sum
(
dim
=
1
)
return
(
zero_counts
>=
2
).
all
().
item
()
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py
View file @
96ae75ad
...
...
@@ -61,6 +61,10 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
params_dtype
:
torch
.
dtype
,
weight_loader
:
Callable
,
**
kwargs
):
assert
params_dtype
==
torch
.
float16
,
(
"float16 is required for marlin24 compressd models. Set dtype=torch.float16"
# noqa: E501
)
pack_factor
=
32
//
self
.
quant_type
.
size_bits
output_size_per_partition
=
sum
(
output_partition_sizes
)
...
...
vllm/model_executor/layers/quantization/compressed_tensors/utils.py
View file @
96ae75ad
...
...
@@ -30,7 +30,7 @@ def should_ignore_layer(layer_name: Optional[str],
# in the safetensors checkpoint. So, we convert the name
# from the fused version to unfused + check to make sure that
# each shard of the fused layer has the same scheme.
if
proj_name
in
FUSED_LAYER_NAME_MAPPING
:
if
proj_name
in
FUSED_LAYER_NAME_MAPPING
and
layer_name
not
in
ignore
:
shard_proj_names
=
FUSED_LAYER_NAME_MAPPING
[
proj_name
]
# Convert fused_name --> [shard_names]
...
...
vllm/model_executor/layers/quantization/experts_int8.py
View file @
96ae75ad
...
...
@@ -99,11 +99,13 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
=
True
,
renormalize
:
bool
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
...
...
@@ -115,7 +117,9 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
)
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
)
return
fused_experts
(
x
,
layer
.
w13_weight
,
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
96ae75ad
...
...
@@ -6,6 +6,7 @@ from torch.nn.parameter import Parameter
import
vllm.envs
as
envs
from
vllm
import
_custom_ops
as
ops
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe
import
(
FusedMoE
,
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
)
...
...
@@ -14,6 +15,8 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
apply_w8a8_block_fp8_linear
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
apply_fp8_marlin_linear
,
prepare_fp8_layer_for_marlin
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
...
...
@@ -22,7 +25,8 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d
,
apply_fp8_linear
,
convert_to_channelwise
,
cutlass_fp8_supported
,
normalize_e4m3fn_to_e4m3fnuz
,
per_tensor_dequantize
,
requantize_with_max_scale
)
from
vllm.model_executor.parameter
import
(
ModelWeightParameter
,
from
vllm.model_executor.parameter
import
(
BlockQuantScaleParameter
,
ModelWeightParameter
,
PerTensorScaleParameter
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
...
...
@@ -41,6 +45,7 @@ class Fp8Config(QuantizationConfig):
is_checkpoint_fp8_serialized
:
bool
=
False
,
activation_scheme
:
str
=
"dynamic"
,
ignored_layers
:
Optional
[
List
[
str
]]
=
None
,
weight_block_size
:
Optional
[
List
[
int
]]
=
None
,
)
->
None
:
self
.
is_checkpoint_fp8_serialized
=
is_checkpoint_fp8_serialized
if
is_checkpoint_fp8_serialized
:
...
...
@@ -51,6 +56,20 @@ class Fp8Config(QuantizationConfig):
f
"Unsupported activation scheme
{
activation_scheme
}
"
)
self
.
activation_scheme
=
activation_scheme
self
.
ignored_layers
=
ignored_layers
or
[]
if
weight_block_size
is
not
None
:
if
not
is_checkpoint_fp8_serialized
:
raise
ValueError
(
"The block-wise quantization only supports fp8-serialized "
"checkpoint for now."
)
if
len
(
weight_block_size
)
!=
2
:
raise
ValueError
(
"The quantization block size of weight must have 2 "
f
"dimensions, but got
{
len
(
weight_block_size
)
}
dimensions"
)
if
activation_scheme
!=
"dynamic"
:
raise
ValueError
(
"The block-wise quantization only supports "
"dynamic activation scheme for now, but got "
f
"
{
activation_scheme
}
activation scheme."
)
self
.
weight_block_size
=
weight_block_size
@
classmethod
def
get_name
(
cls
)
->
str
:
...
...
@@ -74,9 +93,12 @@ class Fp8Config(QuantizationConfig):
is_checkpoint_fp8_serialized
=
(
"fp8"
in
quant_method
)
activation_scheme
=
cls
.
get_from_keys
(
config
,
[
"activation_scheme"
])
ignored_layers
=
cls
.
get_from_keys_or
(
config
,
[
"ignored_layers"
],
None
)
weight_block_size
=
cls
.
get_from_keys_or
(
config
,
[
"weight_block_size"
],
None
)
return
cls
(
is_checkpoint_fp8_serialized
=
is_checkpoint_fp8_serialized
,
activation_scheme
=
activation_scheme
,
ignored_layers
=
ignored_layers
)
ignored_layers
=
ignored_layers
,
weight_block_size
=
weight_block_size
)
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
...
...
@@ -123,6 +145,11 @@ class Fp8LinearMethod(LinearMethodBase):
if
current_platform
.
is_rocm
():
self
.
use_marlin
=
False
self
.
block_quant
=
self
.
quant_config
.
weight_block_size
is
not
None
if
self
.
block_quant
:
# Marlin doesn't support block-wise fp8
self
.
use_marlin
=
False
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
...
...
@@ -133,10 +160,34 @@ class Fp8LinearMethod(LinearMethodBase):
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
del
input_size
,
output_size
output_size_per_partition
=
sum
(
output_partition_sizes
)
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
if
self
.
block_quant
:
tp_size
=
get_tensor_model_parallel_world_size
()
assert
self
.
quant_config
.
weight_block_size
is
not
None
block_n
,
block_k
=
(
self
.
quant_config
.
weight_block_size
[
0
],
self
.
quant_config
.
weight_block_size
[
1
],
)
# Required by row parallel
if
(
tp_size
>
1
and
input_size
//
input_size_per_partition
==
tp_size
and
input_size_per_partition
%
block_k
!=
0
):
raise
ValueError
(
f
"Weight input_size_per_partition = "
f
"
{
input_size_per_partition
}
is not divisible by "
f
"weight quantization block_k =
{
block_k
}
."
)
# Required by column parallel or enabling merged weights
if
(
tp_size
>
1
and
output_size
//
output_size_per_partition
==
tp_size
)
or
len
(
output_partition_sizes
)
>
1
:
for
output_partition_size
in
output_partition_sizes
:
if
output_partition_size
%
block_n
!=
0
:
raise
ValueError
(
f
"Weight output_partition_size = "
f
"
{
output_partition_size
}
is not divisible by "
f
"weight quantization block_n =
{
block_n
}
."
)
layer
.
logical_widths
=
output_partition_sizes
layer
.
input_size_per_partition
=
input_size_per_partition
...
...
@@ -161,12 +212,29 @@ class Fp8LinearMethod(LinearMethodBase):
# Otherwise, wait until process_weights_after_loading.
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
# WEIGHT SCALE
scale
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
)
scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
layer
.
register_parameter
(
"weight_scale"
,
scale
)
if
not
self
.
block_quant
:
scale
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
,
)
scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
layer
.
register_parameter
(
"weight_scale"
,
scale
)
else
:
assert
self
.
quant_config
.
activation_scheme
==
"dynamic"
scale
=
BlockQuantScaleParameter
(
data
=
torch
.
empty
(
(
output_size_per_partition
+
block_n
-
1
)
//
block_n
,
(
input_size_per_partition
+
block_k
-
1
)
//
block_k
,
dtype
=
torch
.
float32
,
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
,
)
scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
# The weight_scale_inv name is intentional for deepseekv3
layer
.
register_parameter
(
"weight_scale_inv"
,
scale
)
# INPUT ACTIVATION SCALE
if
self
.
quant_config
.
activation_scheme
==
"static"
:
...
...
@@ -180,6 +248,9 @@ class Fp8LinearMethod(LinearMethodBase):
layer
.
register_parameter
(
"input_scale"
,
None
)
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
# Block quant doesn't need to process weights after loading
if
self
.
block_quant
:
return
layer
.
weight
=
torch
.
nn
.
Parameter
(
layer
.
weight
.
data
,
requires_grad
=
False
)
# If checkpoint not serialized fp8, quantize the weights.
...
...
@@ -266,6 +337,17 @@ class Fp8LinearMethod(LinearMethodBase):
size_k
=
layer
.
input_size_per_partition
,
bias
=
bias
)
if
self
.
block_quant
:
assert
self
.
quant_config
.
weight_block_size
is
not
None
return
apply_w8a8_block_fp8_linear
(
input
=
x
,
weight
=
layer
.
weight
,
block_size
=
self
.
quant_config
.
weight_block_size
,
weight_scale
=
layer
.
weight_scale_inv
,
input_scale
=
layer
.
input_scale
,
bias
=
bias
,
)
return
apply_fp8_linear
(
input
=
x
,
weight
=
layer
.
weight
,
...
...
@@ -291,6 +373,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def
__init__
(
self
,
quant_config
:
Fp8Config
):
self
.
quant_config
=
quant_config
self
.
block_quant
=
self
.
quant_config
.
weight_block_size
is
not
None
def
create_weights
(
self
,
layer
:
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
torch
.
dtype
,
...
...
@@ -298,6 +381,27 @@ class Fp8MoEMethod(FusedMoEMethodBase):
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
params_dtype
=
torch
.
float8_e4m3fn
if
self
.
block_quant
:
assert
self
.
quant_config
.
weight_block_size
is
not
None
tp_size
=
get_tensor_model_parallel_world_size
()
block_n
,
block_k
=
(
self
.
quant_config
.
weight_block_size
[
0
],
self
.
quant_config
.
weight_block_size
[
1
],
)
# NOTE: To ensure proper alignment of the block-wise quantization
# scales, the output_size of the weights for both the gate and up
# layers must be divisible by block_n.
# Required by column parallel or enabling merged weights
if
intermediate_size
%
block_n
!=
0
:
raise
ValueError
(
f
"The output_size of gate's and up's weight = "
f
"
{
intermediate_size
}
is not divisible by "
f
"weight quantization block_n =
{
block_n
}
."
)
if
(
tp_size
>
1
and
intermediate_size
%
block_k
!=
0
):
# Required by row parallel
raise
ValueError
(
f
"The input_size of down's weight = "
f
"
{
intermediate_size
}
is not divisible by "
f
"weight quantization block_k =
{
block_k
}
."
)
# WEIGHTS
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
...
...
@@ -317,21 +421,45 @@ class Fp8MoEMethod(FusedMoEMethodBase):
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
2
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
if
not
self
.
block_quant
:
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
2
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
w2_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
else
:
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
2
*
((
intermediate_size
+
block_n
-
1
)
//
block_n
),
(
hidden_size
+
block_k
-
1
)
//
block_k
,
dtype
=
torch
.
float32
,
),
requires_grad
=
False
,
)
w2_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
(
hidden_size
+
block_n
-
1
)
//
block_n
,
(
intermediate_size
+
block_k
-
1
)
//
block_k
,
dtype
=
torch
.
float32
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight_scale_inv"
,
w13_weight_scale
)
layer
.
register_parameter
(
"w2_weight_scale_inv"
,
w2_weight_scale
)
assert
self
.
quant_config
.
activation_scheme
==
"dynamic"
w2_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
# Add the quantization method used (per tensor/grouped/channel)
# to ensure the weight scales are loaded in properly
extra_weight_attrs
.
update
(
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
BLOCK
.
value
}
if
self
.
block_quant
else
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
TENSOR
.
value
})
# If loading fp8 checkpoint, pass the weight loaders.
# If loading an fp16 checkpoint, do not (we will quantize in
...
...
@@ -364,7 +492,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer
.
w2_input_scale
=
None
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
# Block quant doesn't need to process weights after loading
if
self
.
block_quant
:
return
# If checkpoint is fp16, quantize in place.
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
# If rocm, use float8_e4m3fnuz as dtype
...
...
@@ -471,12 +601,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
...
...
@@ -487,19 +618,27 @@ class Fp8MoEMethod(FusedMoEMethodBase):
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
)
return
fused_experts
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
use_fp8_w8a8
=
True
,
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
)
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
,
)
return
fused_experts
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
use_fp8_w8a8
=
True
,
w1_scale
=
(
layer
.
w13_weight_scale_inv
if
self
.
block_quant
else
layer
.
w13_weight_scale
),
w2_scale
=
(
layer
.
w2_weight_scale_inv
if
self
.
block_quant
else
layer
.
w2_weight_scale
),
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
block_shape
=
self
.
quant_config
.
weight_block_size
,
)
class
Fp8KVCacheMethod
(
BaseKVCacheMethod
):
...
...
vllm/model_executor/layers/quantization/gptq_marlin.py
View file @
96ae75ad
...
...
@@ -532,11 +532,13 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
=
True
,
renormalize
:
bool
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
# The input must currently be float16
orig_dtype
=
x
.
dtype
...
...
@@ -550,7 +552,9 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
None
)
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
)
return
torch
.
ops
.
vllm
.
fused_marlin_moe
(
x
,
...
...
Prev
1
…
11
12
13
14
15
16
17
18
19
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment