Commit e661d594 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.5.4' into v0.5.4-dtk24.04.1

parents 6b16ea2e 4db5176d
...@@ -12,7 +12,10 @@ from transformers import PreTrainedTokenizerBase ...@@ -12,7 +12,10 @@ from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest) CompletionRequest)
from vllm.model_executor.guided_decoding.guided_fields import (
GuidedDecodingRequest)
from vllm.model_executor.guided_decoding.outlines_decoding import ( from vllm.model_executor.guided_decoding.outlines_decoding import (
get_local_outlines_guided_decoding_logits_processor,
get_outlines_guided_decoding_logits_processor) get_outlines_guided_decoding_logits_processor)
from vllm.sampling_params import LogitsProcessor from vllm.sampling_params import LogitsProcessor
...@@ -54,6 +57,42 @@ async def get_lm_format_enforcer_guided_decoding_logits_processor( ...@@ -54,6 +57,42 @@ async def get_lm_format_enforcer_guided_decoding_logits_processor(
return logits_processor return logits_processor
def get_local_lm_format_enforcer_guided_decoding_logits_processor(
guided_options: GuidedDecodingRequest,
tokenizer) -> Optional[LogitsProcessor]:
"""
Given an OpenAI-compatible request, check for guided decoding parameters
and get the necessary logits processor for the given guide.
We cache logit processors by (guide, tokenizer), and on cache hit
we make a shallow copy to reuse the same underlying FSM.
"""
tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data(
tokenizer)
character_level_parser: CharacterLevelParser
if guided_options.guided_json:
schema = _normalize_json_schema_object(guided_options.guided_json)
character_level_parser = JsonSchemaParser(schema)
elif guided_options.guided_choice:
character_level_parser = UnionParser(
[StringParser(choice) for choice in guided_options.guided_choice])
elif guided_options.guided_regex:
character_level_parser = RegexParser(guided_options.guided_regex)
elif guided_options.guided_grammar:
# CFG grammar not supported by LMFE, revert to outlines
return get_local_outlines_guided_decoding_logits_processor(
guided_options, tokenizer)
elif guided_options.guided_json_object:
# None means any json object
character_level_parser = JsonSchemaParser(None)
else:
return None
logits_processor = build_vllm_logits_processor(tokenizer_data,
character_level_parser)
return logits_processor
def _normalize_json_schema_object(schema: Union[str, dict, BaseModel]) -> dict: def _normalize_json_schema_object(schema: Union[str, dict, BaseModel]) -> dict:
if isinstance(schema, str): if isinstance(schema, str):
return json_loads(schema) return json_loads(schema)
......
...@@ -10,6 +10,8 @@ from transformers import PreTrainedTokenizerBase ...@@ -10,6 +10,8 @@ from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest) CompletionRequest)
from vllm.model_executor.guided_decoding.guided_fields import (
GuidedDecodingRequest)
from vllm.model_executor.guided_decoding.outlines_logits_processors import ( from vllm.model_executor.guided_decoding.outlines_logits_processors import (
CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor) CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor)
...@@ -77,8 +79,27 @@ async def get_outlines_guided_decoding_logits_processor( ...@@ -77,8 +79,27 @@ async def get_outlines_guided_decoding_logits_processor(
mode, request.guided_whitespace_pattern) mode, request.guided_whitespace_pattern)
def get_local_outlines_guided_decoding_logits_processor(
guided_options: GuidedDecodingRequest, tokenizer: PreTrainedTokenizerBase
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
None]:
"""
Given an OpenAI-compatible request, check for guided decoding parameters
and get the necessary logits processor for the given guide.
We cache logit processors by (guide, tokenizer), and on cache hit
we make a shallow copy to reuse the same underlying FSM.
"""
guide, mode = _get_guide_and_mode(guided_options)
if not guide or not mode:
return None
return _get_logits_processor(guide, tokenizer, mode,
guided_options.guided_whitespace_pattern)
def _get_guide_and_mode( def _get_guide_and_mode(
request: Union[CompletionRequest, ChatCompletionRequest] request: Union[CompletionRequest, ChatCompletionRequest,
GuidedDecodingRequest]
) -> Union[Tuple[str, GuidedDecodingMode], Tuple[None, None]]: ) -> Union[Tuple[str, GuidedDecodingMode], Tuple[None, None]]:
if request.guided_json: if request.guided_json:
...@@ -102,7 +123,8 @@ def _get_guide_and_mode( ...@@ -102,7 +123,8 @@ def _get_guide_and_mode(
return choices_regex, GuidedDecodingMode.CHOICE return choices_regex, GuidedDecodingMode.CHOICE
elif request.guided_grammar: elif request.guided_grammar:
return request.guided_grammar, GuidedDecodingMode.GRAMMAR return request.guided_grammar, GuidedDecodingMode.GRAMMAR
elif (request.response_format is not None elif (not isinstance(request, GuidedDecodingRequest)
and request.response_format is not None
and request.response_format.type == "json_object"): and request.response_format.type == "json_object"):
return JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR return JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR
else: else:
......
...@@ -21,6 +21,8 @@ from functools import lru_cache ...@@ -21,6 +21,8 @@ from functools import lru_cache
from typing import Callable, DefaultDict, Dict, List, Union from typing import Callable, DefaultDict, Dict, List, Union
import torch import torch
from lark import Lark
from outlines import grammars
from outlines.caching import cache from outlines.caching import cache
from outlines.fsm.guide import CFGGuide, Generate, Guide, RegexGuide, Write from outlines.fsm.guide import CFGGuide, Generate, Guide, RegexGuide, Write
from outlines.fsm.json_schema import build_regex_from_schema from outlines.fsm.json_schema import build_regex_from_schema
...@@ -44,6 +46,23 @@ class BaseLogitsProcessor: ...@@ -44,6 +46,23 @@ class BaseLogitsProcessor:
last_seq_id = hash(tuple(input_ids[:-1])) last_seq_id = hash(tuple(input_ids[:-1]))
self._fsm_state[seq_id] = self._guide.get_next_state( self._fsm_state[seq_id] = self._guide.get_next_state(
state=self._fsm_state[last_seq_id], token_id=last_token) state=self._fsm_state[last_seq_id], token_id=last_token)
else:
# Note: this is a hack.
# Lark pickling does not work properly (silent failure),
# which breaks the RPC (which uses python pickleing).
# We need to find a better solution.
# On the first time this is called, we simply re-create
# the Lark object.
if isinstance(self._guide, CFGGuide):
self._guide.parser = Lark(
self._guide.cfg_string,
parser="lalr",
lexer="contextual",
propagate_positions=False,
maybe_placeholders=False,
regex=True,
import_paths=[grammars.GRAMMAR_PATH],
)
instruction = self._guide.get_next_instruction( instruction = self._guide.get_next_instruction(
state=self._fsm_state[seq_id]) state=self._fsm_state[seq_id])
......
...@@ -159,6 +159,19 @@ class QuickGELU(CustomOp): ...@@ -159,6 +159,19 @@ class QuickGELU(CustomOp):
# def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
class ReLUSquaredActivation(CustomOp):
"""
Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
"""
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
return torch.square(F.relu(x))
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
return self.forward_native(x)
class ScaledActivation(nn.Module): class ScaledActivation(nn.Module):
"""An activation function with post-scale parameters. """An activation function with post-scale parameters.
...@@ -207,6 +220,7 @@ _ACTIVATION_REGISTRY = { ...@@ -207,6 +220,7 @@ _ACTIVATION_REGISTRY = {
"gelu_new": NewGELU(), "gelu_new": NewGELU(),
"gelu_pytorch_tanh": nn.GELU(approximate="tanh"), "gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
"relu": nn.ReLU(), "relu": nn.ReLU(),
"relu2": ReLUSquaredActivation(),
"quick_gelu": QuickGELU(), "quick_gelu": QuickGELU(),
} }
......
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts, fused_moe, fused_topk, get_config_file_name, grouped_topk)
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
FusedMoEMethodBase) FusedMoEMethodBase)
from vllm.triton_utils import HAS_TRITON
__all__ = [ __all__ = [
"fused_moe",
"fused_topk",
"fused_experts",
"get_config_file_name",
"grouped_topk",
"FusedMoE", "FusedMoE",
"FusedMoEMethodBase", "FusedMoEMethodBase",
] ]
if HAS_TRITON:
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts, fused_moe, fused_topk, get_config_file_name,
grouped_topk)
__all__ += [
"fused_moe",
"fused_topk",
"fused_experts",
"get_config_file_name",
"grouped_topk",
]
...@@ -218,12 +218,16 @@ class ReplicatedLinear(LinearBase): ...@@ -218,12 +218,16 @@ class ReplicatedLinear(LinearBase):
self.input_size, self.input_size,
self.output_size, self.output_size,
self.params_dtype, self.params_dtype,
weight_loader=self.weight_loader,
prefix=prefix) prefix=prefix)
if bias: if bias:
self.bias = Parameter( self.bias = Parameter(
torch.empty(self.output_size, dtype=self.params_dtype)) torch.empty(self.output_size, dtype=self.params_dtype))
set_weight_attrs(self.bias, {"output_dim": 0}) set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else: else:
self.register_parameter("bias", None) self.register_parameter("bias", None)
......
...@@ -5,10 +5,12 @@ from typing import Optional ...@@ -5,10 +5,12 @@ from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.distributed import tensor_model_parallel_gather from vllm.distributed import (tensor_model_parallel_all_gather,
tensor_model_parallel_gather)
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform
class LogitsProcessor(nn.Module): class LogitsProcessor(nn.Module):
...@@ -39,6 +41,8 @@ class LogitsProcessor(nn.Module): ...@@ -39,6 +41,8 @@ class LogitsProcessor(nn.Module):
self.org_vocab_size = org_vocab_size or vocab_size self.org_vocab_size = org_vocab_size or vocab_size
# Soft cap the logits. Used in Gemma 2. # Soft cap the logits. Used in Gemma 2.
self.soft_cap = soft_cap self.soft_cap = soft_cap
# Whether to use gather or all-gather to gather the logits.
self.use_gather = not current_platform.is_tpu()
def forward( def forward(
self, self,
...@@ -76,7 +80,15 @@ class LogitsProcessor(nn.Module): ...@@ -76,7 +80,15 @@ class LogitsProcessor(nn.Module):
logits = lm_head.linear_method.apply(lm_head, logits = lm_head.linear_method.apply(lm_head,
hidden_states, hidden_states,
bias=embedding_bias) bias=embedding_bias)
logits = tensor_model_parallel_gather(logits) if self.use_gather:
logits = tensor_model_parallel_gather(logits)
else:
# Gather is not supported for some devices such as TPUs.
# Use all-gather instead.
# NOTE(woosuk): Here, the outputs of every device should not be None
# because XLA requires strict SPMD among all devices. Every device
# should execute the same operations after gathering the logits.
logits = tensor_model_parallel_all_gather(logits)
# Remove paddings in vocab (if any). # Remove paddings in vocab (if any).
if logits is not None: if logits is not None:
logits = logits[:, :self.org_vocab_size] logits = logits[:, :self.org_vocab_size]
......
import math
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch
...@@ -6,20 +5,9 @@ import triton ...@@ -6,20 +5,9 @@ import triton
import triton.language as tl import triton.language as tl
from vllm.model_executor.layers.ops.rand import seeded_uniform from vllm.model_executor.layers.ops.rand import seeded_uniform
from vllm.triton_utils.sample import get_num_triton_sampler_splits
_EPS = 1e-6 _EPS: tl.constexpr = 1e-6
# This is a hardcoded limit in Triton (max block size).
MAX_TRITON_N_COLS = 131072
def get_num_triton_sampler_splits(n_cols: int) -> int:
"""Get the number of splits to use for Triton sampling.
Triton has a limit on the number of columns it can handle, so we need to
split the tensor and call the kernel multiple times if it's too large.
"""
return math.ceil(n_cols / MAX_TRITON_N_COLS)
def _multi_split_sample( def _multi_split_sample(
......
...@@ -19,6 +19,7 @@ from vllm.model_executor.layers.quantization.gptq_marlin import ( ...@@ -19,6 +19,7 @@ from vllm.model_executor.layers.quantization.gptq_marlin import (
from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQMarlin24Config) GPTQMarlin24Config)
from vllm.model_executor.layers.quantization.marlin import MarlinConfig from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.layers.quantization.qqq import QQQConfig
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
...@@ -37,6 +38,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { ...@@ -37,6 +38,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"squeezellm": SqueezeLLMConfig, "squeezellm": SqueezeLLMConfig,
"compressed-tensors": CompressedTensorsConfig, "compressed-tensors": CompressedTensorsConfig,
"bitsandbytes": BitsAndBytesConfig, "bitsandbytes": BitsAndBytesConfig,
"qqq": QQQConfig,
} }
......
...@@ -10,11 +10,11 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, ...@@ -10,11 +10,11 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_awq_marlin_linear, awq_to_marlin_zero_points, apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported,
check_awq_marlin_supported, marlin_make_empty_g_idx, marlin_make_workspace, marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales,
marlin_permute_scales, replace_tensor, verify_awq_marlin_supported, replace_tensor, verify_marlin_supported, verify_marlin_supports_shape)
verify_marlin_supports_shape)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.scalar_type import scalar_types
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -22,20 +22,31 @@ logger = init_logger(__name__) ...@@ -22,20 +22,31 @@ logger = init_logger(__name__)
class AWQMarlinConfig(QuantizationConfig): class AWQMarlinConfig(QuantizationConfig):
"""Config class for AWQ Marlin""" """Config class for AWQ Marlin"""
# num_bits -> type
TYPE_MAP = {
4: scalar_types.uint4,
8: scalar_types.uint8,
}
def __init__(self, weight_bits: int, group_size: int, has_zp: bool, def __init__(self, weight_bits: int, group_size: int, has_zp: bool,
lm_head_quantized: bool) -> None: lm_head_quantized: bool) -> None:
self.weight_bits = weight_bits self.pack_factor = 32 // weight_bits # packed into int32
self.pack_factor = 32 // self.weight_bits # packed into int32
self.group_size = group_size self.group_size = group_size
self.has_zp = has_zp self.has_zp = has_zp
self.lm_head_quantized = lm_head_quantized self.lm_head_quantized = lm_head_quantized
verify_awq_marlin_supported(num_bits=self.weight_bits, if weight_bits not in self.TYPE_MAP:
group_size=self.group_size, raise ValueError(f"Unsupported num_bits = {weight_bits}. "
has_zp=self.has_zp) f"Supported num_bits = {self.TYPE_MAP.keys()}")
self.quant_type = self.TYPE_MAP[weight_bits]
verify_marlin_supported(self.quant_type,
group_size=self.group_size,
has_zp=self.has_zp)
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"AWQMarlinConfig(weight_bits={self.weight_bits}, " return (f"AWQMarlinConfig(quant_type={self.quant_type}, "
f"group_size={self.group_size}, " f"group_size={self.group_size}, "
f"has_zp={self.has_zp}, " f"has_zp={self.has_zp}, "
f"lm_head_quantized={self.lm_head_quantized})") f"lm_head_quantized={self.lm_head_quantized})")
...@@ -69,7 +80,8 @@ class AWQMarlinConfig(QuantizationConfig): ...@@ -69,7 +80,8 @@ class AWQMarlinConfig(QuantizationConfig):
def override_quantization_method(cls, hf_quant_cfg, def override_quantization_method(cls, hf_quant_cfg,
user_quant) -> Optional[str]: user_quant) -> Optional[str]:
can_convert = cls.is_awq_marlin_compatible(hf_quant_cfg) can_convert = cls.is_awq_marlin_compatible(hf_quant_cfg)
is_valid_user_quant = (user_quant is None or user_quant == "marlin") is_valid_user_quant = (user_quant is None or user_quant == "marlin"
or user_quant == "awq_marlin")
if can_convert and is_valid_user_quant: if can_convert and is_valid_user_quant:
msg = ("The model is convertible to {} during runtime." msg = ("The model is convertible to {} during runtime."
...@@ -109,11 +121,13 @@ class AWQMarlinConfig(QuantizationConfig): ...@@ -109,11 +121,13 @@ class AWQMarlinConfig(QuantizationConfig):
if (num_bits is None or group_size is None or has_zp is None): if (num_bits is None or group_size is None or has_zp is None):
return False return False
return check_awq_marlin_supported( if num_bits not in cls.TYPE_MAP:
num_bits=num_bits, return False
group_size=group_size,
has_zp=has_zp, return check_marlin_supported(quant_type=cls.TYPE_MAP[num_bits],
min_capability=cls.get_min_capability()) group_size=group_size,
has_zp=has_zp,
min_capability=cls.get_min_capability())
class AWQMarlinLinearMethod(LinearMethodBase): class AWQMarlinLinearMethod(LinearMethodBase):
...@@ -225,7 +239,7 @@ class AWQMarlinLinearMethod(LinearMethodBase): ...@@ -225,7 +239,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
layer.qweight, layer.qweight,
size_k=layer.input_size_per_partition, size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition, size_n=layer.output_size_per_partition,
num_bits=self.quant_config.weight_bits) num_bits=self.quant_config.quant_type.size_bits)
replace_tensor(layer, "qweight", marlin_qweight) replace_tensor(layer, "qweight", marlin_qweight)
# Permute scales from AWQ format to marlin format. # Permute scales from AWQ format to marlin format.
...@@ -241,7 +255,7 @@ class AWQMarlinLinearMethod(LinearMethodBase): ...@@ -241,7 +255,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
layer.qzeros, layer.qzeros,
size_k=layer.num_groups, size_k=layer.num_groups,
size_n=layer.output_size_per_partition, size_n=layer.output_size_per_partition,
num_bits=self.quant_config.weight_bits) num_bits=self.quant_config.quant_type.size_bits)
replace_tensor(layer, "qzeros", marlin_zp) replace_tensor(layer, "qzeros", marlin_zp)
# Not-used # Not-used
...@@ -262,7 +276,7 @@ class AWQMarlinLinearMethod(LinearMethodBase): ...@@ -262,7 +276,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
g_idx=layer.g_idx, g_idx=layer.g_idx,
g_idx_sort_indices=layer.g_idx_sort_indices, g_idx_sort_indices=layer.g_idx_sort_indices,
workspace=layer.workspace, workspace=layer.workspace,
num_bits=self.quant_config.weight_bits, quant_type=self.quant_config.quant_type,
output_size_per_partition=layer.output_size_per_partition, output_size_per_partition=layer.output_size_per_partition,
input_size_per_partition=layer.input_size_per_partition, input_size_per_partition=layer.input_size_per_partition,
bias=bias) bias=bias)
...@@ -15,19 +15,11 @@ class BitsAndBytesConfig(QuantizationConfig): ...@@ -15,19 +15,11 @@ class BitsAndBytesConfig(QuantizationConfig):
Reference: https://arxiv.org/abs/2305.14314 Reference: https://arxiv.org/abs/2305.14314
""" """
def __init__( def __init__(self, ) -> None:
self, pass
adapter_name_or_path: str,
target_modules: List[str],
) -> None:
self.adapter_name_or_path = adapter_name_or_path
self.target_modules = target_modules
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return "BitsAndBytesConfig"
f"BitsAndBytesConfig(adapter_name_or_path={self.adapter_name_or_path}"
)
@classmethod @classmethod
def get_name(self) -> str: def get_name(self) -> str:
...@@ -49,16 +41,7 @@ class BitsAndBytesConfig(QuantizationConfig): ...@@ -49,16 +41,7 @@ class BitsAndBytesConfig(QuantizationConfig):
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> "BitsAndBytesConfig": def from_config(cls, config: Dict[str, Any]) -> "BitsAndBytesConfig":
adapter_name = cls.get_from_keys(config, ["adapter_name_or_path"]) return cls()
default_target_modules = [
"gate_proj", "down_proj", "up_proj", "q_proj", "k_proj", "v_proj",
"o_proj"
]
if adapter_name == "":
target_modules = default_target_modules
else:
target_modules = cls.get_from_keys(config, ["target_modules"])
return cls(adapter_name, target_modules)
def get_quant_method(self, layer: torch.nn.Module, def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["BitsAndBytesLinearMethod"]: prefix: str) -> Optional["BitsAndBytesLinearMethod"]:
......
...@@ -10,7 +10,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( ...@@ -10,7 +10,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS,
CompressedTensorsScheme, CompressedTensorsUnquantized, CompressedTensorsScheme, CompressedTensorsUnquantized,
CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8, CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8,
CompressedTensorsW8A8Int8, CompressedTensorsWNA16) CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8,
CompressedTensorsWNA16)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
CompressionFormat, QuantizationArgs, QuantizationStrategy, CompressionFormat, QuantizationArgs, QuantizationStrategy,
QuantizationType, find_matched_target, is_activation_quantization_format, QuantizationType, find_matched_target, is_activation_quantization_format,
...@@ -100,14 +101,18 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -100,14 +101,18 @@ class CompressedTensorsConfig(QuantizationConfig):
def get_config_filenames(cls) -> List[str]: def get_config_filenames(cls) -> List[str]:
return [] return []
def _check_scheme_supported(self, min_capability: int): def _check_scheme_supported(self,
min_capability: int,
error: bool = True) -> bool:
capability = current_platform.get_device_capability() capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1] capability = capability[0] * 10 + capability[1]
if capability < min_capability: supported = capability >= min_capability
if error and not supported:
raise RuntimeError( raise RuntimeError(
"Quantization scheme is not supported for ", "Quantization scheme is not supported for ",
f"the current GPU. Min capability: {min_capability}. ", f"the current GPU. Min capability: {min_capability}. ",
f"Current capability: {capability}.") f"Current capability: {capability}.")
return supported
def _is_static_tensor_w8a8(self, weight_quant: BaseModel, def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool: input_quant: BaseModel) -> bool:
...@@ -170,6 +175,29 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -170,6 +175,29 @@ class CompressedTensorsConfig(QuantizationConfig):
# All conditions satisfied. # All conditions satisfied.
return True return True
def _is_fp8_w8a16(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
# Confirm weights quantized.
if weight_quant is None:
return False
# Confirm we have floating points.
if weight_quant.type != QuantizationType.FLOAT:
return False
# Confirm weight scheme is supported.
is_symmetric_weight = weight_quant.symmetric
is_static_weight = not weight_quant.dynamic
is_per_tensor_or_channel_weight = (weight_quant.strategy in [
QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL
])
if not (is_symmetric_weight and is_static_weight
and is_per_tensor_or_channel_weight):
return False
# All conditions satisfied.
return True
def _is_wNa16_group_channel(self, weight_quant: BaseModel, def _is_wNa16_group_channel(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool: input_quant: BaseModel) -> bool:
input_quant_none = input_quant is None input_quant_none = input_quant is None
...@@ -204,9 +232,23 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -204,9 +232,23 @@ class CompressedTensorsConfig(QuantizationConfig):
# Detect If Activation Quantization. # Detect If Activation Quantization.
if is_activation_quantization_format(self.quant_format): if is_activation_quantization_format(self.quant_format):
if self._is_fp8_w8a8(weight_quant, input_quant): if self._is_fp8_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Fp8( is_fp8_w8a8_supported = self._check_scheme_supported(
CompressedTensorsW8A8Fp8.get_min_capability(), error=False)
if is_fp8_w8a8_supported:
return CompressedTensorsW8A8Fp8(
strategy=weight_quant.strategy,
is_static_input_scheme=(not input_quant.dynamic))
else:
return CompressedTensorsW8A16Fp8(
strategy=weight_quant.strategy,
is_static_input_scheme=(input_quant
and not input_quant.dynamic))
if self._is_fp8_w8a16(weight_quant, input_quant):
return CompressedTensorsW8A16Fp8(
strategy=weight_quant.strategy, strategy=weight_quant.strategy,
is_static_input_scheme=(not input_quant.dynamic)) is_static_input_scheme=(input_quant
and not input_quant.dynamic))
if self._is_static_tensor_w8a8(weight_quant, input_quant): if self._is_static_tensor_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Int8( return CompressedTensorsW8A8Int8(
...@@ -257,11 +299,10 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -257,11 +299,10 @@ class CompressedTensorsConfig(QuantizationConfig):
targets=self.target_scheme_map.keys()) targets=self.target_scheme_map.keys())
# Find the quant_scheme # Find the quant_scheme
scheme = self.target_scheme_map[matched_target] scheme_dict = self.target_scheme_map[matched_target]
scheme = self._get_scheme_from_parts(
return self._get_scheme_from_parts( weight_quant=scheme_dict["weights"],
weight_quant=scheme["weights"], input_quant=scheme_dict["input_activations"])
input_quant=scheme["input_activations"])
# Raise error if device does not support the scheme # Raise error if device does not support the scheme
# (e.g. fp8 needs ada lovelace) # (e.g. fp8 needs ada lovelace)
......
...@@ -4,6 +4,7 @@ from .compressed_tensors_w4a16_24 import (W4A16SPARSE24_SUPPORTED_BITS, ...@@ -4,6 +4,7 @@ from .compressed_tensors_w4a16_24 import (W4A16SPARSE24_SUPPORTED_BITS,
CompressedTensorsW4A16Sparse24) CompressedTensorsW4A16Sparse24)
from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8 from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8 from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8
from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8
from .compressed_tensors_wNa16 import (WNA16_SUPPORTED_BITS, from .compressed_tensors_wNa16 import (WNA16_SUPPORTED_BITS,
CompressedTensorsWNA16) CompressedTensorsWNA16)
...@@ -11,6 +12,7 @@ __all__ = [ ...@@ -11,6 +12,7 @@ __all__ = [
"CompressedTensorsScheme", "CompressedTensorsScheme",
"CompressedTensorsUnquantized", "CompressedTensorsUnquantized",
"CompressedTensorsWNA16", "CompressedTensorsWNA16",
"CompressedTensorsW8A16Fp8",
"CompressedTensorsW4A16Sparse24", "CompressedTensorsW4A16Sparse24",
"CompressedTensorsW8A8Int8", "CompressedTensorsW8A8Int8",
"CompressedTensorsW8A8Fp8", "CompressedTensorsW8A8Fp8",
......
...@@ -12,8 +12,9 @@ class CompressedTensorsScheme(ABC): ...@@ -12,8 +12,9 @@ class CompressedTensorsScheme(ABC):
of different quantization schemes supported by CompressedTensors. of different quantization schemes supported by CompressedTensors.
""" """
@classmethod
@abstractmethod @abstractmethod
def get_min_capability(self) -> int: def get_min_capability(cls) -> int:
""" """
Get minimum device capability. Get minimum device capability.
""" """
......
...@@ -18,7 +18,8 @@ class CompressedTensorsUnquantized(CompressedTensorsScheme): ...@@ -18,7 +18,8 @@ class CompressedTensorsUnquantized(CompressedTensorsScheme):
in a linear transformation. in a linear transformation.
""" """
def get_min_capability(self) -> int: @classmethod
def get_min_capability(cls) -> int:
# volta and up # volta and up
return 70 return 70
......
...@@ -9,9 +9,13 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( ...@@ -9,9 +9,13 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N) GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.scalar_type import scalar_types
__all__ = ["CompressedTensorsW4A16Sparse24"] __all__ = ["CompressedTensorsW4A16Sparse24"]
W4A16SPARSE24_SUPPORTED_BITS = [4] W4A16SPARSE24_SUPPORTED_TYPES_MAP = {
4: scalar_types.uint4b8,
}
W4A16SPARSE24_SUPPORTED_BITS = list(W4A16SPARSE24_SUPPORTED_TYPES_MAP.keys())
class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme): class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
...@@ -22,14 +26,21 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme): ...@@ -22,14 +26,21 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
group_size: Optional[int] = None): group_size: Optional[int] = None):
self.strategy = strategy self.strategy = strategy
self.group_size = group_size self.group_size = group_size
self.num_bits = num_bits
self.tile_size = 16 self.tile_size = 16
if num_bits not in W4A16SPARSE24_SUPPORTED_TYPES_MAP:
raise ValueError(
f"Unsupported num_bits = {num_bits}. "
f"Supported num_bits = {W4A16SPARSE24_SUPPORTED_BITS}")
self.quant_type = W4A16SPARSE24_SUPPORTED_TYPES_MAP[num_bits]
if self.strategy == "group" and self.group_size is None: if self.strategy == "group" and self.group_size is None:
raise ValueError( raise ValueError(
"group_size must be given when using strategy group") "group_size must be given when using strategy group")
def get_min_capability(self) -> int: @classmethod
def get_min_capability(cls) -> int:
# ampere + up # ampere + up
return 80 return 80
...@@ -42,7 +53,7 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme): ...@@ -42,7 +53,7 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
params_dtype: torch.dtype, weight_loader: Callable, params_dtype: torch.dtype, weight_loader: Callable,
**kwargs): **kwargs):
pack_factor = 32 // self.num_bits pack_factor = 32 // self.quant_type.size_bits
output_size_per_partition = sum(output_partition_sizes) output_size_per_partition = sum(output_partition_sizes)
qweight = Parameter( qweight = Parameter(
...@@ -137,7 +148,7 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme): ...@@ -137,7 +148,7 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
size_n = scales.shape[1] size_n = scales.shape[1]
output_2d = ops.gptq_marlin_24_gemm(x_2d, qweight, meta, scales, output_2d = ops.gptq_marlin_24_gemm(x_2d, qweight, meta, scales,
workspace, self.num_bits, size_m, workspace, self.quant_type, size_m,
size_n, size_k) size_n, size_k)
output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], )) output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
......
from typing import Callable, List, Optional
import torch
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
QuantizationStrategy)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise, create_per_channel_scale_param,
create_per_tensor_scale_param)
from vllm.model_executor.utils import set_weight_attrs
__all__ = ["CompressedTensorsW8A16Fp8"]
SUPPORTED_STRATEGIES = [
QuantizationStrategy.CHANNEL, QuantizationStrategy.TENSOR
]
class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
def __init__(self, strategy: str, is_static_input_scheme: bool):
self.strategy = strategy
self.is_static_input_scheme = is_static_input_scheme
@classmethod
def get_min_capability(cls) -> int:
# ampere and up
return 80
# W8A8-Fp8 kernels support only per-tensor and per-channel cases.
# So if we have a fused module (QKV, MLP) with per tensor scales,
# we expand each scale to its shard's channels.
def process_weights_after_loading(self, layer) -> None:
if self.strategy == QuantizationStrategy.TENSOR:
ws_channelwise = convert_to_channelwise(layer.weight_scale,
layer.logical_widths)
layer.weight_scale = torch.nn.Parameter(ws_channelwise,
requires_grad=False)
# Weights must be transposed for marlin
layer.weight = torch.nn.Parameter(layer.weight.t(),
requires_grad=False)
prepare_fp8_layer_for_marlin(layer, strategy="channel")
def create_weights(self, layer: torch.nn.Module, input_size: int,
output_partition_sizes: List[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.orig_dtype = params_dtype
# WEIGHT
weight = torch.nn.Parameter(torch.empty(output_size_per_partition,
input_size_per_partition,
dtype=torch.float8_e4m3fn),
requires_grad=False)
layer.register_parameter("weight", weight)
set_weight_attrs(weight, {
"input_dim": 1,
"output_dim": 0,
"weight_loader": weight_loader,
})
# WEIGHT SCALE
layer_kwargs = {"weight_loader": weight_loader}
if self.strategy == QuantizationStrategy.CHANNEL:
weight_scale = create_per_channel_scale_param(
output_partition_sizes, **layer_kwargs)
elif self.strategy == QuantizationStrategy.TENSOR:
weight_scale = create_per_tensor_scale_param(
output_partition_sizes, **layer_kwargs)
else:
raise ValueError(
f"Unsupported weight strategy={self.strategy}, "
f"supported strategies are {SUPPORTED_STRATEGIES}")
layer.register_parameter("weight_scale", weight_scale)
# INPUT SCALE (to deal with converted checkpoints)
if self.is_static_input_scheme:
input_scale = create_per_tensor_scale_param(
output_partition_sizes, **layer_kwargs)
layer.register_parameter("input_scale", input_scale)
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return apply_fp8_marlin_linear(input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
workspace=layer.workspace,
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
bias=bias)
...@@ -23,7 +23,8 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): ...@@ -23,7 +23,8 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
self.is_static_input_scheme = is_static_input_scheme self.is_static_input_scheme = is_static_input_scheme
self.cutlass_fp8_supported = cutlass_fp8_supported() self.cutlass_fp8_supported = cutlass_fp8_supported()
def get_min_capability(self) -> int: @classmethod
def get_min_capability(cls) -> int:
# lovelace and up # lovelace and up
return 89 return 89
...@@ -77,19 +78,20 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): ...@@ -77,19 +78,20 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
}) })
# WEIGHT SCALE # WEIGHT SCALE
layer_kwargs = {"weight_loader": weight_loader}
if self.strategy == QuantizationStrategy.CHANNEL: if self.strategy == QuantizationStrategy.CHANNEL:
weight_scale = create_per_channel_scale_param( weight_scale = create_per_channel_scale_param(
output_partition_sizes, weight_loader=weight_loader) output_partition_sizes, **layer_kwargs)
else: else:
assert self.strategy == QuantizationStrategy.TENSOR assert self.strategy == QuantizationStrategy.TENSOR
weight_scale = create_per_tensor_scale_param( weight_scale = create_per_tensor_scale_param(
output_partition_sizes, weight_loader=weight_loader) output_partition_sizes, **layer_kwargs)
layer.register_parameter("weight_scale", weight_scale) layer.register_parameter("weight_scale", weight_scale)
# INPUT SCALE # INPUT SCALE
if self.is_static_input_scheme: if self.is_static_input_scheme:
input_scale = create_per_tensor_scale_param( input_scale = create_per_tensor_scale_param(
output_partition_sizes, weight_loader=weight_loader) output_partition_sizes, **layer_kwargs)
layer.register_parameter("input_scale", input_scale) layer.register_parameter("input_scale", input_scale)
def apply_weights(self, def apply_weights(self,
......
...@@ -19,7 +19,8 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme): ...@@ -19,7 +19,8 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
self.strategy = strategy self.strategy = strategy
self.is_static_input_scheme = is_static_input_scheme self.is_static_input_scheme = is_static_input_scheme
def get_min_capability(self) -> int: @classmethod
def get_min_capability(cls) -> int:
# turing and up # turing and up
return 75 return 75
...@@ -68,19 +69,19 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme): ...@@ -68,19 +69,19 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
# WEIGHT SCALE # WEIGHT SCALE
layer_kwargs = {"weight_loader": weight_loader} layer_kwargs = {"weight_loader": weight_loader}
if self.strategy == QuantizationStrategy.CHANNEL: if self.strategy == QuantizationStrategy.CHANNEL:
scale = create_per_channel_scale_param(output_partition_sizes, weight_scale = create_per_channel_scale_param(
**layer_kwargs) output_partition_sizes, **layer_kwargs)
else: else:
assert self.strategy == QuantizationStrategy.TENSOR assert self.strategy == QuantizationStrategy.TENSOR
scale = create_per_tensor_scale_param(output_partition_sizes, weight_scale = create_per_tensor_scale_param(
**layer_kwargs) output_partition_sizes, **layer_kwargs)
layer.register_parameter("weight_scale", scale) layer.register_parameter("weight_scale", weight_scale)
# INPUT SCALE # INPUT SCALE
if self.is_static_input_scheme: if self.is_static_input_scheme:
scale = create_per_tensor_scale_param(output_partition_sizes, input_scale = create_per_tensor_scale_param(
**layer_kwargs) output_partition_sizes, **layer_kwargs)
layer.register_parameter("input_scale", scale) layer.register_parameter("input_scale", input_scale)
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor: bias: Optional[torch.Tensor]) -> torch.Tensor:
......
...@@ -8,12 +8,17 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( ...@@ -8,12 +8,17 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme) CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_gptq_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace, apply_gptq_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace,
marlin_permute_scales, replace_tensor, verify_gptq_marlin_supported, marlin_permute_scales, replace_tensor, verify_marlin_supported,
verify_marlin_supports_shape) verify_marlin_supports_shape)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.scalar_type import scalar_types
__all__ = ["CompressedTensorsWNA16"] __all__ = ["CompressedTensorsWNA16"]
WNA16_SUPPORTED_BITS = [4, 8] WNA16_SUPPORTED_TYPES_MAP = {
4: scalar_types.uint4b8,
8: scalar_types.uint8b128,
}
WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys())
class CompressedTensorsWNA16(CompressedTensorsScheme): class CompressedTensorsWNA16(CompressedTensorsScheme):
...@@ -22,8 +27,8 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): ...@@ -22,8 +27,8 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
strategy: str, strategy: str,
num_bits: int, num_bits: int,
group_size: Optional[int] = None): group_size: Optional[int] = None):
self.num_bits = num_bits
self.pack_factor = 32 // self.num_bits self.pack_factor = 32 // num_bits
self.strategy = strategy self.strategy = strategy
self.group_size: int self.group_size: int
...@@ -37,12 +42,19 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): ...@@ -37,12 +42,19 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
else: else:
self.group_size = group_size self.group_size = group_size
if num_bits not in WNA16_SUPPORTED_TYPES_MAP:
raise ValueError(
f"Unsupported num_bits = {num_bits}. "
f"Supported num_bits = {WNA16_SUPPORTED_TYPES_MAP.keys()}")
self.quant_type = WNA16_SUPPORTED_TYPES_MAP[num_bits]
# Verify supported on platform. # Verify supported on platform.
verify_gptq_marlin_supported(num_bits=self.num_bits, verify_marlin_supported(quant_type=self.quant_type,
group_size=self.group_size, group_size=self.group_size)
is_sym=True)
def get_min_capability(self) -> int: @classmethod
def get_min_capability(cls) -> int:
# ampere and up # ampere and up
return 80 return 80
...@@ -54,7 +66,12 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): ...@@ -54,7 +66,12 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
output_size_per_partition = sum(output_partition_sizes) output_size_per_partition = sum(output_partition_sizes)
# If group_size is -1, we are in channelwise case. # If group_size is -1, we are in channelwise case.
group_size = input_size if self.group_size == -1 else self.group_size channelwise = (self.group_size == -1)
group_size = input_size if channelwise else self.group_size
row_parallel = (input_size != input_size_per_partition)
# In the case of channelwise quantization, we need to replicate the
# scales across all gpus.
partition_scales = (row_parallel and not channelwise)
verify_marlin_supports_shape( verify_marlin_supports_shape(
output_size_per_partition=output_size_per_partition, output_size_per_partition=output_size_per_partition,
...@@ -65,8 +82,8 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): ...@@ -65,8 +82,8 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
weight_scale_dim = None weight_scale_dim = None
scales_and_zp_size = input_size // group_size scales_and_zp_size = input_size // group_size
if (input_size != input_size_per_partition if partition_scales:
and self.group_size is not None): assert input_size_per_partition % group_size == 0
weight_scale_dim = 1 weight_scale_dim = 1
scales_and_zp_size = input_size_per_partition // group_size scales_and_zp_size = input_size_per_partition // group_size
...@@ -144,7 +161,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): ...@@ -144,7 +161,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
perm=layer.g_idx_sort_indices, perm=layer.g_idx_sort_indices,
size_k=layer.input_size_per_partition, size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition, size_n=layer.output_size_per_partition,
num_bits=self.num_bits) num_bits=self.quant_type.size_bits)
replace_tensor(layer, "weight_packed", marlin_qweight) replace_tensor(layer, "weight_packed", marlin_qweight)
# Permute scales from compressed-tensors format to marlin format. # Permute scales from compressed-tensors format to marlin format.
...@@ -166,7 +183,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): ...@@ -166,7 +183,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
g_idx=layer.g_idx, g_idx=layer.g_idx,
g_idx_sort_indices=layer.g_idx_sort_indices, g_idx_sort_indices=layer.g_idx_sort_indices,
workspace=layer.workspace, workspace=layer.workspace,
num_bits=self.num_bits, wtype=self.quant_type,
output_size_per_partition=layer.output_size_per_partition, output_size_per_partition=layer.output_size_per_partition,
input_size_per_partition=layer.input_size_per_partition, input_size_per_partition=layer.input_size_per_partition,
is_k_full=True, is_k_full=True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment