Commit 99b471c2 authored by zhuwenwen's avatar zhuwenwen
Browse files

merge v0.4.1

parents 1925d2e9 468d761b
...@@ -12,9 +12,8 @@ from transformers import PreTrainedTokenizerBase ...@@ -12,9 +12,8 @@ from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest) CompletionRequest)
from vllm.model_executor.guided_logits_processors import (CFGLogitsProcessor, from vllm.model_executor.guided_decoding.outlines_logits_processors import (
JSONLogitsProcessor, CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor)
RegexLogitsProcessor)
class GuidedDecodingMode(Enum): class GuidedDecodingMode(Enum):
...@@ -54,9 +53,9 @@ pair : UNESCAPED_STRING ":" value ...@@ -54,9 +53,9 @@ pair : UNESCAPED_STRING ":" value
global_thread_pool = None # used for generating logits processor fsm global_thread_pool = None # used for generating logits processor fsm
async def get_guided_decoding_logits_processor( async def get_outlines_guided_decoding_logits_processor(
request: Union[CompletionRequest, ChatCompletionRequest], request: Union[CompletionRequest, ChatCompletionRequest],
tokenizer) -> Union[JSONLogitsProcessor, RegexLogitsProcessor]: tokenizer) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, None]:
""" """
Given an OpenAI-compatible request, check for guided decoding parameters Given an OpenAI-compatible request, check for guided decoding parameters
and get the necessary logits processor for the given guide. and get the necessary logits processor for the given guide.
...@@ -85,13 +84,13 @@ async def get_guided_decoding_logits_processor( ...@@ -85,13 +84,13 @@ async def get_guided_decoding_logits_processor(
def _get_guide_and_mode( def _get_guide_and_mode(
request: Union[CompletionRequest, ChatCompletionRequest] request: Union[CompletionRequest, ChatCompletionRequest]
) -> Tuple[str, GuidedDecodingMode]: ) -> Union[Tuple[str, GuidedDecodingMode], Tuple[None, None]]:
if request.guided_json: if request.guided_json:
json = request.guided_json json = request.guided_json
if isinstance(json, dict): if isinstance(json, dict):
# turn dict into hashable string # turn dict into hashable string
json = json_dumps(json, sort_keys=True) json = json_dumps(json)
elif isinstance(json, BaseModel): elif isinstance(json, BaseModel):
# use pydantic signature so that different model classes # use pydantic signature so that different model classes
# with the same fields will get hashed the same # with the same fields will get hashed the same
......
...@@ -13,13 +13,15 @@ ...@@ -13,13 +13,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
import json import json
import math import math
from collections import defaultdict from collections import defaultdict
from functools import lru_cache
from typing import Callable, DefaultDict, Dict, List, Optional, Union from typing import Callable, DefaultDict, Dict, List, Optional, Union
import torch import torch
from outlines.fsm.fsm import CFGFSM, RegexFSM from outlines.fsm.fsm import CFGFSM, FSM, RegexFSM
from outlines.fsm.json_schema import build_regex_from_schema from outlines.fsm.json_schema import build_regex_from_schema
from pydantic import BaseModel from pydantic import BaseModel
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
...@@ -27,49 +29,9 @@ from transformers import PreTrainedTokenizerBase ...@@ -27,49 +29,9 @@ from transformers import PreTrainedTokenizerBase
class BaseLogitsProcessor: class BaseLogitsProcessor:
def adapt_tokenizer(self, tokenizer: PreTrainedTokenizerBase): def __init__(self):
"""Adapt vLLM's tokenizer to use to compile the FSM. # Child class should use initialize in their init.
self.fsm: FSM
The API of Outlines tokenizers is slightly different to that of
`transformers`. The decoder of outlines, returns a list whereas
the decode of vLLM returns an str. To sync the vLLM decoder with
outlines internal api, the decoder should be adapted. In addition
we need to handle the missing spaces to Llama's tokenizer to be
able to compile FSMs for this model.
"""
if getattr(tokenizer, "_outlines_adapted", False):
return tokenizer
tokenizer.vocabulary = tokenizer.get_vocab()
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
def convert_token_to_string(token: str) -> str:
from transformers.file_utils import SPIECE_UNDERLINE
string = tokenizer.convert_tokens_to_string([token])
# A hack to handle missing spaces to HF's Llama tokenizers
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
return " " + string
return string
def change_decoder(
decoder: Callable[[List[int]], str]
) -> Callable[[List[int]], List[str]]:
"""Sync vLLM's decoder with the outlines by returning list."""
def new_decoder(inp_tokens: List[int]) -> List[str]:
return [decoder(inp_tokens)]
return new_decoder
tokenizer.convert_token_to_string = convert_token_to_string
tokenizer.decode = change_decoder(tokenizer.decode)
setattr(tokenizer, "_outlines_adapted", True) # noqa: B010
return tokenizer
def init_state(self): def init_state(self):
"""Initialize the FSM states.""" """Initialize the FSM states."""
...@@ -78,7 +40,6 @@ class BaseLogitsProcessor: ...@@ -78,7 +40,6 @@ class BaseLogitsProcessor:
def __call__(self, input_ids: List[int], def __call__(self, input_ids: List[int],
scores: torch.Tensor) -> torch.Tensor: scores: torch.Tensor) -> torch.Tensor:
"""Use the FSM to bias the logits before sampling the next token.""" """Use the FSM to bias the logits before sampling the next token."""
seq_id = hash(tuple(input_ids)) seq_id = hash(tuple(input_ids))
if len(input_ids) == 0: if len(input_ids) == 0:
...@@ -96,7 +57,6 @@ class BaseLogitsProcessor: ...@@ -96,7 +57,6 @@ class BaseLogitsProcessor:
device=scores.device) device=scores.device)
mask[allowed_tokens] = 0 mask[allowed_tokens] = 0
scores.add_(mask) scores.add_(mask)
return scores return scores
...@@ -113,7 +73,7 @@ class RegexLogitsProcessor(BaseLogitsProcessor): ...@@ -113,7 +73,7 @@ class RegexLogitsProcessor(BaseLogitsProcessor):
The model's tokenizer The model's tokenizer
""" """
tokenizer = self.adapt_tokenizer(tokenizer) tokenizer = _adapt_tokenizer(tokenizer)
fsm = RegexFSM(regex_string, tokenizer) fsm = RegexFSM(regex_string, tokenizer)
self.fsm = fsm self.fsm = fsm
...@@ -167,6 +127,59 @@ class CFGLogitsProcessor(BaseLogitsProcessor): ...@@ -167,6 +127,59 @@ class CFGLogitsProcessor(BaseLogitsProcessor):
The model's tokenizer The model's tokenizer
""" """
tokenizer = self.adapt_tokenizer(tokenizer) tokenizer = _adapt_tokenizer(tokenizer)
fsm = CFGFSM(cfg, tokenizer) fsm = CFGFSM(cfg, tokenizer)
self.fsm = fsm self.fsm = fsm
def init_state(self):
"""Initialize state with a CFGFSM copy."""
super().init_state()
self.fsm = self.fsm.copy()
@lru_cache
def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase):
"""Adapt vLLM's tokenizer to use to compile the FSM.
The API of Outlines tokenizers is slightly different to that of
`transformers`. The decoder of outlines, returns a list whereas
the decode of vLLM returns an str. To sync the vLLM decoder with
outlines internal api, the decoder should be adapted. In addition
we need to handle the missing spaces to Llama's tokenizer to be
able to compile FSMs for this model.
"""
if getattr(tokenizer, "_outlines_adapted", False):
return tokenizer
tokenizer = copy.deepcopy(tokenizer)
tokenizer.vocabulary = tokenizer.get_vocab()
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
def convert_token_to_string(token: str) -> str:
from transformers.file_utils import SPIECE_UNDERLINE
string = tokenizer.convert_tokens_to_string([token])
# A hack to handle missing spaces to HF's Llama tokenizers
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
return " " + string
return string
def change_decoder(
decoder: Callable[[List[int]],
str]) -> Callable[[List[int]], List[str]]:
"""Sync vLLM's decoder with the outlines by returning list."""
def new_decoder(inp_tokens: List[int]) -> List[str]:
return [decoder(inp_tokens)]
return new_decoder
tokenizer.convert_token_to_string = convert_token_to_string
tokenizer.decode = change_decoder(tokenizer.decode)
setattr(tokenizer, "_outlines_adapted", True) # noqa: B010
return tokenizer
...@@ -6,11 +6,10 @@ import torch ...@@ -6,11 +6,10 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from vllm._C import ops from vllm import _custom_ops as ops
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.utils import divide
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
......
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"128": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"48": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"48": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"128": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"48": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"48": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"256": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
}
}
...@@ -8,7 +8,7 @@ import torch ...@@ -8,7 +8,7 @@ import torch
import triton import triton
import triton.language as tl import triton.language as tl
from vllm._C import ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import is_hip from vllm.utils import is_hip
...@@ -21,6 +21,8 @@ def fused_moe_kernel( ...@@ -21,6 +21,8 @@ def fused_moe_kernel(
a_ptr, a_ptr,
b_ptr, b_ptr,
c_ptr, c_ptr,
a_scale_ptr,
b_scale_ptr,
topk_weights_ptr, topk_weights_ptr,
sorted_token_ids_ptr, sorted_token_ids_ptr,
expert_ids_ptr, expert_ids_ptr,
...@@ -49,6 +51,7 @@ def fused_moe_kernel( ...@@ -49,6 +51,7 @@ def fused_moe_kernel(
MUL_ROUTED_WEIGHT: tl.constexpr, MUL_ROUTED_WEIGHT: tl.constexpr,
top_k: tl.constexpr, top_k: tl.constexpr,
compute_type: tl.constexpr, compute_type: tl.constexpr,
use_fp8: tl.constexpr,
): ):
""" """
Implements the fused computation for a Mixture of Experts (MOE) using Implements the fused computation for a Mixture of Experts (MOE) using
...@@ -111,6 +114,10 @@ def fused_moe_kernel( ...@@ -111,6 +114,10 @@ def fused_moe_kernel(
b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +
offs_bn[None, :] * stride_bn) offs_bn[None, :] * stride_bn)
if use_fp8:
a_scale = tl.load(a_scale_ptr)
b_scale = tl.load(b_scale_ptr + off_experts)
# ----------------------------------------------------------- # -----------------------------------------------------------
# Iterate to compute a block of the C matrix. # Iterate to compute a block of the C matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
...@@ -129,7 +136,10 @@ def fused_moe_kernel( ...@@ -129,7 +136,10 @@ def fused_moe_kernel(
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
other=0.0) other=0.0)
# We accumulate along the K dimension. # We accumulate along the K dimension.
accumulator += tl.dot(a, b) if use_fp8:
accumulator = tl.dot(a, b, acc=accumulator)
else:
accumulator += tl.dot(a, b)
# Advance the ptrs to the next K block. # Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * stride_ak a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk b_ptrs += BLOCK_SIZE_K * stride_bk
...@@ -140,7 +150,10 @@ def fused_moe_kernel( ...@@ -140,7 +150,10 @@ def fused_moe_kernel(
other=0) other=0)
accumulator = accumulator * moe_weight[:, None] accumulator = accumulator * moe_weight[:, None]
accumulator = accumulator.to(compute_type) if use_fp8:
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
else:
accumulator = accumulator.to(compute_type)
# ----------------------------------------------------------- # -----------------------------------------------------------
# Write back the block of the output # Write back the block of the output
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
...@@ -207,15 +220,24 @@ def moe_align_block_size( ...@@ -207,15 +220,24 @@ def moe_align_block_size(
def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor, B_scale: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
sorted_token_ids: torch.Tensor, sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor, expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor, num_tokens_post_padded: torch.Tensor,
mul_routed_weight: bool, top_k: int, mul_routed_weight: bool, top_k: int,
config: Dict[str, Any]) -> None: config: Dict[str, Any], compute_type: tl.dtype,
use_fp8: bool) -> None:
assert topk_weights.stride(1) == 1 assert topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1 assert sorted_token_ids.stride(0) == 1
if not use_fp8:
A_scale = None
assert B_scale is None
else:
A, A_scale = ops.scaled_fp8_quant(A)
assert B_scale is not None
grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[ grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), ) 'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), )
...@@ -223,6 +245,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, ...@@ -223,6 +245,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
A, A,
B, B,
C, C,
A_scale,
B_scale,
topk_weights, topk_weights,
sorted_token_ids, sorted_token_ids,
expert_ids, expert_ids,
...@@ -240,18 +264,21 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, ...@@ -240,18 +264,21 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
C.stride(2), C.stride(2),
MUL_ROUTED_WEIGHT=mul_routed_weight, MUL_ROUTED_WEIGHT=mul_routed_weight,
top_k=top_k, top_k=top_k,
compute_type=tl.bfloat16 if A.dtype == torch.bfloat16 else tl.float16, compute_type=compute_type,
use_fp8=use_fp8,
**config, **config,
) )
def get_config_file_name(E: int, N: int) -> str: def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
device_name = torch.cuda.get_device_name().replace(" ", "_") device_name = torch.cuda.get_device_name().replace(" ", "_")
return f"E={E},N={N},device_name={device_name}.json" dtype_selector = "" if not dtype else f",dtype={dtype}"
return f"E={E},N={N},device_name={device_name}{dtype_selector}.json"
@functools.lru_cache @functools.lru_cache
def get_moe_configs(E: int, N: int) -> Optional[Dict[int, Any]]: def get_moe_configs(E: int, N: int,
dtype: Optional[str]) -> Optional[Dict[int, Any]]:
""" """
Return optimized configurations for the fused MoE kernel. Return optimized configurations for the fused MoE kernel.
...@@ -263,7 +290,7 @@ def get_moe_configs(E: int, N: int) -> Optional[Dict[int, Any]]: ...@@ -263,7 +290,7 @@ def get_moe_configs(E: int, N: int) -> Optional[Dict[int, Any]]:
# First look up if an optimized configuration is available in the configs # First look up if an optimized configuration is available in the configs
# directory # directory
json_file_name = get_config_file_name(E, N) json_file_name = get_config_file_name(E, N, dtype)
config_file_path = os.path.join( config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name) os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
...@@ -288,6 +315,9 @@ def fused_moe( ...@@ -288,6 +315,9 @@ def fused_moe(
renormalize: bool, renormalize: bool,
inplace: bool = False, inplace: bool = False,
override_config: Optional[Dict[str, Any]] = None, override_config: Optional[Dict[str, Any]] = None,
use_fp8: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
This function computes a Mixture of Experts (MoE) layer using two sets of This function computes a Mixture of Experts (MoE) layer using two sets of
...@@ -305,6 +335,12 @@ def fused_moe( ...@@ -305,6 +335,12 @@ def fused_moe(
Defaults to False. Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override - override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration. for the kernel configuration.
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.
Returns: Returns:
- torch.Tensor: The output tensor after applying the MoE layer. - torch.Tensor: The output tensor after applying the MoE layer.
...@@ -358,7 +394,8 @@ def fused_moe( ...@@ -358,7 +394,8 @@ def fused_moe(
config = override_config config = override_config
else: else:
# First try to load optimal config from the file # First try to load optimal config from the file
configs = get_moe_configs(E, w2.shape[2]) configs = get_moe_configs(E, w2.shape[2],
"float8" if use_fp8 else None)
if configs: if configs:
# If an optimal configuration map has been found, look up the # If an optimal configuration map has been found, look up the
...@@ -394,17 +431,37 @@ def fused_moe( ...@@ -394,17 +431,37 @@ def fused_moe(
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
topk_ids, config['BLOCK_SIZE_M'], E) topk_ids, config['BLOCK_SIZE_M'], E)
invoke_fused_moe_kernel(hidden_states, w1, intermediate_cache1, invoke_fused_moe_kernel(hidden_states,
topk_weights, topk_ids, sorted_token_ids, w1,
expert_ids, num_tokens_post_padded, False, intermediate_cache1,
topk_ids.shape[1], config) w1_scale,
topk_weights,
topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
False,
topk_ids.shape[1],
config,
compute_type=tl.float16,
use_fp8=use_fp8)
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
invoke_fused_moe_kernel(intermediate_cache2, w2, intermediate_cache3, invoke_fused_moe_kernel(intermediate_cache2,
topk_weights, topk_ids, sorted_token_ids, w2,
expert_ids, num_tokens_post_padded, True, 1, intermediate_cache3,
config) w2_scale,
topk_weights,
topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
True,
1,
config,
compute_type=tl.float16,
use_fp8=use_fp8)
if inplace: if inplace:
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
......
...@@ -4,7 +4,7 @@ from typing import Optional, Tuple, Union ...@@ -4,7 +4,7 @@ from typing import Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm._C import ops from vllm import _custom_ops as ops
class RMSNorm(nn.Module): class RMSNorm(nn.Module):
......
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional from typing import List, Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.utils import (
divide, split_tensor_along_last_dim)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -32,21 +32,43 @@ class LinearMethodBase(ABC): ...@@ -32,21 +32,43 @@ class LinearMethodBase(ABC):
"""Base class for different (maybe quantized) linear methods.""" """Base class for different (maybe quantized) linear methods."""
@abstractmethod @abstractmethod
def create_weights(self, input_size_per_partition: int, def create_weights(self, layer: torch.nn.Module,
output_size_per_partition: int, input_size: int, input_size_per_partition: int,
output_size: int, output_partition_sizes: List[int], input_size: int,
params_dtype: torch.dtype) -> Dict[str, Any]: output_size: int, params_dtype: torch.dtype,
"""Create weights for a linear layer.""" **extra_weight_attrs):
"""Create weights for a linear layer.
The weights will be set as attributes of the layer.
Args:
layer: The layer that is using the LinearMethodBase factory.
input_size_per_partition: Size of the weight input dim on rank X.
output_partition_sizes: Sizes of the output dim of each logical
weight on rank X. E.g., output_partition_sizes for QKVLinear
is a list contains the width of Wq, Wk, Wv on rank X.
input_size: Size of the input dim of the weight across all ranks.
output_size: Size of the output dim of the weight across all ranks.
params_dtype: Datatype of the parameters.
"""
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def apply_weights(self, def apply_weights(self,
weights: Dict[str, torch.Tensor], layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Apply the weights to the input tensor.""" """Apply the weights in layer to the input tensor.
Expects create_weights to have been called before on the layer."""
raise NotImplementedError raise NotImplementedError
def process_weights_after_loading(self, layer: nn.Module) -> None:
"""Process the weight after loading.
This can be used for example, to transpose weights for computation.
"""
return
class UnquantizedLinearMethod(LinearMethodBase): class UnquantizedLinearMethod(LinearMethodBase):
"""Linear method without quantization. """Linear method without quantization.
...@@ -60,22 +82,25 @@ class UnquantizedLinearMethod(LinearMethodBase): ...@@ -60,22 +82,25 @@ class UnquantizedLinearMethod(LinearMethodBase):
self.separate_bias_add = separate_bias_add self.separate_bias_add = separate_bias_add
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def create_weights(self, input_size_per_partition: int, def create_weights(self, layer: torch.nn.Module,
output_size_per_partition: int, input_size: int, input_size_per_partition: int,
output_size: int, output_partition_sizes: List[int], input_size: int,
params_dtype: torch.dtype) -> Dict[str, Any]: output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
output_size_per_partition = sum(output_partition_sizes)
weight = Parameter(torch.empty(output_size_per_partition, weight = Parameter(torch.empty(output_size_per_partition,
input_size_per_partition, input_size_per_partition,
dtype=params_dtype), dtype=params_dtype),
requires_grad=False) requires_grad=False)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
return {"weight": weight} layer.register_parameter("weight", weight)
set_weight_attrs(weight, extra_weight_attrs)
def apply_weights(self, def apply_weights(self,
weights: Dict[str, torch.Tensor], layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
weight = weights["weight"] weight = layer.weight
if self.separate_bias_add: if self.separate_bias_add:
if bias is not None: if bias is not None:
return F.linear(x, weight) + bias return F.linear(x, weight) + bias
...@@ -124,12 +149,9 @@ class ReplicatedLinear(torch.nn.Module): ...@@ -124,12 +149,9 @@ class ReplicatedLinear(torch.nn.Module):
if linear_method is None: if linear_method is None:
linear_method = UnquantizedLinearMethod() linear_method = UnquantizedLinearMethod()
self.linear_method = linear_method self.linear_method = linear_method
self.linear_weights = self.linear_method.create_weights( self.linear_method.create_weights(self, self.input_size,
self.input_size, self.output_size, self.input_size, [self.output_size], self.input_size,
self.output_size, self.params_dtype) self.output_size, self.params_dtype)
for name, weight in self.linear_weights.items():
if isinstance(weight, torch.Tensor):
self.register_parameter(name, weight)
if bias: if bias:
self.bias = Parameter( self.bias = Parameter(
torch.empty(self.output_size, dtype=self.params_dtype)) torch.empty(self.output_size, dtype=self.params_dtype))
...@@ -139,7 +161,7 @@ class ReplicatedLinear(torch.nn.Module): ...@@ -139,7 +161,7 @@ class ReplicatedLinear(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
output = self.linear_method.apply_weights(self.linear_weights, x, bias) output = self.linear_method.apply_weights(self, x, bias)
output_bias = self.bias if self.skip_bias_add else None output_bias = self.bias if self.skip_bias_add else None
return output, output_bias return output, output_bias
...@@ -162,6 +184,8 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -162,6 +184,8 @@ class ColumnParallelLinear(torch.nn.Module):
skip adding bias but instead return it. skip adding bias but instead return it.
params_dtype: Data type for the parameters. params_dtype: Data type for the parameters.
linear_method: (Maybe quantized) linear method. linear_method: (Maybe quantized) linear method.
output_sizes: list of output sizes packed into one output, like for QKV
the list would be size 3.
""" """
def __init__( def __init__(
...@@ -173,6 +197,7 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -173,6 +197,7 @@ class ColumnParallelLinear(torch.nn.Module):
skip_bias_add: bool = False, skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
linear_method: Optional[LinearMethodBase] = None, linear_method: Optional[LinearMethodBase] = None,
output_sizes: Optional[List[int]] = None,
): ):
super().__init__() super().__init__()
...@@ -189,14 +214,16 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -189,14 +214,16 @@ class ColumnParallelLinear(torch.nn.Module):
self.params_dtype = params_dtype self.params_dtype = params_dtype
if linear_method is None: if linear_method is None:
linear_method = UnquantizedLinearMethod() linear_method = UnquantizedLinearMethod()
if output_sizes is None:
output_sizes = [output_size]
self.linear_method = linear_method self.linear_method = linear_method
self.linear_weights = self.linear_method.create_weights( self.linear_method.create_weights(self,
self.input_size, self.output_size_per_partition, self.input_size, self.input_size,
self.output_size, self.params_dtype) [x // tp_size for x in output_sizes],
for name, weight in self.linear_weights.items(): self.input_size,
if isinstance(weight, torch.Tensor): self.output_size,
self.register_parameter(name, weight) self.params_dtype,
set_weight_attrs(weight, {"weight_loader": self.weight_loader}) weight_loader=self.weight_loader)
if bias: if bias:
self.bias = Parameter( self.bias = Parameter(
torch.empty(self.output_size_per_partition, torch.empty(self.output_size_per_partition,
...@@ -228,8 +255,7 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -228,8 +255,7 @@ class ColumnParallelLinear(torch.nn.Module):
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
# Matrix multiply. # Matrix multiply.
output_parallel = self.linear_method.apply_weights( output_parallel = self.linear_method.apply_weights(self, input_, bias)
self.linear_weights, input_, bias)
if self.gather_output: if self.gather_output:
# All-gather across the partitions. # All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel) output = tensor_model_parallel_all_gather(output_parallel)
...@@ -273,16 +299,19 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -273,16 +299,19 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
self.output_sizes = output_sizes self.output_sizes = output_sizes
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
assert all(output_size % tp_size == 0 for output_size in output_sizes) assert all(output_size % tp_size == 0 for output_size in output_sizes)
super().__init__(input_size, sum(output_sizes), bias, gather_output, super().__init__(input_size, sum(output_sizes), bias, gather_output,
skip_bias_add, params_dtype, linear_method) skip_bias_add, params_dtype, linear_method,
self.output_sizes)
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def weight_loader(self, def weight_loader(self,
param: Parameter, param: Parameter,
loaded_weight: torch.Tensor, loaded_weight: torch.Tensor,
loaded_shard_id: Optional[int] = None): loaded_shard_id: Optional[int] = None):
param_data = param.data param_data = param.data
output_dim = getattr(param, "output_dim", None) output_dim = getattr(param, "output_dim", None)
is_metadata = getattr(param, "is_metadata", False)
if loaded_shard_id is None: if loaded_shard_id is None:
# Loaded weight is already packed. # Loaded weight is already packed.
if output_dim is None: if output_dim is None:
...@@ -339,6 +368,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -339,6 +368,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
start_idx = tp_rank * shard_size start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx, loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size) shard_size)
elif is_metadata:
# metadata indicates fixed size concatenated along dim 0
shard_size = loaded_weight.shape[0]
shard_offset = loaded_shard_id * shard_size
param_data = param_data.narrow(0, shard_offset, shard_size)
else: else:
ignore_warning = getattr(param, "ignore_warning", False) ignore_warning = getattr(param, "ignore_warning", False)
if not ignore_warning: if not ignore_warning:
...@@ -412,8 +446,14 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -412,8 +446,14 @@ class QKVParallelLinear(ColumnParallelLinear):
input_size = self.hidden_size input_size = self.hidden_size
output_size = (self.num_heads + output_size = (self.num_heads +
2 * self.num_kv_heads) * tp_size * self.head_size 2 * self.num_kv_heads) * tp_size * self.head_size
output_sizes = [
self.num_heads * tp_size * self.head_size,
self.num_kv_heads * tp_size * self.head_size,
self.num_kv_heads * tp_size * self.head_size
]
super().__init__(input_size, output_size, bias, False, skip_bias_add, super().__init__(input_size, output_size, bias, False, skip_bias_add,
params_dtype, linear_method) params_dtype, linear_method, output_sizes)
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def weight_loader(self, def weight_loader(self,
...@@ -422,6 +462,7 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -422,6 +462,7 @@ class QKVParallelLinear(ColumnParallelLinear):
loaded_shard_id: Optional[str] = None): loaded_shard_id: Optional[str] = None):
param_data = param.data param_data = param.data
output_dim = getattr(param, "output_dim", None) output_dim = getattr(param, "output_dim", None)
is_metadata = getattr(param, "is_metadata", False)
if loaded_shard_id is None: if loaded_shard_id is None:
# Loaded weight is already packed. # Loaded weight is already packed.
...@@ -493,6 +534,12 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -493,6 +534,12 @@ class QKVParallelLinear(ColumnParallelLinear):
start_idx = shard_id * shard_size start_idx = shard_id * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx, loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size) shard_size)
elif is_metadata:
# metadata indicates fixed size concatenated along dim 0
shard_size = loaded_weight.shape[0]
shard_index = ["q", "k", "v"].index(loaded_shard_id)
param_data = param_data.narrow(0, shard_index * shard_size,
shard_size)
else: else:
ignore_warning = getattr(param, "ignore_warning", False) ignore_warning = getattr(param, "ignore_warning", False)
if not ignore_warning: if not ignore_warning:
...@@ -566,13 +613,13 @@ class RowParallelLinear(torch.nn.Module): ...@@ -566,13 +613,13 @@ class RowParallelLinear(torch.nn.Module):
if linear_method is None: if linear_method is None:
linear_method = UnquantizedLinearMethod() linear_method = UnquantizedLinearMethod()
self.linear_method = linear_method self.linear_method = linear_method
self.linear_weights = self.linear_method.create_weights( self.linear_method.create_weights(self,
self.input_size_per_partition, self.output_size, self.input_size, self.input_size_per_partition,
self.output_size, self.params_dtype) [self.output_size],
for name, weight in self.linear_weights.items(): self.input_size,
if isinstance(weight, torch.Tensor): self.output_size,
self.register_parameter(name, weight) self.params_dtype,
set_weight_attrs(weight, {"weight_loader": self.weight_loader}) weight_loader=self.weight_loader)
if not reduce_results and (bias and not skip_bias_add): if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the " raise ValueError("When not reduce the results, adding bias to the "
...@@ -616,7 +663,7 @@ class RowParallelLinear(torch.nn.Module): ...@@ -616,7 +663,7 @@ class RowParallelLinear(torch.nn.Module):
# Matrix multiply. # Matrix multiply.
output_parallel = self.linear_method.apply_weights( output_parallel = self.linear_method.apply_weights(
self.linear_weights, input_parallel) self, input_parallel)
if self.reduce_results and self.tp_size > 1: if self.reduce_results and self.tp_size > 1:
output_ = tensor_model_parallel_all_reduce(output_parallel) output_ = tensor_model_parallel_all_reduce(output_parallel)
else: else:
......
...@@ -4,8 +4,7 @@ from typing import Optional ...@@ -4,8 +4,7 @@ from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.model_executor.parallel_utils.communication_op import ( from vllm.distributed import tensor_model_parallel_gather
tensor_model_parallel_gather)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
...@@ -86,8 +85,16 @@ def _apply_logits_processors( ...@@ -86,8 +85,16 @@ def _apply_logits_processors(
) -> torch.Tensor: ) -> torch.Tensor:
logits_row_idx = 0 logits_row_idx = 0
found_logits_processors = False found_logits_processors = False
for seq_ids, sampling_params in sampling_metadata.seq_groups: for i, seq_group in enumerate(sampling_metadata.seq_groups):
seq_ids, sampling_params = seq_group
logits_processors = sampling_params.logits_processors logits_processors = sampling_params.logits_processors
# handle prompt_logprobs by skipping rows in logits added for
# the prompt tokens (prompt logprobs are not processed)
if (i < sampling_metadata.num_prompts
and sampling_params.prompt_logprobs is not None):
assert len(seq_ids) == 1
logits_row_idx += sampling_metadata.prompt_lens[i] - 1
if logits_processors: if logits_processors:
found_logits_processors = True found_logits_processors = True
for seq_id in seq_ids: for seq_id in seq_ids:
...@@ -100,5 +107,6 @@ def _apply_logits_processors( ...@@ -100,5 +107,6 @@ def _apply_logits_processors(
else: else:
logits_row_idx += len(seq_ids) logits_row_idx += len(seq_ids)
if found_logits_processors: if found_logits_processors:
# verifies that no rows in logits were missed unexpectedly
assert logits_row_idx == logits.shape[0] assert logits_row_idx == logits.shape[0]
return logits return logits
...@@ -29,8 +29,8 @@ def _multi_split_sample( ...@@ -29,8 +29,8 @@ def _multi_split_sample(
sampled_tokens_size: Tuple[int, int], sampled_tokens_size: Tuple[int, int],
sampled_logprobs_size: Tuple[int, int], sampled_logprobs_size: Tuple[int, int],
sample_indices: torch.Tensor, sample_indices: torch.Tensor,
logprobs: torch.Tensor,
*, *,
logprobs: Optional[torch.Tensor] = None,
modify_greedy_probs: bool = False, modify_greedy_probs: bool = False,
save_logprobs: bool = False, save_logprobs: bool = False,
): ):
...@@ -167,6 +167,7 @@ def sample( ...@@ -167,6 +167,7 @@ def sample(
sampled_logprobs_size = (0, 0) sampled_logprobs_size = (0, 0)
logprobs = probs logprobs = probs
assert logprobs is not None
if _save_modified_probs: if _save_modified_probs:
sampled_modified_probs_size = sampled_tokens_size sampled_modified_probs_size = sampled_tokens_size
else: else:
......
from typing import Type from typing import Type
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.quantization.fp8 import FP8Config
from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.marlin import MarlinConfig from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
_QUANTIZATION_CONFIG_REGISTRY = { QUANTIZATION_METHODS = {
"aqlm": AQLMConfig,
"awq": AWQConfig, "awq": AWQConfig,
"fp8": FP8Config,
"gptq": GPTQConfig, "gptq": GPTQConfig,
"squeezellm": SqueezeLLMConfig, "squeezellm": SqueezeLLMConfig,
"marlin": MarlinConfig, "marlin": MarlinConfig,
...@@ -16,12 +20,13 @@ _QUANTIZATION_CONFIG_REGISTRY = { ...@@ -16,12 +20,13 @@ _QUANTIZATION_CONFIG_REGISTRY = {
def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
if quantization not in _QUANTIZATION_CONFIG_REGISTRY: if quantization not in QUANTIZATION_METHODS:
raise ValueError(f"Invalid quantization method: {quantization}") raise ValueError(f"Invalid quantization method: {quantization}")
return _QUANTIZATION_CONFIG_REGISTRY[quantization] return QUANTIZATION_METHODS[quantization]
__all__ = [ __all__ = [
"QuantizationConfig", "QuantizationConfig",
"get_quantization_config", "get_quantization_config",
"QUANTIZATION_METHODS",
] ]
# Supports AQLM compression, see https://github.com/Vahe1994/AQLM
# and https://arxiv.org/pdf/2401.06118.pdf
import math
from typing import Any, Dict, List, Optional
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from vllm._C import ops
from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
def get_int_dtype(nbits: int) -> torch.dtype:
if nbits <= 8:
return torch.int8
if nbits <= 16:
return torch.int16
if nbits <= 32:
return torch.int32
if nbits <= 64:
return torch.int64
raise ValueError(f"No dtype available for {nbits}-bit codebooks")
@torch.inference_mode()
def unpack_int_data(data: torch.IntTensor, nbits: int) -> torch.IntTensor:
return data.to(torch.int64) % (2**nbits)
def dequantize_weight(codes: torch.Tensor,
codebooks: torch.Tensor,
scales: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Decode float weights from quantization codes. Differentiable.
:param codes: tensor of integer quantization codes, shape
[*dims, num_out_groups, num_in_groups, num_codebooks]
:param codebooks: tensor of vectors for each quantization code,
[num_codebooks, codebook_size, out_group_size, in_group_size]
:param scales: weight will be multiplied by this factor, must be
broadcastble with
[*dims, out_groups, num_in_groups, out_group_size, in_group_size]
:return: reconstructed weight tensor of shape
[*dims, num_in_groups*group_size]
"""
num_out_groups, num_in_groups, num_codebooks = codes.shape[-3:]
num_codebooks, codebook_size, out_group_size, in_group_size = \
codebooks.shape
out_features = num_out_groups * out_group_size
in_features = num_in_groups * in_group_size
codebook_offsets = torch.arange(
0, num_codebooks * codebook_size, codebook_size,
device=codes.device) # shape: [num_codebooks]
reconstructed_weight_flat = F.embedding_bag(
codes.flatten(0, -2) + codebook_offsets,
codebooks.flatten(0, 1).flatten(-2, -1),
mode="sum"
) # [prod(dims) * num_out_groups * num_in_groups, out_group_size
# * in_group_size]
reconstructed_weight_groupwise = reconstructed_weight_flat.view(
list(codes.shape[:-3]) +
[num_out_groups, num_in_groups, out_group_size, in_group_size])
if scales is not None:
reconstructed_weight_groupwise = reconstructed_weight_groupwise.mul(
scales)
return reconstructed_weight_groupwise.swapaxes(
-3, -2).reshape(list(codes.shape[:-3]) + [out_features, in_features])
def dequantize_gemm(
input: torch.Tensor, # [..., in_features]
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
codebooks: torch.
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
bias: Optional[torch.Tensor],
) -> torch.Tensor:
dequantized_weight = dequantize_weight(
unpack_int_data(codes, codebooks.shape[1].bit_length() - 1),
codebooks,
scales,
)
return F.linear(input, dequantized_weight, bias)
# Generic dequantization, slow but flexible.
def generic_dequantize_gemm(
input: torch.Tensor, # [..., in_features]
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
codebooks: torch.
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
output_partition_sizes: torch.IntTensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
output_shape = input.shape[:-1] + (scales.shape[0], )
output = torch.empty(output_shape, dtype=input.dtype, device=input.device)
num_outputs = len(output_partition_sizes)
# break the inputs and codebooks apart then combine the outputs.
# Surprisingly (to me) this is faster than doing 3 de-quants and 1 big
# multiply at the end.
num_codebooks = codebooks.shape[0] // num_outputs
assert (scales.shape[0] == codes.shape[0])
assert (sum(output_partition_sizes) == scales.shape[0])
output_offset = 0
codebooks_offset = 0
for output_size in output_partition_sizes:
shard_output = dequantize_gemm(
input, codes.narrow(0, output_offset, output_size),
codebooks.narrow(0, codebooks_offset, num_codebooks),
scales.narrow(0, output_offset, output_size), None
if bias is None else bias.narrow(0, output_offset, output_size))
output_slice = output.narrow(-1, output_offset, output_size)
assert (output_slice.shape == shard_output.shape)
output_slice.copy_(shard_output)
output_offset += output_size
codebooks_offset += num_codebooks
return output
# Optimized dequnantize/decompression kernels, supports 1x16 and 2x8
# at 6 and 9 times faster than the generic version above, respectively.
def optimized_dequantize_gemm(
input: torch.Tensor, # [..., in_features]
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
codebooks: torch.
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
output_partition_sizes: torch.IntTensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
if bias is None:
# scaling the output is fastest, so we do that when possible.
output = F.linear(input, weights, bias)
orig_shape = output.shape
flattened_output = output.view(-1, output.size(-1))
f_scales = scales.view(-1, scales.shape[0])
b_scales = f_scales.expand(flattened_output.shape[0], -1)
flattened_output *= b_scales
return output.view(orig_shape)
else:
b_scales = scales.view(scales.shape[:-3] + (-1, )).expand(
-1, weights.shape[1])
weights *= b_scales
return F.linear(input, weights, bias)
class AQLMConfig(QuantizationConfig):
"""Config class for AQLM.
Reference: https://github.com/Vahe1994/AQLM
"""
def __init__(
self,
in_group_size: int,
nbits_per_codebook: int,
num_codebooks: int,
out_group_size: int,
) -> None:
self.in_group_size = in_group_size
self.nbits_per_codebook = nbits_per_codebook
self.num_codebooks = num_codebooks
self.out_group_size = out_group_size
# out_group_size > 1 is untested, and probably won't work as-is.
assert (self.out_group_size == 1)
self.pack_factor = (self.in_group_size * self.out_group_size)
def __repr__(self) -> str:
return (f"AQLMConfig(in_group_size={self.in_group_size}, "
f"nbits_per_codebook={self.nbits_per_codebook}, "
f"num_codebooks={self.num_codebooks}, "
f"out_group_size={self.out_group_size})")
@classmethod
def get_name(cls) -> str:
return "aqlm"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half]
@classmethod
def get_min_capability(cls) -> int:
return 70
@classmethod
def get_config_filenames(cls) -> List[str]:
return [] # no extra configs.
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "AQLMConfig":
in_group_size = cls.get_from_keys(config, ["in_group_size"])
nbits_per_codebook = cls.get_from_keys(config, ["nbits_per_codebook"])
num_code_books = cls.get_from_keys(config, ["num_codebooks"])
out_group_size = cls.get_from_keys(config, ["out_group_size"])
return cls(in_group_size, nbits_per_codebook, num_code_books,
out_group_size)
def get_linear_method(self) -> "AQLMLinearMethod":
return AQLMLinearMethod(self)
def get_scaled_act_names(self) -> List[str]:
return []
class AQLMLinearMethod(LinearMethodBase):
"""Linear method for AQLM.
Args:
quant_config: The AQLM quantization config.
"""
def __init__(self, quant_config: AQLMConfig):
self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
del output_size # Unused.
del input_size # Unused.
if params_dtype != torch.half:
raise ValueError("Only half is currently supported by aqlm")
if input_size_per_partition % self.quant_config.in_group_size != 0:
raise ValueError(
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
output_size_per_partition = sum(output_partition_sizes)
if output_size_per_partition % self.quant_config.out_group_size != 0:
raise ValueError(
"The output size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
codes = Parameter(
torch.empty(
# There could actually be two pack factors, one along input and
# one along output, but we don't currently support
# out_group_size, and only the one along output needs to be
# marked with "packed_dim" in order for QKVLinear to work.
output_size_per_partition,
input_size_per_partition // self.quant_config.pack_factor,
self.quant_config.num_codebooks,
dtype=get_int_dtype(self.quant_config.nbits_per_codebook),
),
requires_grad=False,
)
set_weight_attrs(
codes,
{
"input_dim": 1,
"output_dim": 0,
"packed_dim": 1,
"pack_factor": self.quant_config.pack_factor,
},
)
codebooks = Parameter(
torch.empty(
self.quant_config.num_codebooks * len(output_partition_sizes),
2**self.quant_config.nbits_per_codebook,
self.quant_config.out_group_size,
self.quant_config.in_group_size,
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(
codebooks,
{
# metadata indicates fixed size concatenated along dim 0
"is_metadata":
True,
"output_partition_sizes":
torch.tensor(output_partition_sizes, device='cpu'),
},
)
scales = Parameter(
torch.empty(
(
output_size_per_partition //
self.quant_config.out_group_size,
1,
1,
1,
),
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(
scales,
{
"output_dim": 0,
"packed_dim": 0,
"pack_factor": self.quant_config.out_group_size
},
)
layer.register_parameter("codes", codes)
set_weight_attrs(codes, extra_weight_attrs)
layer.register_parameter("codebooks", codebooks)
set_weight_attrs(codebooks, extra_weight_attrs)
layer.register_parameter("scales", scales)
set_weight_attrs(scales, extra_weight_attrs)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
codebooks = layer.codebooks
codes = layer.codes
scales = layer.scales
output_partition_sizes = getattr(codebooks, "output_partition_sizes",
None)
nbooks = codes.shape[2]
ingroups = codebooks.shape[3]
outgroups = codebooks.shape[2]
bits = codebooks.shape[1]
# We support these formats with dedicated gemm and decompression
# kernels.
if ingroups == 8 and outgroups == 1 and (
(bits == 256 and nbooks == 2) or (bits == 65536 and nbooks == 1)):
# thresholds determined by timings on an A6000, one GPU
use_gemv = math.prod(x.shape[:-1]) <= 6
return ops.aqlm_gemm(
x,
codes,
codebooks,
scales,
output_partition_sizes,
bias,
) if use_gemv else optimized_dequantize_gemm(
x,
codes,
codebooks,
scales,
output_partition_sizes,
bias,
)
# fall back all unoptimized formats
return generic_dequantize_gemm(
x,
codes,
codebooks,
scales,
output_partition_sizes,
bias,
)
...@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional ...@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm._C import ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs) set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
...@@ -79,15 +79,18 @@ class AWQLinearMethod(LinearMethodBase): ...@@ -79,15 +79,18 @@ class AWQLinearMethod(LinearMethodBase):
def __init__(self, quant_config: AWQConfig): def __init__(self, quant_config: AWQConfig):
self.quant_config = quant_config self.quant_config = quant_config
def create_weights(self, input_size_per_partition: int, def create_weights(self, layer: torch.nn.Module,
output_size_per_partition: int, input_size: int, input_size_per_partition: int,
output_size: int, output_partition_sizes: List[int], input_size: int,
params_dtype: torch.dtype) -> Dict[str, Any]: output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
if input_size_per_partition % self.quant_config.group_size != 0: if input_size_per_partition % self.quant_config.group_size != 0:
raise ValueError( raise ValueError(
"The input size is not aligned with the quantized " "The input size is not aligned with the quantized "
"weight shape. This can be caused by too large " "weight shape. This can be caused by too large "
"tensor parallel size.") "tensor parallel size.")
output_size_per_partition = sum(output_partition_sizes)
if output_size_per_partition % self.quant_config.pack_factor != 0: if output_size_per_partition % self.quant_config.pack_factor != 0:
raise ValueError( raise ValueError(
"The output size is not aligned with the quantized " "The output size is not aligned with the quantized "
...@@ -136,19 +139,21 @@ class AWQLinearMethod(LinearMethodBase): ...@@ -136,19 +139,21 @@ class AWQLinearMethod(LinearMethodBase):
"input_dim": 0, "input_dim": 0,
"output_dim": 1, "output_dim": 1,
}) })
return {
"qweight": qweight, layer.register_parameter("qweight", qweight)
"qzeros": qzeros, set_weight_attrs(qweight, extra_weight_attrs)
"scales": scales, layer.register_parameter("qzeros", qzeros)
} set_weight_attrs(qzeros, extra_weight_attrs)
layer.register_parameter("scales", scales)
set_weight_attrs(scales, extra_weight_attrs)
def apply_weights(self, def apply_weights(self,
weights: Dict[str, Any], layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
qweight = weights["qweight"] qweight = layer.qweight
scales = weights["scales"] scales = layer.scales
qzeros = weights["qzeros"] qzeros = layer.qzeros
pack_factor = self.quant_config.pack_factor pack_factor = self.quant_config.pack_factor
out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, )) out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, ))
reshaped_x = x.reshape(-1, x.shape[-1]) reshaped_x = x.reshape(-1, x.shape[-1])
...@@ -163,5 +168,5 @@ class AWQLinearMethod(LinearMethodBase): ...@@ -163,5 +168,5 @@ class AWQLinearMethod(LinearMethodBase):
out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros, out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros,
pack_factor) pack_factor)
if bias is not None: if bias is not None:
out = out + bias out.add_(bias)
return out.reshape(out_shape) return out.reshape(out_shape)
from typing import Any, Dict, List, Optional, Tuple
import torch
from torch.nn import Module
from torch.nn.parameter import Parameter
from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
class FP8Config(QuantizationConfig):
"""Config class for FP8."""
@classmethod
def get_name(cls) -> str:
return "fp8"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.bfloat16, torch.half]
@classmethod
def get_min_capability(cls) -> int:
# TODO: PyTorch 2.3.0+ is required to run FP8 on
# SM 89 (e.g. Ada) GPUs. Specifically, this PR has to
# be included: https://github.com/pytorch/pytorch/pull/118881
return 90
@classmethod
def get_config_filenames(cls) -> List[str]:
return []
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "FP8Config":
return cls()
def get_linear_method(self) -> "Fp8LinearMethod":
return Fp8LinearMethod(self)
def get_scaled_act_names(self) -> List[str]:
return []
class Fp8LinearMethod(LinearMethodBase):
"""Linear method for FP8.
We now support common FP16/BF16 model checkpoints ONLY. The weight
scaling factor will be initialized after the model weights are loaded.
Limitations:
1. Only support per-tensor quantization due to torch._scaled_mm support.
2. Only support float8_e4m3fn data type due to the limitation of
torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
Args:
quant_config: The quantization config.
"""
def __init__(self, quant_config: FP8Config):
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
output_size_per_partition = sum(output_partition_sizes)
weight = Parameter(torch.empty(output_size_per_partition,
input_size_per_partition,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("weight", weight)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
set_weight_attrs(weight, extra_weight_attrs)
w_scale = Parameter(
torch.empty(1, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("weight_scaling_factor", w_scale)
def process_weights_after_loading(self, layer: Module) -> None:
# Although the linear_method is propagated to all layers,
# only linear layers invoke "create_weights". So we check
# whether "weight_scaling_facor" is registered to determine
# whether the layer is a linear layer that requires quantization.
if not hasattr(layer, "weight_scaling_factor"):
return
qweight, weight_scale = per_tensor_quantize(layer.weight)
# torch._scaled_mm requires column-major in the second
# input (weight), so we transpose the quantized weight.
layer.weight = Parameter(qweight.t(), requires_grad=False)
layer.weight_scaling_factor.data.copy_(weight_scale)
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
qinput, x_scale = per_tensor_quantize(x)
output, _ = torch._scaled_mm(
qinput,
layer.weight,
out_dtype=x.dtype,
scale_a=x_scale,
scale_b=layer.weight_scaling_factor,
bias=bias,
)
return output
def per_tensor_quantize(tensor: torch.Tensor) -> Tuple[torch.Tensor, float]:
"""Quantize a tensor using per-tensor static scaling factor.
Args:
tensor: The input tensor.
"""
finfo = torch.finfo(torch.float8_e4m3fn)
# Calculate the scale as dtype max divided by absmax.
# Since .abs() creates a new tensor, we use aminmax to get
# the min and max first and then calculate the absmax.
min_val, max_val = tensor.aminmax()
amax = min_val.abs().max(max_val.abs())
scale = finfo.max / amax.clamp(min=1e-12)
# scale and clamp the tensor to bring it to
# the representative range of float8 data type
# (as default cast is unsaturated)
qweight = (tensor * scale).clamp(min=finfo.min, max=finfo.max)
# Return both float8 data and the inverse scale (as float),
# as both required as inputs to torch._scaled_mm
qweight = qweight.to(torch.float8_e4m3fn)
scale = scale.float().reciprocal()
return qweight, scale
...@@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional ...@@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm._C import ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs) set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
...@@ -89,18 +89,21 @@ class GPTQLinearMethod(LinearMethodBase): ...@@ -89,18 +89,21 @@ class GPTQLinearMethod(LinearMethodBase):
def create_weights( def create_weights(
self, self,
layer: torch.nn.Module,
input_size_per_partition: int, input_size_per_partition: int,
output_size_per_partition: int, output_partition_sizes: List[int],
input_size: int, input_size: int,
output_size: int, output_size: int,
params_dtype: torch.dtype, params_dtype: torch.dtype,
) -> Dict[str, Any]: **extra_weight_attrs,
):
del output_size # Unused. del output_size # Unused.
if input_size_per_partition % self.quant_config.group_size != 0: if input_size_per_partition % self.quant_config.group_size != 0:
raise ValueError( raise ValueError(
"The input size is not aligned with the quantized " "The input size is not aligned with the quantized "
"weight shape. This can be caused by too large " "weight shape. This can be caused by too large "
"tensor parallel size.") "tensor parallel size.")
output_size_per_partition = sum(output_partition_sizes)
if (output_size_per_partition % self.quant_config.pack_factor.numerator if (output_size_per_partition % self.quant_config.pack_factor.numerator
!= 0): != 0):
raise ValueError( raise ValueError(
...@@ -179,37 +182,40 @@ class GPTQLinearMethod(LinearMethodBase): ...@@ -179,37 +182,40 @@ class GPTQLinearMethod(LinearMethodBase):
"input_dim": scale_and_zero_input_dim, "input_dim": scale_and_zero_input_dim,
"output_dim": 1, "output_dim": 1,
}) })
return {
"qweight": qweight, layer.register_parameter("qweight", qweight)
"g_idx": g_idx, set_weight_attrs(qweight, extra_weight_attrs)
"qzeros": qzeros, layer.register_parameter("g_idx", g_idx)
"scales": scales, set_weight_attrs(g_idx, extra_weight_attrs)
"exllama_state": exllama_state, layer.register_parameter("qzeros", qzeros)
} set_weight_attrs(qzeros, extra_weight_attrs)
layer.register_parameter("scales", scales)
set_weight_attrs(scales, extra_weight_attrs)
layer.exllama_state = exllama_state
def apply_weights(self, def apply_weights(self,
weights: Dict[str, Any], layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
qweight = weights["qweight"] qweight = layer.qweight
out_shape = x.shape[:-1] + (qweight.shape[-1], ) out_shape = x.shape[:-1] + (qweight.shape[-1], )
reshaped_x = x.reshape(-1, x.shape[-1]) reshaped_x = x.reshape(-1, x.shape[-1])
# exllama needs to shuffle the weight after the weight is loaded # exllama needs to shuffle the weight after the weight is loaded
# here we do the shuffle on first forward pass # here we do the shuffle on first forward pass
if weights["exllama_state"] == ExllamaState.UNINITIALIZED: if layer.exllama_state == ExllamaState.UNINITIALIZED:
if self.quant_config.desc_act: if self.quant_config.desc_act:
weights["g_idx"] = torch.argsort(weights["g_idx"]).to( layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int)
torch.int)
else: else:
weights["g_idx"] = torch.empty((1, 1), device="meta") layer.g_idx.data = torch.empty((0, ),
weights["exllama_state"] = ExllamaState.READY device=layer.g_idx.device)
ops.gptq_shuffle(weights["qweight"], weights["g_idx"], layer.exllama_state = ExllamaState.READY
ops.gptq_shuffle(layer.qweight, layer.g_idx,
self.quant_config.weight_bits) self.quant_config.weight_bits)
output = ops.gptq_gemm(reshaped_x, weights["qweight"], output = ops.gptq_gemm(reshaped_x, layer.qweight, layer.qzeros,
weights["qzeros"], weights["scales"], layer.scales, layer.g_idx,
weights["g_idx"], layer.exllama_state == ExllamaState.READY,
weights["exllama_state"] == ExllamaState.READY,
self.quant_config.weight_bits) self.quant_config.weight_bits)
if bias is not None: if bias is not None:
output = output + bias output.add_(bias)
return output.reshape(out_shape) return output.reshape(out_shape)
...@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional ...@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm._C import ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs) set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
...@@ -91,12 +91,14 @@ class MarlinLinearMethod(LinearMethodBase): ...@@ -91,12 +91,14 @@ class MarlinLinearMethod(LinearMethodBase):
def create_weights( def create_weights(
self, self,
layer: torch.nn.Module,
input_size_per_partition: int, input_size_per_partition: int,
output_size_per_partition: int, output_partition_sizes: List[int],
input_size: int, input_size: int,
output_size: int, output_size: int,
params_dtype: torch.dtype, params_dtype: torch.dtype,
) -> Dict[str, Any]: **extra_weight_attrs,
):
del output_size # Unused. del output_size # Unused.
if params_dtype != torch.float16: if params_dtype != torch.float16:
...@@ -104,6 +106,7 @@ class MarlinLinearMethod(LinearMethodBase): ...@@ -104,6 +106,7 @@ class MarlinLinearMethod(LinearMethodBase):
f"The params dtype must be float16, but got {params_dtype}") f"The params dtype must be float16, but got {params_dtype}")
# Validate output_size_per_partition # Validate output_size_per_partition
output_size_per_partition = sum(output_partition_sizes)
if output_size_per_partition % self.quant_config.min_n_threads != 0: if output_size_per_partition % self.quant_config.min_n_threads != 0:
raise ValueError( raise ValueError(
f"Weight output_size_per_partition = " f"Weight output_size_per_partition = "
...@@ -187,21 +190,22 @@ class MarlinLinearMethod(LinearMethodBase): ...@@ -187,21 +190,22 @@ class MarlinLinearMethod(LinearMethodBase):
dtype=torch.int), dtype=torch.int),
requires_grad=False) requires_grad=False)
return { layer.register_parameter("B", qweight)
"B": qweight, set_weight_attrs(qweight, extra_weight_attrs)
"s": scales, layer.register_parameter("s", scales)
"workspace": workspace, set_weight_attrs(scales, extra_weight_attrs)
} layer.register_parameter("workspace", workspace)
set_weight_attrs(workspace, extra_weight_attrs)
def apply_weights( def apply_weights(
self, self,
weights: Dict[str, Any], layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
qweight = weights["B"] qweight = layer.B
scales = weights["s"] scales = layer.s
workspace = weights["workspace"] workspace = layer.workspace
x_2d = x.view(-1, x.shape[-1]) x_2d = x.view(-1, x.shape[-1])
......
"""
This file contains the Pydantic schemas for various quantization-related
parameters. When a relevant quantization technique is specified, these
parameters are loaded in the form of a JSON alongside the model weights
and augment the model with additional information needed for use of that
technique. The format of this JSON should be specified by one or more
schemas contained here.
For example, when the KV cache is quantized to FP8-E4M3 (currently only
possible on ROCm), the model can be optionally augmented with KV cache
scaling factors.
"""
from typing import Dict, Optional
from pydantic import BaseModel, ConfigDict, ValidationInfo, model_validator
class KVCacheQuantSchema(BaseModel):
dtype: str
# Each key is a TP rank. Each value is a dictionary mapping a TP rank's
# layer indices to their per-tensor KV cache scaling factor.
# TODO: Consider pulling this and its validation methods out into its
# own schema class (tricky as its members are variable)
scaling_factor: Dict[int, Dict[int, float]]
@model_validator(mode="after")
def check_is_fp8(self) -> "KVCacheQuantSchema":
assert self.dtype == "float8_e4m3fn", (
"Loaded scaling factors intended for KV cache dtype = "
f"{self.dtype} rather than float8_e4m3fn!")
return self
@model_validator(mode="after")
def check_tp_ranks(self, info: ValidationInfo) -> "KVCacheQuantSchema":
context = info.context
if context:
tp_size = context["tp_size"]
num_hidden_layers = context["num_hidden_layers"]
assert len(self.scaling_factor) == tp_size, (
f"Loaded dictionary has TP size {len(self.scaling_factor)} "
f"but LLM engine is currently running with TP size {tp_size}.")
for tp_rank, layer_maps in self.scaling_factor.items():
assert len(layer_maps) == num_hidden_layers, (
f"KV cache scales map for TP rank {tp_rank} is malformed. "
f"Expected {num_hidden_layers} layers, got "
f"{len(layer_maps)}.")
for i in range(tp_size):
assert i in self.scaling_factor, (
f"KV cache scales map for TP rank {i} not found.")
return self
@model_validator(mode="after")
def check_current_rank(self, info: ValidationInfo) -> "KVCacheQuantSchema":
context = info.context
if context:
tp_rank = context["tp_rank"]
num_hidden_layers = context["num_hidden_layers"]
layer_scales_map = self.scaling_factor[tp_rank]
for i in range(num_hidden_layers):
assert i in layer_scales_map, (
f"Could not find KV cache scales for layer {i} in "
f"TP rank {tp_rank}.")
return self
class QuantParamSchema(BaseModel):
# TODO: Generalize and extend with more fields
# (e.g. weights/activations params) once functionality is enabled
model_config = ConfigDict(protected_namespaces=())
model_type: Optional[str]
kv_cache: KVCacheQuantSchema
@model_validator(mode="after")
def check_model_type(self, info: ValidationInfo) -> "QuantParamSchema":
context = info.context
if context:
model_type = context.get("model_type", None)
if model_type is not None:
assert model_type == self.model_type, (
f"Model type is {model_type} but loaded "
f"scaling factors belonging to different "
f"model type {self.model_type}!")
return self
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment