"components/vscode:/vscode.git/clone" did not exist on "41d7d5490fc8e723fa1ef88ec946d9c7f4ec89b4"
Commit 96ae75ad authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.6.post1' into v0.6.6.post1-dev

parents f9f4a735 2339d59f
...@@ -28,7 +28,7 @@ from vllm.lora.utils import (from_layer, from_layer_logits_processor, ...@@ -28,7 +28,7 @@ from vllm.lora.utils import (from_layer, from_layer_logits_processor,
parse_fine_tuned_lora_name, replace_submodule) parse_fine_tuned_lora_name, replace_submodule)
from vllm.model_executor.models import SupportsLoRA, supports_multimodal from vllm.model_executor.models import SupportsLoRA, supports_multimodal
from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.utils import PPMissingLayer from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper
from vllm.utils import is_pin_memory_available from vllm.utils import is_pin_memory_available
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -113,13 +113,14 @@ class LoRAModel(AdapterModel): ...@@ -113,13 +113,14 @@ class LoRAModel(AdapterModel):
target_embedding_padding: Optional[int] = None, target_embedding_padding: Optional[int] = None,
embedding_modules: Optional[Dict[str, str]] = None, embedding_modules: Optional[Dict[str, str]] = None,
embedding_padding_modules: Optional[List[str]] = None, embedding_padding_modules: Optional[List[str]] = None,
weights_mapper: Optional[WeightsMapper] = None,
) -> "LoRAModel": ) -> "LoRAModel":
"""Create a LoRAModel from a dictionary of tensors.""" """Create a LoRAModel from a dictionary of tensors."""
pin_memory = str(device) == "cpu" and is_pin_memory_available() pin_memory = str(device) == "cpu" and is_pin_memory_available()
loras: Dict[str, LoRALayerWeights] = {} loras: Dict[str, LoRALayerWeights] = {}
for tensor_name, tensor in tensors.items(): for tensor_name, tensor in tensors.items():
module_name, is_lora_a, is_bias = parse_fine_tuned_lora_name( module_name, is_lora_a, is_bias = parse_fine_tuned_lora_name(
tensor_name) tensor_name, weights_mapper)
if module_name not in loras: if module_name not in loras:
lora_embeddings_tensor = None lora_embeddings_tensor = None
if embeddings: if embeddings:
...@@ -187,6 +188,7 @@ class LoRAModel(AdapterModel): ...@@ -187,6 +188,7 @@ class LoRAModel(AdapterModel):
target_embedding_padding: Optional[int] = None, target_embedding_padding: Optional[int] = None,
embedding_modules: Optional[Dict[str, str]] = None, embedding_modules: Optional[Dict[str, str]] = None,
embedding_padding_modules: Optional[List[str]] = None, embedding_padding_modules: Optional[List[str]] = None,
weights_mapper: Optional[WeightsMapper] = None,
) -> "LoRAModel": ) -> "LoRAModel":
"""Create a LoRAModel from a local checkpoint. """Create a LoRAModel from a local checkpoint.
...@@ -229,7 +231,8 @@ class LoRAModel(AdapterModel): ...@@ -229,7 +231,8 @@ class LoRAModel(AdapterModel):
with safetensors.safe_open(lora_tensor_path, with safetensors.safe_open(lora_tensor_path,
framework="pt") as f: # type: ignore framework="pt") as f: # type: ignore
for lora_module in f.keys(): # noqa for lora_module in f.keys(): # noqa
module_name, _, _ = parse_fine_tuned_lora_name(lora_module) module_name, _, _ = parse_fine_tuned_lora_name(
lora_module, weights_mapper)
part_name = module_name.split(".")[-1] part_name = module_name.split(".")[-1]
if part_name not in expected_lora_modules: if part_name not in expected_lora_modules:
unexpected_modules.append(module_name) unexpected_modules.append(module_name)
...@@ -289,7 +292,8 @@ class LoRAModel(AdapterModel): ...@@ -289,7 +292,8 @@ class LoRAModel(AdapterModel):
embeddings=embeddings, embeddings=embeddings,
target_embedding_padding=target_embedding_padding, target_embedding_padding=target_embedding_padding,
embedding_modules=embedding_modules, embedding_modules=embedding_modules,
embedding_padding_modules=embedding_padding_modules) embedding_padding_modules=embedding_padding_modules,
weights_mapper=weights_mapper)
class LoRAModelManager(AdapterModelManager): class LoRAModelManager(AdapterModelManager):
......
...@@ -30,6 +30,7 @@ from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA, ...@@ -30,6 +30,7 @@ from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
# yapf: enable # yapf: enable
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.models.utils import WeightsMapper
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -91,28 +92,46 @@ def replace_submodule(model: nn.Module, module_name: str, ...@@ -91,28 +92,46 @@ def replace_submodule(model: nn.Module, module_name: str,
return new_module return new_module
def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool, bool]: def parse_fine_tuned_lora_name(
name: str,
weights_mapper: Optional[WeightsMapper] = None
) -> Tuple[str, bool, bool]:
"""Parse the name of lora weights. """Parse the name of lora weights.
args: args:
name: the name of the fine-tuned LoRA, e.g. name: the name of the fine-tuned LoRA, e.g.
base_model.model.dense1.weight base_model.model.dense1.weight
weights_mapper: maps the name of weight, e.g.
`model.` -> `language_model.model.`,
return: return:
Tuple(module_name, is_lora_a): Tuple(module_name, is_lora_a):
module_name: the name of the module, e.g. model.dense1, module_name: the name of the module, e.g. model.dense1,
is_lora_a whether the tensor is lora_a or lora_b. is_lora_a whether the tensor is lora_a or lora_b.
is_bias whether the tensor is lora bias. is_bias whether the tensor is lora bias.
""" """
# LoRA weight qualified name always starts with `base_model.model.`,
# so we remove the prefix `base_model.model.` to make the following
# mapping correctly.
if "base_model.model." in name:
name = name.replace("base_model.model.", "")
name = weights_mapper._map_name(name) if weights_mapper else name
# recover the prefix `base_model.model.`
name = "base_model.model." + name
parts = name.split(".") parts = name.split(".")
if parts[-1] == "weight" and (parts[-2] == "lora_A" if parts[-1] == "weight" and (parts[-2] == "lora_A"
or parts[-2] == "lora_B"): or parts[-2] == "lora_B"):
return ".".join(parts[2:-2]), parts[-2] == "lora_A", False new_name = ".".join(parts[2:-2])
return new_name, parts[-2] == "lora_A", False
if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B": if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A", False new_name = ".".join(parts[2:-1])
return new_name, parts[-1] == "lora_embedding_A", False
if parts[-1] == "bias": if parts[-1] == "bias":
return ".".join(parts[2:-2]), False, True new_name = ".".join(parts[2:-2])
return new_name, False, True
raise ValueError(f"{name} is unsupported LoRA weight") raise ValueError(f"{name} is unsupported LoRA weight")
......
...@@ -91,7 +91,17 @@ class WorkerLoRAManager(AbstractWorkerManager): ...@@ -91,7 +91,17 @@ class WorkerLoRAManager(AbstractWorkerManager):
packed_modules_mapping[module]) packed_modules_mapping[module])
else: else:
expected_lora_modules.append(module) expected_lora_modules.append(module)
expected_lora_modules = list(set(expected_lora_modules))
lora_path = get_adapter_absolute_path(lora_request.lora_path) lora_path = get_adapter_absolute_path(lora_request.lora_path)
# For some models like Qwen2VL, we need to use hf_to_vllm_mapper
# to ensure correct loading of lora weights.
hf_to_vllm_mapper = None
if (hasattr(model, "hf_to_vllm_mapper")
and model.hf_to_vllm_mapper is not None):
hf_to_vllm_mapper = model.hf_to_vllm_mapper
lora = self._lora_model_cls.from_local_checkpoint( lora = self._lora_model_cls.from_local_checkpoint(
lora_path, lora_path,
expected_lora_modules, expected_lora_modules,
...@@ -103,7 +113,8 @@ class WorkerLoRAManager(AbstractWorkerManager): ...@@ -103,7 +113,8 @@ class WorkerLoRAManager(AbstractWorkerManager):
self.lora_config.lora_extra_vocab_size, self.lora_config.lora_extra_vocab_size,
embedding_modules=self.embedding_modules, embedding_modules=self.embedding_modules,
embedding_padding_modules=self.embedding_padding_modules, embedding_padding_modules=self.embedding_padding_modules,
) weights_mapper=hf_to_vllm_mapper)
except Exception as e: except Exception as e:
raise RuntimeError(f"Loading lora {lora_path} failed") from e raise RuntimeError(f"Loading lora {lora_path} failed") from e
if lora.rank > self.lora_config.max_lora_rank: if lora.rank > self.lora_config.max_lora_rank:
......
...@@ -3,6 +3,9 @@ from __future__ import annotations ...@@ -3,6 +3,9 @@ from __future__ import annotations
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.guided_decoding.utils import (
convert_lark_to_gbnf, grammar_is_likely_lark,
has_lmf_unsupported_json_features, has_xgrammar_unsupported_json_features)
from vllm.platforms import CpuArchEnum, current_platform from vllm.platforms import CpuArchEnum, current_platform
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -15,49 +18,24 @@ if TYPE_CHECKING: ...@@ -15,49 +18,24 @@ if TYPE_CHECKING:
logger = init_logger(__name__) logger = init_logger(__name__)
def has_xgrammar_unsupported_json_features(schema: dict) -> bool:
"""Check if JSON schema contains features unsupported by xgrammar."""
def check_object(obj: dict) -> bool:
if not isinstance(obj, dict):
return False
# Check for pattern restrictions
if "pattern" in obj:
return True
# Check for numeric ranges
if obj.get("type") in ("integer", "number") and any(
key in obj for key in [
"minimum", "maximum", "exclusiveMinimum",
"exclusiveMaximum", "multipleOf"
]):
return True
# Recursively check all nested objects and arrays
for value in obj.values():
if isinstance(value, dict):
if check_object(value):
return True
elif isinstance(value, list):
for item in value:
if isinstance(item, dict) and check_object(item):
return True
return False
return check_object(schema)
def maybe_backend_fallback( def maybe_backend_fallback(
guided_params: GuidedDecodingParams) -> GuidedDecodingParams: guided_params: GuidedDecodingParams) -> GuidedDecodingParams:
# lm-format-enforce doesn't support grammar, fallback to xgrammar # lm-format-enforce doesn't support grammar, fallback to xgrammar
if (guided_params.backend == "lm-format-enforcer" if guided_params.backend == "lm-format-enforcer":
and guided_params.grammar is not None): if guided_params.grammar is not None:
logger.warning( logger.warning(
"lm-format-enforcer does not support grammar guided decoding. " "lm-format-enforcer does not support grammar guided decoding. "
"Falling back to use xgrammar instead.") "Falling back to use xgrammar instead.")
guided_params.backend = "xgrammar" guided_params.backend = "xgrammar"
# lm-format-enforcer doesn't support some JSON schema features
elif (guided_params.json is not None
and has_lmf_unsupported_json_features(guided_params.json)):
logger.warning(
"lm-format-enforcer does not support advanced JSON schema "
"features like patterns or numeric ranges. "
"Falling back to use outlines instead.")
guided_params.backend = "outlines"
if guided_params.backend == "xgrammar": if guided_params.backend == "xgrammar":
# xgrammar only has x86 wheels for linux, fallback to outlines # xgrammar only has x86 wheels for linux, fallback to outlines
...@@ -82,6 +60,27 @@ def maybe_backend_fallback( ...@@ -82,6 +60,27 @@ def maybe_backend_fallback(
"Falling back to use outlines instead.") "Falling back to use outlines instead.")
guided_params.backend = "outlines" guided_params.backend = "outlines"
# xgrammar only supports GBNF grammars, so we must convert Lark.
# We must check if the grammar is likely Lark and if that
# grammar is convertible to GBNF
elif (guided_params.grammar is not None
and grammar_is_likely_lark(guided_params.grammar)):
try:
convert_lark_to_gbnf(guided_params.grammar)
except Exception:
logger.warning(
"xgrammar does not support Lark grammars and the "
"grammar failed to convert to GBNF. "
"Falling back to use outlines instead.")
guided_params.backend = "outlines"
if (guided_params.backend == "outlines"
and guided_params.json_object is not None):
# outlines doesn't support json_object, fallback to xgrammar
logger.warning("outlines does not support json_object. "
"Falling back to use xgrammar instead.")
guided_params.backend = "xgrammar"
return guided_params return guided_params
......
...@@ -21,10 +21,11 @@ from typing import Callable, DefaultDict, Dict, List, Union ...@@ -21,10 +21,11 @@ from typing import Callable, DefaultDict, Dict, List, Union
import numpy as np import numpy as np
import torch import torch
from lark import Lark
from outlines import grammars from outlines import grammars
from outlines.caching import cache from outlines.caching import cache
from outlines.fsm.guide import CFGGuide, Generate, Guide, RegexGuide, Write from outlines.fsm.guide import (CFGGuide, CFGState, Generate, Guide,
RegexGuide, Write)
from outlines.fsm.parsing import PartialLark
from outlines_core.fsm.json_schema import build_regex_from_schema from outlines_core.fsm.json_schema import build_regex_from_schema
from pydantic import BaseModel from pydantic import BaseModel
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
...@@ -34,7 +35,9 @@ class BaseLogitsProcessor: ...@@ -34,7 +35,9 @@ class BaseLogitsProcessor:
def __init__(self, guide: Guide): def __init__(self, guide: Guide):
self._guide: Guide = guide self._guide: Guide = guide
self._fsm_state: DefaultDict[int, int] = defaultdict(int) # CFGState is used for the FSM state for CFGGuide
self._fsm_state: DefaultDict[int, Union[int,
CFGState]] = defaultdict(int)
def __call__(self, input_ids: List[int], def __call__(self, input_ids: List[int],
scores: torch.Tensor) -> torch.Tensor: scores: torch.Tensor) -> torch.Tensor:
...@@ -54,15 +57,13 @@ class BaseLogitsProcessor: ...@@ -54,15 +57,13 @@ class BaseLogitsProcessor:
# On the first time this is called, we simply re-create # On the first time this is called, we simply re-create
# the Lark object. # the Lark object.
if isinstance(self._guide, CFGGuide): if isinstance(self._guide, CFGGuide):
self._guide.parser = Lark( self._guide.parser = PartialLark(
self._guide.cfg_string, self._guide.cfg_string,
parser="lalr", parser="lalr",
lexer="contextual",
propagate_positions=False,
maybe_placeholders=False,
regex=True,
import_paths=[grammars.GRAMMAR_PATH], import_paths=[grammars.GRAMMAR_PATH],
) )
self._fsm_state[seq_id] = CFGState(
parser_state=self._guide.parser.parse(""), prev_token=None)
instruction = self._guide.get_next_instruction( instruction = self._guide.get_next_instruction(
state=self._fsm_state[seq_id]) state=self._fsm_state[seq_id])
...@@ -200,7 +201,8 @@ def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase): ...@@ -200,7 +201,8 @@ def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase):
string = tokenizer.convert_tokens_to_string([token]) string = tokenizer.convert_tokens_to_string([token])
# A hack to handle missing spaces to HF's Llama tokenizers # A hack to handle missing spaces to HF's Llama tokenizers
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": if (type(token) is str and token.startswith(SPIECE_UNDERLINE)
or token == "<0x20>"):
return " " + string return " " + string
return string return string
...@@ -211,6 +213,9 @@ def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase): ...@@ -211,6 +213,9 @@ def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase):
"""Sync vLLM's decoder with the outlines by returning list.""" """Sync vLLM's decoder with the outlines by returning list."""
def new_decoder(inp_tokens: List[int]) -> List[str]: def new_decoder(inp_tokens: List[int]) -> List[str]:
if (isinstance(inp_tokens, list) and len(inp_tokens) == 1
and isinstance(inp_tokens[0], list)):
inp_tokens = inp_tokens[0]
return [decoder(inp_tokens)] return [decoder(inp_tokens)]
return new_decoder return new_decoder
......
import re import re
def has_xgrammar_unsupported_json_features(schema: dict) -> bool:
"""Check if JSON schema contains features unsupported by xgrammar."""
def check_object(obj: dict) -> bool:
if not isinstance(obj, dict):
return False
# Check for pattern restrictions
if "pattern" in obj:
return True
# Check for numeric ranges
if obj.get("type") in ("integer", "number") and any(
key in obj for key in [
"minimum", "maximum", "exclusiveMinimum",
"exclusiveMaximum", "multipleOf"
]):
return True
# Recursively check all nested objects and arrays
for value in obj.values():
if isinstance(value, dict):
if check_object(value):
return True
elif isinstance(value, list):
for item in value:
if isinstance(item, dict) and check_object(item):
return True
return False
return check_object(schema)
def has_lmf_unsupported_json_features(schema: dict) -> bool:
"""
Check if JSON schema contains features unsupported
by lm_format_enforcer.
Known issues:
- Regex patterns:
"grade": {
"type": "string",
"pattern": "^[A-D]$" # Regex pattern
},
"""
def check_object(obj: dict) -> bool:
if not isinstance(obj, dict):
return False
# Check for pattern restrictions
if "pattern" in obj:
return True
# Recursively check all nested objects and arrays
for value in obj.values():
if isinstance(value, dict):
if check_object(value):
return True
elif isinstance(value, list):
for item in value:
if isinstance(item, dict) and check_object(item):
return True
return False
return check_object(schema)
def grammar_is_likely_lark(grammar_str: str) -> bool: def grammar_is_likely_lark(grammar_str: str) -> bool:
""" """
Check if grammar appears to use Lark syntax. Check if grammar appears to use Lark syntax.
......
...@@ -3,7 +3,7 @@ from __future__ import annotations ...@@ -3,7 +3,7 @@ from __future__ import annotations
import json import json
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, NamedTuple from typing import TYPE_CHECKING, Any
import torch import torch
from transformers import PreTrainedTokenizerFast from transformers import PreTrainedTokenizerFast
...@@ -14,8 +14,9 @@ try: ...@@ -14,8 +14,9 @@ try:
except ImportError: except ImportError:
pass pass
from vllm.model_executor.guided_decoding.xgrammar_utils import ( from vllm.model_executor.guided_decoding.utils import (convert_lark_to_gbnf,
convert_lark_to_gbnf, grammar_is_likely_lark) grammar_is_likely_lark)
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
...@@ -37,11 +38,21 @@ def get_local_xgrammar_guided_decoding_logits_processor( ...@@ -37,11 +38,21 @@ def get_local_xgrammar_guided_decoding_logits_processor(
return XGrammarLogitsProcessor(config) return XGrammarLogitsProcessor(config)
class TokenizerData(NamedTuple): @dataclass(frozen=True)
class TokenizerData:
"""Immutable container for cached tokenizer data.""" """Immutable container for cached tokenizer data."""
encoded_vocab: list[str] encoded_vocab: list[str] = field(default_factory=list)
stop_token_ids: list[int] | None stop_token_ids: list[int] | None = None
backend_str: str # These fields are mutually exclusive: `backend_str` is used to create a
# TokenizeInfo with `TokenizerInfo.from_huggingface` while `vocab_type` is
# used within the constructor of TokenizeInfo
backend_str: str | None = None
vocab_type: xgr.VocabType | None = None
def __post_init__(self):
# Check for mutual exclusive
assert not (self.backend_str and self.vocab_type), \
"backend_str and vocab_type are mutual exclusive"
class TokenizerDataCache: class TokenizerDataCache:
...@@ -68,18 +79,27 @@ class TokenizerDataCache: ...@@ -68,18 +79,27 @@ class TokenizerDataCache:
"get_vocab method.") from e "get_vocab method.") from e
stop_token_ids = None stop_token_ids = None
backend_str = xgr.VocabType.RAW backend_str = ""
vocab_type = xgr.VocabType.RAW
if stop_token_ids is None and hasattr(
tokenizer,
"eos_token_id") and tokenizer.eos_token_id is not None:
stop_token_ids = [tokenizer.eos_token_id]
if isinstance(tokenizer, PreTrainedTokenizerFast): if isinstance(tokenizer, PreTrainedTokenizerFast):
backend_str = tokenizer.backend_tokenizer.to_str() backend_str = tokenizer.backend_tokenizer.to_str()
if stop_token_ids is None and hasattr( vocab_type = None
tokenizer,
"eos_token_id") and tokenizer.eos_token_id is not None: elif isinstance(tokenizer, MistralTokenizer):
stop_token_ids = [tokenizer.eos_token_id] # REF: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501
vocab_type = xgr.VocabType.BYTE_FALLBACK
cls._cache[tokenizer_hash] = TokenizerData( cls._cache[tokenizer_hash] = TokenizerData(
encoded_vocab=encoded_vocab, encoded_vocab=encoded_vocab,
stop_token_ids=stop_token_ids, stop_token_ids=stop_token_ids,
backend_str=backend_str) backend_str=backend_str,
vocab_type=vocab_type)
return cls._cache[tokenizer_hash] return cls._cache[tokenizer_hash]
...@@ -98,11 +118,30 @@ class GrammarCompilerCache: ...@@ -98,11 +118,30 @@ class GrammarCompilerCache:
cache_key = str(config.tokenizer_hash) cache_key = str(config.tokenizer_hash)
if cache_key not in cls._cache: if cache_key not in cls._cache:
assert config.encoded_vocab is not None assert config.tokenizer_data is not None
tokenizer_info = xgr.TokenizerInfo._create_from_handle( assert config.tokenizer_data.encoded_vocab is not None
xgr_core.TokenizerInfo.from_huggingface(
config.encoded_vocab, config.backend_str, config_data = config.tokenizer_data
config.vocab_size, config.stop_token_ids))
# In TokenizerDataCache.get_tokenizer_data, a serializable
# tokenizer_data is created and cached. This data is used to build
# a tokenizer_info and create an xgrammar compiler.
# - If tokenizer_data has backend_str set, use
# xgr_core.TokenizerInfo.from_huggingface (a C++ bind).
# - Otherwise, use the default constructor with vocab_type.
# - xgr_core.TokenizerInfo.from_huggingface !=
# xgr.TokenizerInfo.from_huggingface.
if config_data.backend_str:
tokenizer_info = xgr.TokenizerInfo._create_from_handle(
xgr_core.TokenizerInfo.from_huggingface(
config_data.encoded_vocab, config_data.backend_str,
config.vocab_size, config_data.stop_token_ids))
else:
tokenizer_info = xgr.TokenizerInfo(
config_data.encoded_vocab,
config_data.vocab_type,
vocab_size=config.vocab_size,
stop_token_ids=config_data.stop_token_ids)
cls._cache[cache_key] = xgr.GrammarCompiler( cls._cache[cache_key] = xgr.GrammarCompiler(
tokenizer_info, max_threads=config.max_threads) tokenizer_info, max_threads=config.max_threads)
...@@ -118,10 +157,7 @@ class GrammarConfig: ...@@ -118,10 +157,7 @@ class GrammarConfig:
grammar_str: str | None = None grammar_str: str | None = None
json_object: bool | None = None json_object: bool | None = None
max_threads: int = 8 max_threads: int = 8
# Only populated if tokenizer_hash not in cache tokenizer_data: TokenizerData | None = None
encoded_vocab: list[str] | None = None
stop_token_ids: list[int] | None = None
backend_str: str | None = None
@classmethod @classmethod
def from_guided_params(cls, def from_guided_params(cls,
...@@ -132,9 +168,6 @@ class GrammarConfig: ...@@ -132,9 +168,6 @@ class GrammarConfig:
tokenizer_hash = hash(tokenizer) tokenizer_hash = hash(tokenizer)
tokenizer_data = TokenizerDataCache.get_tokenizer_data(tokenizer) tokenizer_data = TokenizerDataCache.get_tokenizer_data(tokenizer)
encoded_vocab = tokenizer_data.encoded_vocab
stop_token_ids = tokenizer_data.stop_token_ids
backend_str = tokenizer_data.backend_str
if guided_params.json: if guided_params.json:
if not isinstance(guided_params.json, str): if not isinstance(guided_params.json, str):
...@@ -152,11 +185,9 @@ class GrammarConfig: ...@@ -152,11 +185,9 @@ class GrammarConfig:
return cls(json_str=json_str, return cls(json_str=json_str,
vocab_size=model_config.hf_text_config.vocab_size, vocab_size=model_config.hf_text_config.vocab_size,
encoded_vocab=encoded_vocab,
stop_token_ids=stop_token_ids,
backend_str=backend_str,
tokenizer_hash=tokenizer_hash, tokenizer_hash=tokenizer_hash,
max_threads=max_threads) max_threads=max_threads,
tokenizer_data=tokenizer_data)
elif guided_params.grammar: elif guided_params.grammar:
# XGrammar only supports GBNF grammars, so we must convert Lark # XGrammar only supports GBNF grammars, so we must convert Lark
if grammar_is_likely_lark(guided_params.grammar): if grammar_is_likely_lark(guided_params.grammar):
...@@ -181,19 +212,17 @@ class GrammarConfig: ...@@ -181,19 +212,17 @@ class GrammarConfig:
return cls(grammar_str=grammar_str, return cls(grammar_str=grammar_str,
vocab_size=model_config.hf_text_config.vocab_size, vocab_size=model_config.hf_text_config.vocab_size,
encoded_vocab=encoded_vocab,
stop_token_ids=stop_token_ids,
backend_str=backend_str,
tokenizer_hash=tokenizer_hash, tokenizer_hash=tokenizer_hash,
max_threads=max_threads) max_threads=max_threads,
tokenizer_data=tokenizer_data)
elif guided_params.json_object: elif guided_params.json_object:
return cls(json_object=True, return cls(
vocab_size=model_config.hf_text_config.vocab_size, json_object=True,
encoded_vocab=encoded_vocab, vocab_size=model_config.hf_text_config.vocab_size,
stop_token_ids=stop_token_ids, tokenizer_hash=tokenizer_hash,
backend_str=backend_str, max_threads=max_threads,
tokenizer_hash=tokenizer_hash, tokenizer_data=tokenizer_data,
max_threads=max_threads) )
else: else:
raise ValueError( raise ValueError(
"Currently only support JSON and EBNF grammar mode for xgrammar" "Currently only support JSON and EBNF grammar mode for xgrammar"
...@@ -269,10 +298,14 @@ class XGrammarLogitsProcessor: ...@@ -269,10 +298,14 @@ class XGrammarLogitsProcessor:
# fill_next_token_bitmask so we move it to the device of scores # fill_next_token_bitmask so we move it to the device of scores
device_type = scores.device.type device_type = scores.device.type
if device_type != "cuda": if device_type != "cuda":
scores = scores.to("cpu") scores = scores.to("cpu").unsqueeze(0)
# Note: In this method, if the tensors have different dimensions
# on CPU device fails, but on GPU it runs without error. Hence the
# unsqueeze above for scores, to match the token bitmask shape
xgr.apply_token_bitmask_inplace(scores, xgr.apply_token_bitmask_inplace(scores,
self.token_bitmask.to(scores.device)) self.token_bitmask.to(scores.device))
if device_type != "cuda": if device_type != "cuda":
scores = scores.to(device_type) scores = scores.to(device_type).squeeze()
return scores return scores
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
import functools import functools
import json import json
import os import os
from typing import Any, Callable, Dict, Optional, Tuple from typing import Any, Callable, Dict, List, Optional, Tuple
import torch import torch
import triton import triton
...@@ -11,6 +11,8 @@ import triton.language as tl ...@@ -11,6 +11,8 @@ import triton.language as tl
import vllm.envs as envs import vllm.envs as envs
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
...@@ -45,8 +47,14 @@ def fused_moe_kernel( ...@@ -45,8 +47,14 @@ def fused_moe_kernel(
stride_bn, stride_bn,
stride_cm, stride_cm,
stride_cn, stride_cn,
stride_asm,
stride_ask,
stride_bse, stride_bse,
stride_bsk,
stride_bsn, stride_bsn,
# Block size for block-wise quantization
group_n: tl.constexpr,
group_k: tl.constexpr,
# Meta-parameters # Meta-parameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
...@@ -125,8 +133,14 @@ def fused_moe_kernel( ...@@ -125,8 +133,14 @@ def fused_moe_kernel(
b_scale = tl.load(b_scale_ptrs) b_scale = tl.load(b_scale_ptrs)
if use_fp8_w8a8: if use_fp8_w8a8:
a_scale = tl.load(a_scale_ptr) if group_k > 0 and group_n > 0:
b_scale = tl.load(b_scale_ptr + off_experts) a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
offs_bsn = offs_bn // group_n
b_scale_ptrs = (b_scale_ptr + off_experts * stride_bse +
offs_bsn * stride_bsn)
else:
a_scale = tl.load(a_scale_ptr)
b_scale = tl.load(b_scale_ptr + off_experts)
# ----------------------------------------------------------- # -----------------------------------------------------------
# Iterate to compute a block of the C matrix. # Iterate to compute a block of the C matrix.
...@@ -149,7 +163,18 @@ def fused_moe_kernel( ...@@ -149,7 +163,18 @@ def fused_moe_kernel(
if use_int8_w8a16: if use_int8_w8a16:
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
elif use_fp8_w8a8: elif use_fp8_w8a8:
accumulator = tl.dot(a, b, acc=accumulator) if group_k > 0 and group_n > 0:
k_start = k * BLOCK_SIZE_K
offs_ks = k_start // group_k
a_scale = tl.load(a_scale_ptrs + offs_ks * stride_ask,
mask=token_mask,
other=0.0)
b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
accumulator += tl.dot(a, b) * a_scale[:,
None] * b_scale[None, :]
else:
accumulator = tl.dot(a, b, acc=accumulator)
else: else:
accumulator += tl.dot(a, b) accumulator += tl.dot(a, b)
# Advance the ptrs to the next K block. # Advance the ptrs to the next K block.
...@@ -164,7 +189,10 @@ def fused_moe_kernel( ...@@ -164,7 +189,10 @@ def fused_moe_kernel(
if use_int8_w8a16: if use_int8_w8a16:
accumulator = (accumulator * b_scale).to(compute_type) accumulator = (accumulator * b_scale).to(compute_type)
elif use_fp8_w8a8: elif use_fp8_w8a8:
accumulator = (accumulator * a_scale * b_scale).to(compute_type) if group_k > 0 and group_n > 0:
accumulator = accumulator.to(compute_type)
else:
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
else: else:
accumulator = accumulator.to(compute_type) accumulator = accumulator.to(compute_type)
# ----------------------------------------------------------- # -----------------------------------------------------------
...@@ -233,22 +261,37 @@ def moe_align_block_size( ...@@ -233,22 +261,37 @@ def moe_align_block_size(
return sorted_ids, expert_ids, num_tokens_post_pad return sorted_ids, expert_ids, num_tokens_post_pad
def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, def invoke_fused_moe_kernel(A: torch.Tensor,
B: torch.Tensor,
C: torch.Tensor,
A_scale: Optional[torch.Tensor], A_scale: Optional[torch.Tensor],
B_scale: Optional[torch.Tensor], B_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor, topk_ids: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
sorted_token_ids: torch.Tensor, sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor, expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor, num_tokens_post_padded: torch.Tensor,
mul_routed_weight: bool, top_k: int, mul_routed_weight: bool,
config: Dict[str, Any], compute_type: tl.dtype, top_k: int,
use_fp8_w8a8: bool, use_int8_w8a16: bool) -> None: config: Dict[str, Any],
compute_type: tl.dtype,
use_fp8_w8a8: bool,
use_int8_w8a16: bool,
block_shape: Optional[List[int]] = None) -> None:
assert topk_weights.stride(1) == 1 assert topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1 assert sorted_token_ids.stride(0) == 1
if use_fp8_w8a8: if use_fp8_w8a8:
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
assert B_scale is not None assert B_scale is not None
if block_shape is None:
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
else:
assert len(block_shape) == 2
block_n, block_k = block_shape[0], block_shape[1]
A, A_scale = per_token_group_quant_fp8(A, block_k)
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
elif use_int8_w8a16: elif use_int8_w8a16:
assert B_scale is not None assert B_scale is not None
else: else:
...@@ -279,8 +322,13 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, ...@@ -279,8 +322,13 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
B.stride(1), B.stride(1),
C.stride(1), C.stride(1),
C.stride(2), C.stride(2),
B_scale.stride(0) if B_scale is not None and use_int8_w8a16 else 0, A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
B_scale.stride(1) if B_scale is not None and use_int8_w8a16 else 0, A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0,
B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0,
B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0,
B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0,
0 if block_shape is None else block_shape[0],
0 if block_shape is None else block_shape[1],
MUL_ROUTED_WEIGHT=mul_routed_weight, MUL_ROUTED_WEIGHT=mul_routed_weight,
top_k=top_k, top_k=top_k,
compute_type=compute_type, compute_type=compute_type,
...@@ -362,6 +410,7 @@ def try_get_optimal_moe_config( ...@@ -362,6 +410,7 @@ def try_get_optimal_moe_config(
dtype: Optional[str], dtype: Optional[str],
M: int, M: int,
is_marlin: bool = False, is_marlin: bool = False,
block_shape: Optional[List[int]] = None,
): ):
from vllm.model_executor.layers.fused_moe import get_config from vllm.model_executor.layers.fused_moe import get_config
override_config = get_config() override_config = get_config()
...@@ -380,6 +429,12 @@ def try_get_optimal_moe_config( ...@@ -380,6 +429,12 @@ def try_get_optimal_moe_config(
# Else use the default config # Else use the default config
config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, config = get_default_config(M, E, N, w1_shape[2], top_k, dtype,
is_marlin) is_marlin)
# NOTE: For block-wise quant,
# BLOCK_K must be divisible by block_shape[1]
# BLOCK_N and BLOCK_M has no requirements
if block_shape is not None:
config["BLOCK_SIZE_N"] = block_shape[0]
config["BLOCK_SIZE_K"] = block_shape[1]
return config return config
...@@ -421,18 +476,29 @@ def fused_topk( ...@@ -421,18 +476,29 @@ def fused_topk(
return topk_weights, topk_ids return topk_weights, topk_ids
# This is used by the Deepseek-V2 model # This is used by the Deepseek-V2 and Deepseek-V3 model
def grouped_topk(hidden_states: torch.Tensor, def grouped_topk(hidden_states: torch.Tensor,
gating_output: torch.Tensor, gating_output: torch.Tensor,
topk: int, topk: int,
renormalize: bool, renormalize: bool,
num_expert_group: int = 0, num_expert_group: int = 0,
topk_group: int = 0): topk_group: int = 0,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None):
assert hidden_states.shape[0] == gating_output.shape[0], ( assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch") "Number of tokens mismatch")
scores = torch.softmax(gating_output, dim=-1) if scoring_func == "softmax":
scores = torch.softmax(gating_output, dim=-1)
elif scoring_func == "sigmoid":
scores = gating_output.sigmoid()
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
if e_score_correction_bias is not None:
scores.add_(e_score_correction_bias.unsqueeze(0))
num_token = scores.shape[0] num_token = scores.shape[0]
group_scores = scores.view(num_token, num_expert_group, group_scores = scores.view(num_token, num_expert_group,
-1).max(dim=-1).values # [n, n_group] -1).max(dim=-1).values # [n, n_group]
...@@ -479,10 +545,11 @@ def inplace_fused_experts(hidden_states: torch.Tensor, ...@@ -479,10 +545,11 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None) -> None: a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None) -> None:
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True, fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
use_fp8_w8a8, use_int8_w8a16, w1_scale, w2_scale, use_fp8_w8a8, use_int8_w8a16, w1_scale, w2_scale,
a1_scale, a2_scale) a1_scale, a2_scale, block_shape)
def inplace_fused_experts_fake( def inplace_fused_experts_fake(
...@@ -496,7 +563,8 @@ def inplace_fused_experts_fake( ...@@ -496,7 +563,8 @@ def inplace_fused_experts_fake(
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None) -> None: a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None) -> None:
pass pass
...@@ -519,10 +587,11 @@ def outplace_fused_experts( ...@@ -519,10 +587,11 @@ def outplace_fused_experts(
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None) -> torch.Tensor: a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None) -> torch.Tensor:
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
False, use_fp8_w8a8, use_int8_w8a16, w1_scale, False, use_fp8_w8a8, use_int8_w8a16, w1_scale,
w2_scale, a1_scale, a2_scale) w2_scale, a1_scale, a2_scale, block_shape)
def outplace_fused_experts_fake( def outplace_fused_experts_fake(
...@@ -536,7 +605,8 @@ def outplace_fused_experts_fake( ...@@ -536,7 +605,8 @@ def outplace_fused_experts_fake(
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None) -> torch.Tensor: a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None) -> torch.Tensor:
return torch.empty_like(hidden_states) return torch.empty_like(hidden_states)
...@@ -559,18 +629,22 @@ def fused_experts(hidden_states: torch.Tensor, ...@@ -559,18 +629,22 @@ def fused_experts(hidden_states: torch.Tensor,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None): a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None):
if inplace: if inplace:
torch.ops.vllm.inplace_fused_experts(hidden_states, w1, w2, torch.ops.vllm.inplace_fused_experts(hidden_states, w1, w2,
topk_weights, topk_ids, topk_weights, topk_ids,
use_fp8_w8a8, use_int8_w8a16, use_fp8_w8a8, use_int8_w8a16,
w1_scale, w2_scale, a1_scale, w1_scale, w2_scale, a1_scale,
a2_scale) a2_scale, block_shape)
return hidden_states return hidden_states
else: else:
return torch.ops.vllm.outplace_fused_experts( return torch.ops.vllm.outplace_fused_experts(hidden_states, w1, w2,
hidden_states, w1, w2, topk_weights, topk_ids, use_fp8_w8a8, topk_weights, topk_ids,
use_int8_w8a16, w1_scale, w2_scale, a1_scale, a2_scale) use_fp8_w8a8,
use_int8_w8a16, w1_scale,
w2_scale, a1_scale,
a2_scale, block_shape)
def fused_experts_impl(hidden_states: torch.Tensor, def fused_experts_impl(hidden_states: torch.Tensor,
...@@ -584,7 +658,8 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -584,7 +658,8 @@ def fused_experts_impl(hidden_states: torch.Tensor,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None): a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None):
# Check constraints. # Check constraints.
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
...@@ -611,6 +686,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -611,6 +686,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
w2.shape, w2.shape,
topk_ids.shape[1], topk_ids.shape[1],
config_dtype, config_dtype,
block_shape=block_shape,
) )
config = get_config_func(M) config = get_config_func(M)
...@@ -674,7 +750,8 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -674,7 +750,8 @@ def fused_experts_impl(hidden_states: torch.Tensor,
config, config,
compute_type=compute_type, compute_type=compute_type,
use_fp8_w8a8=use_fp8_w8a8, use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16) use_int8_w8a16=use_int8_w8a16,
block_shape=block_shape)
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
...@@ -693,7 +770,8 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -693,7 +770,8 @@ def fused_experts_impl(hidden_states: torch.Tensor,
config, config,
compute_type=compute_type, compute_type=compute_type,
use_fp8_w8a8=use_fp8_w8a8, use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16) use_int8_w8a16=use_int8_w8a16,
block_shape=block_shape)
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape), ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape),
out_hidden_states[begin_chunk_idx:end_chunk_idx]) out_hidden_states[begin_chunk_idx:end_chunk_idx])
...@@ -718,6 +796,7 @@ def fused_moe( ...@@ -718,6 +796,7 @@ def fused_moe(
w2_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
This function computes a Mixture of Experts (MoE) layer using two sets of This function computes a Mixture of Experts (MoE) layer using two sets of
...@@ -745,6 +824,12 @@ def fused_moe( ...@@ -745,6 +824,12 @@ def fused_moe(
w1. w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for - w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2. w2.
- a1_scale (Optional[torch.Tensor]): Optional scale to be used for
a1.
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for
a2.
- block_shape: (Optional[List[int]]): Optional block size for block-wise
quantization.
Returns: Returns:
- torch.Tensor: The output tensor after applying the MoE layer. - torch.Tensor: The output tensor after applying the MoE layer.
...@@ -775,4 +860,5 @@ def fused_moe( ...@@ -775,4 +860,5 @@ def fused_moe(
w1_scale=w1_scale, w1_scale=w1_scale,
w2_scale=w2_scale, w2_scale=w2_scale,
a1_scale=a1_scale, a1_scale=a1_scale,
a2_scale=a2_scale) a2_scale=a2_scale,
\ No newline at end of file block_shape=block_shape)
...@@ -29,6 +29,7 @@ class FusedMoeWeightScaleSupported(Enum): ...@@ -29,6 +29,7 @@ class FusedMoeWeightScaleSupported(Enum):
TENSOR = "tensor" TENSOR = "tensor"
CHANNEL = "channel" CHANNEL = "channel"
GROUP = "group" GROUP = "group"
BLOCK = "block"
class FusedMoEMethodBase(QuantizeMethodBase): class FusedMoEMethodBase(QuantizeMethodBase):
...@@ -40,9 +41,20 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -40,9 +41,20 @@ class FusedMoEMethodBase(QuantizeMethodBase):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def apply(self, layer: torch.nn.Module, x: torch.Tensor, def apply(
router_logits: torch.Tensor, top_k: int, renormalize: bool, self,
use_grouped_topk: bool) -> torch.Tensor: layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None
) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
...@@ -72,16 +84,18 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -72,16 +84,18 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
set_weight_attrs(w2_weight, extra_weight_attrs) set_weight_attrs(w2_weight, extra_weight_attrs)
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int, top_k: int,
renormalize: bool, renormalize: bool,
use_grouped_topk: bool, use_grouped_topk: bool = False,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None
) -> torch.Tensor: ) -> torch.Tensor:
return self.forward(x=x, return self.forward(x=x,
layer=layer, layer=layer,
...@@ -91,19 +105,23 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -91,19 +105,23 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
use_grouped_topk=use_grouped_topk, use_grouped_topk=use_grouped_topk,
topk_group=topk_group, topk_group=topk_group,
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function) custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
def forward_cuda( def forward_cuda(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
use_grouped_topk: bool, use_grouped_topk: bool,
top_k: int, top_k: int,
router_logits: torch.Tensor, router_logits: torch.Tensor,
renormalize: bool, renormalize: bool,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None
) -> torch.Tensor: ) -> torch.Tensor:
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
...@@ -113,7 +131,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -113,7 +131,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
renormalize=renormalize, renormalize=renormalize,
topk_group=topk_group, topk_group=topk_group,
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function) custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
return fused_experts(hidden_states=x, return fused_experts(hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
...@@ -127,21 +147,29 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -127,21 +147,29 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
"The CPU backend currently does not support MoE.") "The CPU backend currently does not support MoE.")
def forward_tpu( def forward_tpu(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
use_grouped_topk: bool, use_grouped_topk: bool,
top_k: int, top_k: int,
router_logits: torch.Tensor, router_logits: torch.Tensor,
renormalize: bool, renormalize: bool,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None
) -> torch.Tensor: ) -> torch.Tensor:
assert not use_grouped_topk assert not use_grouped_topk
assert num_expert_group is None assert num_expert_group is None
assert topk_group is None assert topk_group is None
assert custom_routing_function is None assert custom_routing_function is None
if scoring_func != "softmax":
raise NotImplementedError(
"Only softmax scoring function is supported for TPU.")
if e_score_correction_bias is not None:
raise NotImplementedError(
"Expert score correction bias is not supported for TPU.")
return fused_moe_pallas(hidden_states=x, return fused_moe_pallas(hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
w2=layer.w2_weight, w2=layer.w2_weight,
...@@ -155,7 +183,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -155,7 +183,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
class FusedMoE(torch.nn.Module): class FusedMoE(torch.nn.Module):
"""FusedMoE layer for MoE models. """FusedMoE layer for MoE models.
This layer contains both MergedColumnParallel weights (gate_up_proj / This layer contains both MergedColumnParallel weights (gate_up_proj /
w13) and RowParallelLinear weights (down_proj/ w2). w13) and RowParallelLinear weights (down_proj/ w2).
Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We
...@@ -189,6 +217,8 @@ class FusedMoE(torch.nn.Module): ...@@ -189,6 +217,8 @@ class FusedMoE(torch.nn.Module):
tp_size: Optional[int] = None, tp_size: Optional[int] = None,
prefix: str = "", prefix: str = "",
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
): ):
super().__init__() super().__init__()
...@@ -199,6 +229,7 @@ class FusedMoE(torch.nn.Module): ...@@ -199,6 +229,7 @@ class FusedMoE(torch.nn.Module):
get_tensor_model_parallel_world_size()) get_tensor_model_parallel_world_size())
self.top_k = top_k self.top_k = top_k
self.num_experts = num_experts self.num_experts = num_experts
assert intermediate_size % self.tp_size == 0
self.intermediate_size_per_partition = intermediate_size // self.tp_size self.intermediate_size_per_partition = intermediate_size // self.tp_size
self.reduce_results = reduce_results self.reduce_results = reduce_results
self.renormalize = renormalize self.renormalize = renormalize
...@@ -208,6 +239,12 @@ class FusedMoE(torch.nn.Module): ...@@ -208,6 +239,12 @@ class FusedMoE(torch.nn.Module):
self.num_expert_group = num_expert_group self.num_expert_group = num_expert_group
self.topk_group = topk_group self.topk_group = topk_group
self.custom_routing_function = custom_routing_function self.custom_routing_function = custom_routing_function
self.scoring_func = scoring_func
self.e_score_correction_bias = e_score_correction_bias
if self.scoring_func != "softmax" and not self.use_grouped_topk:
raise ValueError("Only softmax scoring function is supported for "
"non-grouped topk.")
if quant_config is None: if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = ( self.quant_method: Optional[QuantizeMethodBase] = (
...@@ -398,7 +435,10 @@ class FusedMoE(torch.nn.Module): ...@@ -398,7 +435,10 @@ class FusedMoE(torch.nn.Module):
loaded_weight=loaded_weight, loaded_weight=loaded_weight,
expert_data=expert_data, expert_data=expert_data,
tp_rank=tp_rank) tp_rank=tp_rank)
elif quant_method == FusedMoeWeightScaleSupported.GROUP.value: elif quant_method in [
FusedMoeWeightScaleSupported.GROUP.value,
FusedMoeWeightScaleSupported.BLOCK.value,
]:
self._load_model_weight_or_group_weight_scale( self._load_model_weight_or_group_weight_scale(
shard_id=shard_id, shard_id=shard_id,
shard_dim=shard_dim, shard_dim=shard_dim,
...@@ -441,7 +481,9 @@ class FusedMoE(torch.nn.Module): ...@@ -441,7 +481,9 @@ class FusedMoE(torch.nn.Module):
renormalize: bool, renormalize: bool,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None): custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None):
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, grouped_topk) fused_topk, grouped_topk)
...@@ -455,7 +497,9 @@ class FusedMoE(torch.nn.Module): ...@@ -455,7 +497,9 @@ class FusedMoE(torch.nn.Module):
topk=top_k, topk=top_k,
renormalize=renormalize, renormalize=renormalize,
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
topk_group=topk_group) topk_group=topk_group,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
elif custom_routing_function is None: elif custom_routing_function is None:
topk_weights, topk_ids = fused_topk(hidden_states=hidden_states, topk_weights, topk_ids = fused_topk(hidden_states=hidden_states,
gating_output=router_logits, gating_output=router_logits,
...@@ -484,7 +528,9 @@ class FusedMoE(torch.nn.Module): ...@@ -484,7 +528,9 @@ class FusedMoE(torch.nn.Module):
use_grouped_topk=self.use_grouped_topk, use_grouped_topk=self.use_grouped_topk,
topk_group=self.topk_group, topk_group=self.topk_group,
num_expert_group=self.num_expert_group, num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function) custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias)
if self.reduce_results and self.tp_size > 1: if self.reduce_results and self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states = tensor_model_parallel_all_reduce(
......
...@@ -14,11 +14,14 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank, ...@@ -14,11 +14,14 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
# yapf: disable
from vllm.model_executor.parameter import (BasevLLMParameter, from vllm.model_executor.parameter import (BasevLLMParameter,
BlockQuantScaleParameter,
PackedColumnParameter, PackedColumnParameter,
PackedvLLMParameter, PackedvLLMParameter,
PerTensorScaleParameter, PerTensorScaleParameter,
RowvLLMParameter) RowvLLMParameter)
# yapf: enable
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
import os import os
...@@ -642,8 +645,24 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -642,8 +645,24 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
assert loaded_shard_id < len(self.output_sizes) assert loaded_shard_id < len(self.output_sizes)
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
shard_size = self.output_sizes[loaded_shard_id] // tp_size if isinstance(param, BlockQuantScaleParameter):
from vllm.model_executor.layers.quantization.fp8 import (
Fp8LinearMethod, Fp8MoEMethod)
assert self.quant_method is not None
assert isinstance(self.quant_method,
(Fp8LinearMethod, Fp8MoEMethod))
weight_block_size = self.quant_method.quant_config.weight_block_size
assert weight_block_size is not None
block_n, _ = weight_block_size[0], weight_block_size[1]
shard_offset = (
(sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) //
block_n) // tp_size
shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) //
block_n // tp_size)
else:
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
shard_size = self.output_sizes[loaded_shard_id] // tp_size
param.load_merged_column_weight(loaded_weight=loaded_weight, param.load_merged_column_weight(loaded_weight=loaded_weight,
shard_id=loaded_shard_id, shard_id=loaded_shard_id,
......
...@@ -440,11 +440,13 @@ class AWQMoEMethod(FusedMoEMethodBase): ...@@ -440,11 +440,13 @@ class AWQMoEMethod(FusedMoEMethodBase):
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int, top_k: int,
renormalize: bool = True, renormalize: bool,
use_grouped_topk: bool = False, use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
...@@ -454,7 +456,9 @@ class AWQMoEMethod(FusedMoEMethodBase): ...@@ -454,7 +456,9 @@ class AWQMoEMethod(FusedMoEMethodBase):
renormalize=renormalize, renormalize=renormalize,
topk_group=topk_group, topk_group=topk_group,
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function) custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
return torch.ops.vllm.fused_marlin_moe( return torch.ops.vllm.fused_marlin_moe(
x, x,
......
from typing import Any, Dict, List, Optional, cast from typing import Any, Dict, List, Literal, Optional, cast
import torch import torch
from compressed_tensors.config import CompressionFormat from compressed_tensors.config import (CompressionFormat,
SparsityCompressionConfig,
SparsityStructure)
from compressed_tensors.quantization import (QuantizationArgs, from compressed_tensors.quantization import (QuantizationArgs,
QuantizationStrategy, QuantizationStrategy,
QuantizationType) QuantizationType)
...@@ -15,7 +17,7 @@ from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 ...@@ -15,7 +17,7 @@ from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501 from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501
CompressedTensorsMoEMethod) CompressedTensorsMoEMethod)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, CompressedTensors24,
CompressedTensorsScheme, CompressedTensorsW4A16Sparse24, CompressedTensorsScheme, CompressedTensorsW4A16Sparse24,
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16) CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
...@@ -27,20 +29,29 @@ from vllm.platforms import current_platform ...@@ -27,20 +29,29 @@ from vllm.platforms import current_platform
__all__ = ["CompressedTensorsLinearMethod"] __all__ = ["CompressedTensorsLinearMethod"]
SPARSITY_CONFIG_NAME: Literal["sparsity_config"] = "sparsity_config"
QUANTIZATION_SCHEME_MAP_TYPE = Dict[str, Optional[Dict[str, QuantizationArgs]]]
class CompressedTensorsConfig(QuantizationConfig): class CompressedTensorsConfig(QuantizationConfig):
def __init__(self, def __init__(
target_scheme_map: Dict[str, Any], self,
ignore: List[str], target_scheme_map: Dict[str, Any],
quant_format: str, ignore: List[str],
kv_cache_scheme: Optional[Dict[str, Any]] = None): quant_format: str,
sparsity_scheme_map: Dict[str, SparsityCompressionConfig],
kv_cache_scheme: Optional[Dict[str, Any]] = None,
config: Optional[Dict[str, Any]] = None,
):
self.ignore = ignore self.ignore = ignore
self.quant_format = quant_format self.quant_format = quant_format
# Map from [target -> scheme] # Map from [target -> scheme]
self.target_scheme_map = target_scheme_map self.target_scheme_map = target_scheme_map
self.kv_cache_scheme = kv_cache_scheme self.kv_cache_scheme = kv_cache_scheme
self.sparsity_scheme_map = sparsity_scheme_map
self.config = config
def get_linear_method(self) -> "CompressedTensorsLinearMethod": def get_linear_method(self) -> "CompressedTensorsLinearMethod":
return CompressedTensorsLinearMethod(self) return CompressedTensorsLinearMethod(self)
...@@ -78,8 +89,50 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -78,8 +89,50 @@ class CompressedTensorsConfig(QuantizationConfig):
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig": def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
ignore: List[str] = cast(List[str], config.get("ignore", []))
quant_format = cast(str, config.get("format"))
target_scheme_map = cls._quantization_scheme_map_from_config(
config=config)
sparsity_scheme_map = cls._sparsity_scheme_map_from_config(
config=config)
return cls(
target_scheme_map=target_scheme_map,
ignore=ignore,
quant_format=quant_format,
sparsity_scheme_map=sparsity_scheme_map,
config=config,
)
@classmethod
def _sparsity_scheme_map_from_config(
cls, config: Dict[str,
Any]) -> Dict[str, SparsityCompressionConfig]:
"""
:param config: The `quantization_config` dictionary from config.json
:return: A dictionary mapping target layer names to their corresponding
sparsity compression configurations
"""
if (sparsity_config := config.get(SPARSITY_CONFIG_NAME)) is None:
return dict()
sparsity_config = SparsityCompressionConfig.model_validate(
sparsity_config)
sparse_scheme_map: Dict[str, SparsityCompressionConfig] = {
target: sparsity_config
for target in sparsity_config.targets or list()
}
return sparse_scheme_map
@classmethod
def _quantization_scheme_map_from_config(
cls, config: Dict[str, Any]) -> QUANTIZATION_SCHEME_MAP_TYPE:
"""
:param config: The `quantization_config` dictionary from config.json
:return: A dictionary mapping target layer names to their corresponding
quantization_args for weights and input activations
"""
target_scheme_map: Dict[str, Any] = dict() target_scheme_map: Dict[str, Any] = dict()
ignore = cast(List[str], config.get("ignore"))
quant_format = cast(str, config.get("format")) quant_format = cast(str, config.get("format"))
# The quant_config has multiple config_groups, each containing # The quant_config has multiple config_groups, each containing
...@@ -90,12 +143,14 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -90,12 +143,14 @@ class CompressedTensorsConfig(QuantizationConfig):
# details follow the structure defined by the QuantizationArgs # details follow the structure defined by the QuantizationArgs
# pydantic model, which is used to verify the structure of the # pydantic model, which is used to verify the structure of the
# quant_config and also store the details for later use. # quant_config and also store the details for later use.
for _, quant_config in config["config_groups"].items():
config_groups = config.get("config_groups", dict())
for _, quant_config in config_groups.items():
targets = quant_config.get("targets") targets = quant_config.get("targets")
for target in targets: for target in targets:
target_scheme_map[target] = {} target_scheme_map[target] = {}
target_scheme_map[target][ target_scheme_map[target][
"weights"] = QuantizationArgs.parse_obj( "weights"] = QuantizationArgs.model_validate(
quant_config.get("weights")) quant_config.get("weights"))
target_scheme_map[target]["input_activations"] = None target_scheme_map[target]["input_activations"] = None
...@@ -110,13 +165,9 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -110,13 +165,9 @@ class CompressedTensorsConfig(QuantizationConfig):
"weights"].type == QuantizationType.FLOAT "weights"].type == QuantizationType.FLOAT
else: else:
target_scheme_map[target][ target_scheme_map[target][
"input_activations"] = QuantizationArgs.parse_obj( "input_activations"] = QuantizationArgs.model_validate( # noqa: E501
quant_config.get("input_activations")) quant_config.get("input_activations"))
return target_scheme_map
return cls(target_scheme_map=target_scheme_map,
ignore=ignore,
quant_format=quant_format,
kv_cache_scheme=config.get("kv_cache_scheme"))
@classmethod @classmethod
def get_config_filenames(cls) -> List[str]: def get_config_filenames(cls) -> List[str]:
...@@ -315,23 +366,105 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -315,23 +366,105 @@ class CompressedTensorsConfig(QuantizationConfig):
# TODO (@robertgshaw): add compressed-tensors as dep # TODO (@robertgshaw): add compressed-tensors as dep
# so we do not have to re-write these functions # so we do not have to re-write these functions
# need to make accelerate optional in ct to do this # need to make accelerate optional in ct to do this
matched_target = find_matched_target(
layer_name=layer_name,
module=layer,
targets=self.target_scheme_map.keys())
# Find the quant_scheme # Will be empty for models with only sparsity
scheme_dict = self.target_scheme_map[matched_target] if self.target_scheme_map:
scheme = self._get_scheme_from_parts( matched_target = find_matched_target(
weight_quant=scheme_dict["weights"], layer_name=layer_name,
input_quant=scheme_dict["input_activations"]) module=layer,
targets=self.target_scheme_map.keys())
scheme_dict = self.target_scheme_map[matched_target]
weight_quant = scheme_dict.get("weights")
input_quant = scheme_dict.get("input_activations")
elif self.sparsity_scheme_map:
matched_target = find_matched_target(
layer_name=layer_name,
module=layer,
targets=self.sparsity_scheme_map.keys())
weight_quant = None
input_quant = None
# For models with sparsity, assumes that the sparse layers are also
# quantized for cutlass 2:4 support
sparsity_scheme: Optional[
SparsityCompressionConfig] = self.sparsity_scheme_map.get(
matched_target)
if self.supports_cutlass_24(weight_quant=weight_quant,
input_quant=input_quant,
sparsity_scheme=sparsity_scheme):
# Have a valid sparsity scheme
# Validate layer is supported by Cutlass 2:4 Kernel
scheme = CompressedTensors24(quantized=weight_quant is not None
or input_quant is not None,
weight_quant=weight_quant,
input_quant=input_quant)
else:
# Find the quant_scheme
scheme = self._get_scheme_from_parts( # type: ignore
weight_quant=weight_quant,
input_quant=input_quant,
)
# Raise error if device does not support the scheme # Raise error if device does not support the scheme
# (e.g. fp8 needs ada lovelace) # (e.g. fp8 needs ada lovelace)
self._check_scheme_supported(scheme.get_min_capability()) self._check_scheme_supported(scheme.get_min_capability())
return scheme return scheme
@staticmethod
def supports_cutlass_24(
weight_quant: Optional[QuantizationArgs],
input_quant: Optional[QuantizationArgs],
sparsity_scheme: Optional[SparsityCompressionConfig] = None
) -> bool:
"""
Check if the layer is supported by the Cutlass 2:4 Kernel
Conditions:
- Overarching condition: Sparsity Structure is 2:4
- Unquantized cases are supported
- Weight only quantization is not-supported
- Supported weight quantization strategies are TENSOR and CHANNEL
- Supported input quantization strategies are TENSOR and TOKEN
- Only 8 bit quantization is supported
:return: True if the layer is supported by the Cutlass 2:4 Kernel
False otherwise
"""
is_valid_sparsity = (sparsity_scheme is not None
and sparsity_scheme.sparsity_structure
== SparsityStructure.TWO_FOUR.value
and sparsity_scheme.format == "dense")
if not is_valid_sparsity:
return False
# Unquantized cases are supported
if weight_quant is None and input_quant is None:
return True
# Weight only quantization is not-supported
if weight_quant is not None and input_quant is None:
return False
supported_weight_quant_strategies = [
QuantizationStrategy.TENSOR.value,
QuantizationStrategy.CHANNEL.value
]
assert weight_quant is not None
assert input_quant is not None
if weight_quant.strategy not in supported_weight_quant_strategies:
return False
supported_input_quant_strategies = [
QuantizationStrategy.TENSOR.value, QuantizationStrategy.TOKEN.value
]
if input_quant.strategy not in supported_input_quant_strategies:
return False
return weight_quant.num_bits == input_quant.num_bits == 8
class CompressedTensorsLinearMethod(LinearMethodBase): class CompressedTensorsLinearMethod(LinearMethodBase):
......
...@@ -203,13 +203,14 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -203,13 +203,14 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int, top_k: int,
renormalize: bool = True, renormalize: bool,
use_grouped_topk: bool = False, use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids = FusedMoE.select_experts(
...@@ -220,7 +221,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -220,7 +221,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
renormalize=renormalize, renormalize=renormalize,
topk_group=topk_group, topk_group=topk_group,
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function) custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
return fused_experts(x, return fused_experts(x,
layer.w13_weight, layer.w13_weight,
...@@ -476,12 +479,15 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -476,12 +479,15 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int, top_k: int,
renormalize: bool = True, renormalize: bool,
use_grouped_topk: bool = False, use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
...@@ -490,7 +496,9 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -490,7 +496,9 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
renormalize=renormalize, renormalize=renormalize,
topk_group=topk_group, topk_group=topk_group,
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function) custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
return torch.ops.vllm.fused_marlin_moe( return torch.ops.vllm.fused_marlin_moe(
x, x,
......
...@@ -7,13 +7,12 @@ from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8 ...@@ -7,13 +7,12 @@ from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8
from .compressed_tensors_wNa16 import (WNA16_SUPPORTED_BITS, from .compressed_tensors_wNa16 import (WNA16_SUPPORTED_BITS,
CompressedTensorsWNA16) CompressedTensorsWNA16)
from .compressed_tensors_24 import CompressedTensors24 # isort: skip
__all__ = [ __all__ = [
"CompressedTensorsScheme", "CompressedTensorsScheme", "CompressedTensorsWNA16",
"CompressedTensorsWNA16", "CompressedTensorsW8A16Fp8", "CompressedTensorsW4A16Sparse24",
"CompressedTensorsW8A16Fp8", "CompressedTensorsW8A8Int8", "CompressedTensorsW8A8Fp8",
"CompressedTensorsW4A16Sparse24", "WNA16_SUPPORTED_BITS", "W4A16SPARSE24_SUPPORTED_BITS",
"CompressedTensorsW8A8Int8", "CompressedTensors24"
"CompressedTensorsW8A8Fp8",
"WNA16_SUPPORTED_BITS",
"W4A16SPARSE24_SUPPORTED_BITS",
] ]
from typing import Callable, List, Optional
import torch
from compressed_tensors.quantization import (QuantizationArgs,
QuantizationStrategy,
QuantizationType)
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise, sparse_cutlass_supported)
from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)
__all__ = ["CompressedTensors24"]
class CompressedTensors24(CompressedTensorsScheme):
def __init__(self,
quantized: bool = False,
weight_quant: Optional[QuantizationArgs] = None,
input_quant: Optional[QuantizationArgs] = None):
self.quantized = quantized
self.weight_quant = weight_quant
self.input_quant = input_quant
@classmethod
def get_min_capability(cls) -> int:
# Only cutlass 3.x kernels are implemented so far
return 90
def create_weights(self, layer: torch.nn.Module, input_size: int,
output_partition_sizes: List[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
if not sparse_cutlass_supported():
raise ValueError(
"Sparse CUTLASS not supported. vLLM must be built with"
"CUDA 12.2 or later to use this feature")
self.output_dtype = params_dtype
layer.logical_widths = output_partition_sizes
self.weights_dtype: torch.dtype = self._get_params_dtype(params_dtype)
# parameter to store uncompressed weight
weight = ModelWeightParameter(data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition,
dtype=self.weights_dtype),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
# Check if quantized, not just 2:4 Sparse
if self.quantized:
if (self.weight_quant and self.weight_quant.strategy
== QuantizationStrategy.CHANNEL.value):
weight_scale = ChannelQuantScaleParameter(
data=torch.empty((sum(output_partition_sizes), 1),
dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader)
else:
assert (self.weight_quant and self.weight_quant.strategy
== QuantizationStrategy.TENSOR.value)
weight_scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes),
dtype=torch.float32),
weight_loader=weight_loader)
layer.register_parameter("weight_scale", weight_scale)
# input quant will be non-none
if self.input_quant and not self.input_quant.dynamic:
# register input quant scale
assert (self.input_quant.strategy ==
QuantizationStrategy.TENSOR.value)
input_scale = BasevLLMParameter(data=torch.empty(
1, dtype=torch.float32),
weight_loader=weight_loader)
layer.register_parameter("input_scale", input_scale)
else:
# for sparse-only, pass in 1 for weight/input scales
weight_scale = torch.nn.Parameter(data=torch.ones(
1, dtype=torch.float32),
requires_grad=False)
input_scale = torch.nn.Parameter(data=torch.ones(
1, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("input_scale", input_scale)
layer.register_parameter("weight_scale", weight_scale)
layer.register_parameter("weight", weight)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
"""
Compress weights after loading. Store compressed weight and meta
tensor
:post-condition: layer.w_compressed and layer.meta are
set to the compressed weight and meta tensor in the
format expected by the Cutlass kernels
:param layer: The layer with the weights to be processed
"""
# torch.compile workaround
if hasattr(layer, "input_scale"):
layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
requires_grad=False)
if self.weight_quant:
if self.weight_quant.strategy == QuantizationStrategy.TENSOR.value:
layer.weight_scale = torch.nn.Parameter(convert_to_channelwise(
weight_scale=layer.weight_scale,
logical_widths=layer.logical_widths),
requires_grad=False)
else:
# torch.compile workaround
layer.weight_scale = torch.nn.Parameter(
layer.weight_scale.data, requires_grad=False)
w_compressed, meta = ops.cutlass_sparse_compress(layer.weight.data)
layer.weight = torch.nn.Parameter(w_compressed, requires_grad=False)
layer.meta = torch.nn.Parameter(meta, requires_grad=False)
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Returns the output tensor for the layer with 2:4
sparse compressed weights, given the input tensor
and bias
:param layer: The layer with 2:4 sparse compressed
weights to be used for the computation
:param x: The input tensor to the layer
:param bias: The bias to be added to the output tensor
:return: The output tensor of the layer
"""
if self.quantized:
scale = None
if hasattr(layer, "input_scale"):
scale = layer.input_scale
if self.weights_dtype == torch.int8:
ops_output = ops.scaled_int8_quant(x, scale=scale)
q_input = ops_output[0]
input_scale = ops_output[1]
else:
assert self.weights_dtype == torch.float8_e4m3fn
if scale is not None:
q_input, input_scale = ops.scaled_fp8_quant(x, scale=scale)
else:
q_input, input_scale = ops.scaled_fp8_quant(
x, use_per_token_if_dynamic=True)
else:
# Not quantized, nothing to do with the input_scales, use as is
input_scale = layer.input_scale
q_input = x
out = ops.cutlass_scaled_sparse_mm(a=q_input,
bt_nzs=layer.weight,
bt_meta=layer.meta,
scale_a=input_scale,
scale_b=layer.weight_scale,
out_dtype=self.output_dtype,
bias=bias)
assert out.is_contiguous()
return out
def _get_params_dtype(self, params_dtype: torch.dtype) -> torch.dtype:
if not self.quantized:
return params_dtype
assert self.weight_quant is not None
assert self.input_quant is not None
is_8_bits = self.weight_quant.num_bits == self.input_quant.num_bits == 8
if not is_8_bits:
raise ValueError("Cutlass only supports 8-bit quantization")
if (self.weight_quant.type == QuantizationType.FLOAT
and self.input_quant.type == QuantizationType.FLOAT):
return torch.float8_e4m3fn
if (self.weight_quant.type == QuantizationType.INT
and self.input_quant.type == QuantizationType.INT):
return torch.int8
raise ValueError("Quantization type not supported by Cutlass")
def check_24(tensor):
new_tensor = tensor.view(-1, 4)
zero_counts = (new_tensor == 0).sum(dim=1)
return (zero_counts >= 2).all().item()
...@@ -61,6 +61,10 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme): ...@@ -61,6 +61,10 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
params_dtype: torch.dtype, weight_loader: Callable, params_dtype: torch.dtype, weight_loader: Callable,
**kwargs): **kwargs):
assert params_dtype == torch.float16, (
"float16 is required for marlin24 compressd models. Set dtype=torch.float16" # noqa: E501
)
pack_factor = 32 // self.quant_type.size_bits pack_factor = 32 // self.quant_type.size_bits
output_size_per_partition = sum(output_partition_sizes) output_size_per_partition = sum(output_partition_sizes)
......
...@@ -30,7 +30,7 @@ def should_ignore_layer(layer_name: Optional[str], ...@@ -30,7 +30,7 @@ def should_ignore_layer(layer_name: Optional[str],
# in the safetensors checkpoint. So, we convert the name # in the safetensors checkpoint. So, we convert the name
# from the fused version to unfused + check to make sure that # from the fused version to unfused + check to make sure that
# each shard of the fused layer has the same scheme. # each shard of the fused layer has the same scheme.
if proj_name in FUSED_LAYER_NAME_MAPPING: if proj_name in FUSED_LAYER_NAME_MAPPING and layer_name not in ignore:
shard_proj_names = FUSED_LAYER_NAME_MAPPING[proj_name] shard_proj_names = FUSED_LAYER_NAME_MAPPING[proj_name]
# Convert fused_name --> [shard_names] # Convert fused_name --> [shard_names]
......
...@@ -99,11 +99,13 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase): ...@@ -99,11 +99,13 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int, top_k: int,
renormalize: bool = True, renormalize: bool,
use_grouped_topk: bool = False, use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
...@@ -115,7 +117,9 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase): ...@@ -115,7 +117,9 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
renormalize=renormalize, renormalize=renormalize,
topk_group=topk_group, topk_group=topk_group,
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function) custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
return fused_experts(x, return fused_experts(x,
layer.w13_weight, layer.w13_weight,
......
...@@ -6,6 +6,7 @@ from torch.nn.parameter import Parameter ...@@ -6,6 +6,7 @@ from torch.nn.parameter import Parameter
import vllm.envs as envs import vllm.envs as envs
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
FusedMoeWeightScaleSupported) FusedMoeWeightScaleSupported)
...@@ -14,6 +15,8 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, ...@@ -14,6 +15,8 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
apply_w8a8_block_fp8_linear)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin) apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
...@@ -22,7 +25,8 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( ...@@ -22,7 +25,8 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d, apply_fp8_linear, convert_to_channelwise, all_close_1d, apply_fp8_linear, convert_to_channelwise,
cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize, cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
requantize_with_max_scale) requantize_with_max_scale)
from vllm.model_executor.parameter import (ModelWeightParameter, from vllm.model_executor.parameter import (BlockQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter) PerTensorScaleParameter)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -41,6 +45,7 @@ class Fp8Config(QuantizationConfig): ...@@ -41,6 +45,7 @@ class Fp8Config(QuantizationConfig):
is_checkpoint_fp8_serialized: bool = False, is_checkpoint_fp8_serialized: bool = False,
activation_scheme: str = "dynamic", activation_scheme: str = "dynamic",
ignored_layers: Optional[List[str]] = None, ignored_layers: Optional[List[str]] = None,
weight_block_size: Optional[List[int]] = None,
) -> None: ) -> None:
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
if is_checkpoint_fp8_serialized: if is_checkpoint_fp8_serialized:
...@@ -51,6 +56,20 @@ class Fp8Config(QuantizationConfig): ...@@ -51,6 +56,20 @@ class Fp8Config(QuantizationConfig):
f"Unsupported activation scheme {activation_scheme}") f"Unsupported activation scheme {activation_scheme}")
self.activation_scheme = activation_scheme self.activation_scheme = activation_scheme
self.ignored_layers = ignored_layers or [] self.ignored_layers = ignored_layers or []
if weight_block_size is not None:
if not is_checkpoint_fp8_serialized:
raise ValueError(
"The block-wise quantization only supports fp8-serialized "
"checkpoint for now.")
if len(weight_block_size) != 2:
raise ValueError(
"The quantization block size of weight must have 2 "
f"dimensions, but got {len(weight_block_size)} dimensions")
if activation_scheme != "dynamic":
raise ValueError("The block-wise quantization only supports "
"dynamic activation scheme for now, but got "
f"{activation_scheme} activation scheme.")
self.weight_block_size = weight_block_size
@classmethod @classmethod
def get_name(cls) -> str: def get_name(cls) -> str:
...@@ -74,9 +93,12 @@ class Fp8Config(QuantizationConfig): ...@@ -74,9 +93,12 @@ class Fp8Config(QuantizationConfig):
is_checkpoint_fp8_serialized = ("fp8" in quant_method) is_checkpoint_fp8_serialized = ("fp8" in quant_method)
activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None) ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"],
None)
return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
activation_scheme=activation_scheme, activation_scheme=activation_scheme,
ignored_layers=ignored_layers) ignored_layers=ignored_layers,
weight_block_size=weight_block_size)
def get_quant_method(self, layer: torch.nn.Module, def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]: prefix: str) -> Optional["QuantizeMethodBase"]:
...@@ -123,6 +145,11 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -123,6 +145,11 @@ class Fp8LinearMethod(LinearMethodBase):
if current_platform.is_rocm(): if current_platform.is_rocm():
self.use_marlin = False self.use_marlin = False
self.block_quant = self.quant_config.weight_block_size is not None
if self.block_quant:
# Marlin doesn't support block-wise fp8
self.use_marlin = False
def create_weights( def create_weights(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -133,10 +160,34 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -133,10 +160,34 @@ class Fp8LinearMethod(LinearMethodBase):
params_dtype: torch.dtype, params_dtype: torch.dtype,
**extra_weight_attrs, **extra_weight_attrs,
): ):
del input_size, output_size
output_size_per_partition = sum(output_partition_sizes) output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader") weight_loader = extra_weight_attrs.get("weight_loader")
if self.block_quant:
tp_size = get_tensor_model_parallel_world_size()
assert self.quant_config.weight_block_size is not None
block_n, block_k = (
self.quant_config.weight_block_size[0],
self.quant_config.weight_block_size[1],
)
# Required by row parallel
if (tp_size > 1
and input_size // input_size_per_partition == tp_size
and input_size_per_partition % block_k != 0):
raise ValueError(
f"Weight input_size_per_partition = "
f"{input_size_per_partition} is not divisible by "
f"weight quantization block_k = {block_k}.")
# Required by column parallel or enabling merged weights
if (tp_size > 1 and output_size // output_size_per_partition
== tp_size) or len(output_partition_sizes) > 1:
for output_partition_size in output_partition_sizes:
if output_partition_size % block_n != 0:
raise ValueError(
f"Weight output_partition_size = "
f"{output_partition_size} is not divisible by "
f"weight quantization block_n = {block_n}.")
layer.logical_widths = output_partition_sizes layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition layer.input_size_per_partition = input_size_per_partition
...@@ -161,12 +212,29 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -161,12 +212,29 @@ class Fp8LinearMethod(LinearMethodBase):
# Otherwise, wait until process_weights_after_loading. # Otherwise, wait until process_weights_after_loading.
if self.quant_config.is_checkpoint_fp8_serialized: if self.quant_config.is_checkpoint_fp8_serialized:
# WEIGHT SCALE # WEIGHT SCALE
scale = PerTensorScaleParameter(data=torch.empty( if not self.block_quant:
len(output_partition_sizes), dtype=torch.float32), scale = PerTensorScaleParameter(
weight_loader=weight_loader) data=torch.empty(len(output_partition_sizes),
dtype=torch.float32),
scale[:] = torch.finfo(torch.float32).min weight_loader=weight_loader,
layer.register_parameter("weight_scale", scale) )
scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("weight_scale", scale)
else:
assert self.quant_config.activation_scheme == "dynamic"
scale = BlockQuantScaleParameter(
data=torch.empty(
(output_size_per_partition + block_n - 1) // block_n,
(input_size_per_partition + block_k - 1) // block_k,
dtype=torch.float32,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
scale[:] = torch.finfo(torch.float32).min
# The weight_scale_inv name is intentional for deepseekv3
layer.register_parameter("weight_scale_inv", scale)
# INPUT ACTIVATION SCALE # INPUT ACTIVATION SCALE
if self.quant_config.activation_scheme == "static": if self.quant_config.activation_scheme == "static":
...@@ -180,6 +248,9 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -180,6 +248,9 @@ class Fp8LinearMethod(LinearMethodBase):
layer.register_parameter("input_scale", None) layer.register_parameter("input_scale", None)
def process_weights_after_loading(self, layer: Module) -> None: def process_weights_after_loading(self, layer: Module) -> None:
# Block quant doesn't need to process weights after loading
if self.block_quant:
return
layer.weight = torch.nn.Parameter(layer.weight.data, layer.weight = torch.nn.Parameter(layer.weight.data,
requires_grad=False) requires_grad=False)
# If checkpoint not serialized fp8, quantize the weights. # If checkpoint not serialized fp8, quantize the weights.
...@@ -266,6 +337,17 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -266,6 +337,17 @@ class Fp8LinearMethod(LinearMethodBase):
size_k=layer.input_size_per_partition, size_k=layer.input_size_per_partition,
bias=bias) bias=bias)
if self.block_quant:
assert self.quant_config.weight_block_size is not None
return apply_w8a8_block_fp8_linear(
input=x,
weight=layer.weight,
block_size=self.quant_config.weight_block_size,
weight_scale=layer.weight_scale_inv,
input_scale=layer.input_scale,
bias=bias,
)
return apply_fp8_linear( return apply_fp8_linear(
input=x, input=x,
weight=layer.weight, weight=layer.weight,
...@@ -291,6 +373,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -291,6 +373,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def __init__(self, quant_config: Fp8Config): def __init__(self, quant_config: Fp8Config):
self.quant_config = quant_config self.quant_config = quant_config
self.block_quant = self.quant_config.weight_block_size is not None
def create_weights(self, layer: Module, num_experts: int, hidden_size: int, def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
intermediate_size: int, params_dtype: torch.dtype, intermediate_size: int, params_dtype: torch.dtype,
...@@ -298,6 +381,27 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -298,6 +381,27 @@ class Fp8MoEMethod(FusedMoEMethodBase):
if self.quant_config.is_checkpoint_fp8_serialized: if self.quant_config.is_checkpoint_fp8_serialized:
params_dtype = torch.float8_e4m3fn params_dtype = torch.float8_e4m3fn
if self.block_quant:
assert self.quant_config.weight_block_size is not None
tp_size = get_tensor_model_parallel_world_size()
block_n, block_k = (
self.quant_config.weight_block_size[0],
self.quant_config.weight_block_size[1],
)
# NOTE: To ensure proper alignment of the block-wise quantization
# scales, the output_size of the weights for both the gate and up
# layers must be divisible by block_n.
# Required by column parallel or enabling merged weights
if intermediate_size % block_n != 0:
raise ValueError(
f"The output_size of gate's and up's weight = "
f"{intermediate_size} is not divisible by "
f"weight quantization block_n = {block_n}.")
if (tp_size > 1 and intermediate_size % block_k != 0):
# Required by row parallel
raise ValueError(f"The input_size of down's weight = "
f"{intermediate_size} is not divisible by "
f"weight quantization block_k = {block_k}.")
# WEIGHTS # WEIGHTS
w13_weight = torch.nn.Parameter(torch.empty(num_experts, w13_weight = torch.nn.Parameter(torch.empty(num_experts,
...@@ -317,21 +421,45 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -317,21 +421,45 @@ class Fp8MoEMethod(FusedMoEMethodBase):
set_weight_attrs(w2_weight, extra_weight_attrs) set_weight_attrs(w2_weight, extra_weight_attrs)
# WEIGHT_SCALES # WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively. if not self.block_quant:
# They will be combined to a single scale after weight loading. # Allocate 2 scales for w1 and w3 respectively.
w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts, # They will be combined to a single scale after weight loading.
2, w13_weight_scale = torch.nn.Parameter(torch.ones(
dtype=torch.float32), num_experts, 2, dtype=torch.float32),
requires_grad=False) requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_weight_scale) w2_weight_scale = torch.nn.Parameter(torch.ones(
num_experts, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
else:
w13_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts,
2 * ((intermediate_size + block_n - 1) // block_n),
(hidden_size + block_k - 1) // block_k,
dtype=torch.float32,
),
requires_grad=False,
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts,
(hidden_size + block_n - 1) // block_n,
(intermediate_size + block_k - 1) // block_k,
dtype=torch.float32,
),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
assert self.quant_config.activation_scheme == "dynamic"
w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Add the quantization method used (per tensor/grouped/channel) # Add the quantization method used (per tensor/grouped/channel)
# to ensure the weight scales are loaded in properly # to ensure the weight scales are loaded in properly
extra_weight_attrs.update( extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.
value} if self.block_quant else
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
# If loading fp8 checkpoint, pass the weight loaders. # If loading fp8 checkpoint, pass the weight loaders.
# If loading an fp16 checkpoint, do not (we will quantize in # If loading an fp16 checkpoint, do not (we will quantize in
...@@ -364,7 +492,9 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -364,7 +492,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w2_input_scale = None layer.w2_input_scale = None
def process_weights_after_loading(self, layer: Module) -> None: def process_weights_after_loading(self, layer: Module) -> None:
# Block quant doesn't need to process weights after loading
if self.block_quant:
return
# If checkpoint is fp16, quantize in place. # If checkpoint is fp16, quantize in place.
if not self.quant_config.is_checkpoint_fp8_serialized: if not self.quant_config.is_checkpoint_fp8_serialized:
# If rocm, use float8_e4m3fnuz as dtype # If rocm, use float8_e4m3fnuz as dtype
...@@ -471,12 +601,13 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -471,12 +601,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int, top_k: int,
renormalize: bool, renormalize: bool,
use_grouped_topk: bool, use_grouped_topk: bool = False,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids = FusedMoE.select_experts(
...@@ -487,19 +618,27 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -487,19 +618,27 @@ class Fp8MoEMethod(FusedMoEMethodBase):
renormalize=renormalize, renormalize=renormalize,
topk_group=topk_group, topk_group=topk_group,
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function) custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
return fused_experts(x, e_score_correction_bias=e_score_correction_bias,
layer.w13_weight, )
layer.w2_weight,
topk_weights=topk_weights, return fused_experts(
topk_ids=topk_ids, x,
inplace=True, layer.w13_weight,
use_fp8_w8a8=True, layer.w2_weight,
w1_scale=layer.w13_weight_scale, topk_weights=topk_weights,
w2_scale=layer.w2_weight_scale, topk_ids=topk_ids,
a1_scale=layer.w13_input_scale, inplace=True,
a2_scale=layer.w2_input_scale) use_fp8_w8a8=True,
w1_scale=(layer.w13_weight_scale_inv
if self.block_quant else layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale_inv
if self.block_quant else layer.w2_weight_scale),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size,
)
class Fp8KVCacheMethod(BaseKVCacheMethod): class Fp8KVCacheMethod(BaseKVCacheMethod):
......
...@@ -532,11 +532,13 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -532,11 +532,13 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int, top_k: int,
renormalize: bool = True, renormalize: bool,
use_grouped_topk: bool = False, use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
# The input must currently be float16 # The input must currently be float16
orig_dtype = x.dtype orig_dtype = x.dtype
...@@ -550,7 +552,9 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -550,7 +552,9 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
renormalize=renormalize, renormalize=renormalize,
topk_group=topk_group, topk_group=topk_group,
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=None) custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
return torch.ops.vllm.fused_marlin_moe( return torch.ops.vllm.fused_marlin_moe(
x, x,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment