"vscode:/vscode.git/clone" did not exist on "078df46bc9a99178a9a744b872899990353769a4"
Unverified Commit 9b81f9bd authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

sglang quant module remove vllm dependency (#4507)

parent f81a27f6
...@@ -6,21 +6,41 @@ from copy import deepcopy ...@@ -6,21 +6,41 @@ from copy import deepcopy
from typing import Callable, Dict, Optional, Type, Union from typing import Callable, Dict, Optional, Type, Union
import torch import torch
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig try:
from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
CompressedTensorsConfig, from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
) from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (
from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig CompressedTensorsConfig,
from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config )
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
from vllm.model_executor.layers.quantization.gguf import GGUFConfig from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
from vllm.model_executor.layers.quantization.gptq_marlin_24 import GPTQMarlin24Config from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
from vllm.model_executor.layers.quantization.marlin import MarlinConfig from vllm.model_executor.layers.quantization.gguf import GGUFConfig
from vllm.model_executor.layers.quantization.qqq import QQQConfig from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig GPTQMarlin24Config,
)
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.layers.quantization.qqq import QQQConfig
from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
VLLM_AVAILABLE = True
except ImportError:
VLLM_AVAILABLE = False
# Define empty classes as placeholders when vllm is not available
class DummyConfig:
pass
AQLMConfig = AWQConfig = AWQMarlinConfig = BitsAndBytesConfig = (
CompressedTensorsConfig
) = DummyConfig
DeepSpeedFPConfig = ExpertsInt8Config = FBGEMMFp8Config = GGUFConfig = (
GPTQMarlin24Config
) = DummyConfig
MarlinConfig = QQQConfig = Int8TpuConfig = DummyConfig
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
...@@ -30,29 +50,37 @@ from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config ...@@ -30,29 +50,37 @@ from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { # Base quantization methods that don't depend on vllm
"aqlm": AQLMConfig, BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"awq": AWQConfig,
"deepspeedfp": DeepSpeedFPConfig,
"tpu_int8": Int8TpuConfig,
"fp8": Fp8Config, "fp8": Fp8Config,
"blockwise_int8": BlockInt8Config, "blockwise_int8": BlockInt8Config,
"fbgemm_fp8": FBGEMMFp8Config,
"marlin": MarlinConfig,
"modelopt": ModelOptFp8Config, "modelopt": ModelOptFp8Config,
"gguf": GGUFConfig,
"gptq_marlin_24": GPTQMarlin24Config,
"gptq_marlin": GPTQMarlinConfig, "gptq_marlin": GPTQMarlinConfig,
"awq_marlin": AWQMarlinConfig,
"gptq": GPTQConfig, "gptq": GPTQConfig,
"compressed-tensors": CompressedTensorsConfig,
"bitsandbytes": BitsAndBytesConfig,
"qqq": QQQConfig,
"experts_int8": ExpertsInt8Config,
"w8a8_int8": W8A8Int8Config, "w8a8_int8": W8A8Int8Config,
"w8a8_fp8": W8A8Fp8Config, "w8a8_fp8": W8A8Fp8Config,
} }
# Add vllm-dependent methods if available
QUANTIZATION_METHODS = BASE_QUANTIZATION_METHODS.copy()
if VLLM_AVAILABLE:
VLLM_QUANTIZATION_METHODS = {
"aqlm": AQLMConfig,
"awq": AWQConfig,
"deepspeedfp": DeepSpeedFPConfig,
"tpu_int8": Int8TpuConfig,
"fbgemm_fp8": FBGEMMFp8Config,
"marlin": MarlinConfig,
"gguf": GGUFConfig,
"gptq_marlin_24": GPTQMarlin24Config,
"awq_marlin": AWQMarlinConfig,
"compressed-tensors": CompressedTensorsConfig,
"bitsandbytes": BitsAndBytesConfig,
"qqq": QQQConfig,
"experts_int8": ExpertsInt8Config,
}
QUANTIZATION_METHODS.update(VLLM_QUANTIZATION_METHODS)
def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
if quantization not in QUANTIZATION_METHODS: if quantization not in QUANTIZATION_METHODS:
...@@ -157,25 +185,31 @@ def get_linear_quant_method( ...@@ -157,25 +185,31 @@ def get_linear_quant_method(
def gptq_get_quant_method(self, layer, prefix): def gptq_get_quant_method(self, layer, prefix):
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod if not VLLM_AVAILABLE:
from vllm.model_executor.layers.quantization.gptq_marlin import ( return None
GPTQMarlinLinearMethod,
GPTQMarlinMoEMethod, try:
) from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinLinearMethod,
GPTQMarlinMoEMethod,
)
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
if isinstance(layer, FusedMoE): if isinstance(layer, FusedMoE):
return GPTQMarlinMoEMethod(self) return GPTQMarlinMoEMethod(self)
if isinstance(self, GPTQConfig): if isinstance(self, GPTQConfig):
return get_linear_quant_method( return get_linear_quant_method(
self, layer, prefix=prefix, linear_method_cls=GPTQLinearMethod self, layer, prefix=prefix, linear_method_cls=GPTQLinearMethod
) )
elif isinstance(self, GPTQMarlinConfig): elif isinstance(self, GPTQMarlinConfig):
return get_linear_quant_method( return get_linear_quant_method(
self, layer, prefix=prefix, linear_method_cls=GPTQMarlinLinearMethod self, layer, prefix=prefix, linear_method_cls=GPTQMarlinLinearMethod
) )
except ImportError:
pass
return None return None
...@@ -187,33 +221,40 @@ def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False): ...@@ -187,33 +221,40 @@ def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False):
Patch isinstance so that the `get_quant_method` in vllm's QuantizationConfig Patch isinstance so that the `get_quant_method` in vllm's QuantizationConfig
can recognize sglang layers can recognize sglang layers
""" """
if not VLLM_AVAILABLE:
return
if reverse: if reverse:
builtins.isinstance = original_isinstance builtins.isinstance = original_isinstance
return return
from vllm.model_executor.layers.fused_moe import FusedMoE try:
from vllm.model_executor.layers.linear import LinearBase from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.linear import LinearBase
VocabParallelEmbedding, from vllm.model_executor.layers.vocab_parallel_embedding import (
) VocabParallelEmbedding,
)
from sglang.srt.layers.linear import LinearBase as PatchedLinearBase
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE as PatchedFusedMoE
from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding as PatchedVocabParallelEmbedding,
)
def patched_isinstance(obj, classinfo): from sglang.srt.layers.linear import LinearBase as PatchedLinearBase
if classinfo is LinearBase: from sglang.srt.layers.moe.fused_moe_triton.layer import (
return original_isinstance(obj, PatchedLinearBase) FusedMoE as PatchedFusedMoE,
if classinfo is FusedMoE: )
return original_isinstance(obj, PatchedFusedMoE) from sglang.srt.layers.vocab_parallel_embedding import (
if classinfo is VocabParallelEmbedding: VocabParallelEmbedding as PatchedVocabParallelEmbedding,
return original_isinstance(obj, PatchedVocabParallelEmbedding) )
return original_isinstance(obj, classinfo)
builtins.isinstance = patched_isinstance def patched_isinstance(obj, classinfo):
if classinfo is LinearBase:
return original_isinstance(obj, PatchedLinearBase)
if classinfo is FusedMoE:
return original_isinstance(obj, PatchedFusedMoE)
if classinfo is VocabParallelEmbedding:
return original_isinstance(obj, PatchedVocabParallelEmbedding)
return original_isinstance(obj, classinfo)
builtins.isinstance = patched_isinstance
except ImportError:
return
def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"): def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
...@@ -221,72 +262,88 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"): ...@@ -221,72 +262,88 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
Monkey patch the apply function of vllm's FusedMoEMethodBase. Monkey patch the apply function of vllm's FusedMoEMethodBase.
Convert sglang arguments to vllm arguments. Convert sglang arguments to vllm arguments.
""" """
original_apply = class_obj.apply if not VLLM_AVAILABLE:
sig = inspect.signature(original_apply) return
param_names = list(sig.parameters.keys())
has_correction_bias = "e_score_correction_bias" in param_names try:
original_apply = class_obj.apply
def new_apply( sig = inspect.signature(original_apply)
self, param_names = list(sig.parameters.keys())
layer: torch.nn.Module, has_correction_bias = "e_score_correction_bias" in param_names
x: torch.Tensor,
router_logits: torch.Tensor, def new_apply(
top_k: int, self,
renormalize: bool, layer: torch.nn.Module,
use_grouped_topk: bool, x: torch.Tensor,
topk_group: Optional[int] = None, router_logits: torch.Tensor,
num_expert_group: Optional[int] = None, top_k: int,
custom_routing_function: Optional[Callable] = None, renormalize: bool,
correction_bias: Optional[torch.Tensor] = None, use_grouped_topk: bool,
activation: str = "silu", topk_group: Optional[int] = None,
inplace: bool = True, num_expert_group: Optional[int] = None,
no_combine: bool = False, custom_routing_function: Optional[Callable] = None,
): correction_bias: Optional[torch.Tensor] = None,
assert activation == "silu" activation: str = "silu",
assert inplace and not no_combine inplace: bool = True,
no_combine: bool = False,
kwargs = { ):
"self": self, assert activation == "silu"
"layer": layer, assert inplace and not no_combine
"x": x,
"router_logits": router_logits, kwargs = {
"top_k": top_k, "self": self,
"renormalize": renormalize, "layer": layer,
"use_grouped_topk": use_grouped_topk, "x": x,
"topk_group": topk_group, "router_logits": router_logits,
"num_expert_group": num_expert_group, "top_k": top_k,
"custom_routing_function": custom_routing_function, "renormalize": renormalize,
} "use_grouped_topk": use_grouped_topk,
if correction_bias is not None: "topk_group": topk_group,
if not has_correction_bias: "num_expert_group": num_expert_group,
raise ValueError( "custom_routing_function": custom_routing_function,
"Please increase the version of your vllm. Try `pip install vllm==0.7.2`" }
) if correction_bias is not None:
kwargs["e_score_correction_bias"] = correction_bias if not has_correction_bias:
return original_apply(**kwargs) raise ValueError(
"Please increase the version of your vllm. Try `pip install vllm==0.7.2`"
setattr(class_obj, "apply", new_apply) )
kwargs["e_score_correction_bias"] = correction_bias
return original_apply(**kwargs)
setattr(class_obj, "apply", new_apply)
except (ImportError, AttributeError):
return
def monkey_patch_quant_configs(): def monkey_patch_quant_configs():
"""Apply all monkey patches in one place.""" """Apply all monkey patches in one place."""
from vllm.model_executor.layers.quantization.awq_marlin import AWQMoEMethod if not VLLM_AVAILABLE:
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( return
CompressedTensorsW8A8Fp8MoEMethod,
CompressedTensorsWNA16MoEMethod,
)
from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinMoEMethod
setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method) try:
setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method) from vllm.model_executor.layers.quantization.awq_marlin import AWQMoEMethod
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
CompressedTensorsW8A8Fp8MoEMethod,
CompressedTensorsWNA16MoEMethod,
)
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinMoEMethod,
)
setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method)
monkey_patch_moe_apply(AWQMoEMethod) monkey_patch_moe_apply(AWQMoEMethod)
monkey_patch_moe_apply(GPTQMarlinMoEMethod) monkey_patch_moe_apply(GPTQMarlinMoEMethod)
monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod) monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod)
monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod) monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod)
except ImportError:
return
monkey_patch_quant_configs() # Only apply monkey patches if vllm is available
if VLLM_AVAILABLE:
monkey_patch_quant_configs()
__all__ = [ __all__ = [
......
...@@ -5,7 +5,6 @@ from typing import Any, Callable, Dict, List, Optional ...@@ -5,7 +5,6 @@ from typing import Any, Callable, Dict, List, Optional
import torch import torch
from torch.nn import Module from torch.nn import Module
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
...@@ -19,6 +18,7 @@ from sglang.srt.layers.quantization.base_config import ( ...@@ -19,6 +18,7 @@ from sglang.srt.layers.quantization.base_config import (
QuantizeMethodBase, QuantizeMethodBase,
) )
from sglang.srt.layers.quantization.int8_utils import apply_w8a8_block_int8_linear from sglang.srt.layers.quantization.int8_utils import apply_w8a8_block_int8_linear
from sglang.srt.layers.quantization.utils import is_layer_skipped
from sglang.srt.utils import set_weight_attrs from sglang.srt.utils import set_weight_attrs
ACTIVATION_SCHEMES = ["static", "dynamic"] ACTIVATION_SCHEMES = ["static", "dynamic"]
......
...@@ -7,20 +7,33 @@ import torch ...@@ -7,20 +7,33 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn import Module from torch.nn import Module
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( from sglang.srt.layers.quantization.utils import (
apply_fp8_marlin_linear,
prepare_fp8_layer_for_marlin,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d, all_close_1d,
convert_to_channelwise, convert_to_channelwise,
is_layer_skipped,
per_tensor_dequantize, per_tensor_dequantize,
requantize_with_max_scale, requantize_with_max_scale,
) )
try:
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear,
prepare_fp8_layer_for_marlin,
)
MARLIN_FP8_AVAILABLE = True
except ImportError:
MARLIN_FP8_AVAILABLE = False
def apply_fp8_marlin_linear(*args, **kwargs):
raise ImportError("vllm is not installed")
def prepare_fp8_layer_for_marlin(*args, **kwargs):
raise ImportError("vllm is not installed")
from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
LinearBase, LinearBase,
...@@ -46,6 +59,7 @@ from sglang.srt.layers.quantization.fp8_utils import ( ...@@ -46,6 +59,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
) )
from sglang.srt.utils import ( from sglang.srt.utils import (
get_bool_env_var, get_bool_env_var,
is_cuda,
is_hip, is_hip,
permute_weight, permute_weight,
print_warning_once, print_warning_once,
...@@ -60,6 +74,13 @@ if _is_hip: ...@@ -60,6 +74,13 @@ if _is_hip:
from aiter.fused_moe_bf16_asm import asm_moe from aiter.fused_moe_bf16_asm import asm_moe
from aiter.ops.shuffle import shuffle_weight from aiter.ops.shuffle import shuffle_weight
_is_cuda = is_cuda()
if _is_cuda:
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
else:
from vllm import _custom_ops as vllm_ops
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -173,7 +194,9 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -173,7 +194,9 @@ class Fp8LinearMethod(LinearMethodBase):
# For GPUs that lack FP8 hardware support, we can leverage the Marlin # For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization # kernel for fast weight-only FP8 quantization
self.use_marlin = get_bool_env_var("SGLANG_FORCE_FP8_MARLIN") self.use_marlin = (
get_bool_env_var("SGLANG_FORCE_FP8_MARLIN") and MARLIN_FP8_AVAILABLE
)
# Disable marlin for ROCm # Disable marlin for ROCm
if _is_hip: if _is_hip:
self.use_marlin = False self.use_marlin = False
...@@ -371,9 +394,12 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -371,9 +394,12 @@ class Fp8LinearMethod(LinearMethodBase):
) )
if self.use_marlin: if self.use_marlin:
prepare_fp8_layer_for_marlin(layer) try:
# Activations not quantized for marlin. prepare_fp8_layer_for_marlin(layer)
del layer.input_scale # Activations not quantized for marlin.
del layer.input_scale
except ImportError:
self.use_marlin = False
def apply( def apply(
self, self,
...@@ -383,15 +409,18 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -383,15 +409,18 @@ class Fp8LinearMethod(LinearMethodBase):
) -> torch.Tensor: ) -> torch.Tensor:
if self.use_marlin: if self.use_marlin:
return apply_fp8_marlin_linear( try:
input=x, return apply_fp8_marlin_linear(
weight=layer.weight, input=x,
weight_scale=layer.weight_scale, weight=layer.weight,
workspace=layer.workspace, weight_scale=layer.weight_scale,
size_n=layer.output_size_per_partition, workspace=layer.workspace,
size_k=layer.input_size_per_partition, size_n=layer.output_size_per_partition,
bias=bias, size_k=layer.input_size_per_partition,
) bias=bias,
)
except ImportError:
self.use_marlin = False
if self.block_quant: if self.block_quant:
return apply_w8a8_block_fp8_linear( return apply_w8a8_block_fp8_linear(
...@@ -680,12 +709,20 @@ class Fp8MoEMethod: ...@@ -680,12 +709,20 @@ class Fp8MoEMethod:
requires_grad=False, requires_grad=False,
) )
for expert in range(layer.num_experts): for expert in range(layer.num_experts):
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = ( if _is_cuda:
ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :]) w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
) sgl_scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = ( )
ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :]) w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
) sgl_scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
)
else:
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
vllm_ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
)
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
vllm_ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
)
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
......
...@@ -28,7 +28,12 @@ if _is_cuda: ...@@ -28,7 +28,12 @@ if _is_cuda:
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_quant_fp8 from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_quant_fp8
if use_vllm_cutlass_w8a8_fp8_kernel: if use_vllm_cutlass_w8a8_fp8_kernel:
from vllm import _custom_ops as ops try:
from vllm import _custom_ops as ops
VLLM_AVAILABLE = True
except ImportError:
VLLM_AVAILABLE = False
else: else:
from sgl_kernel import fp8_scaled_mm from sgl_kernel import fp8_scaled_mm
...@@ -219,90 +224,97 @@ def apply_fp8_linear( ...@@ -219,90 +224,97 @@ def apply_fp8_linear(
) )
if cutlass_fp8_supported: if cutlass_fp8_supported:
if use_vllm_cutlass_w8a8_fp8_kernel: try:
# Fall back to vllm cutlass w8a8 fp8 kernel if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel:
output = ops.cutlass_scaled_mm( # Fall back to vllm cutlass w8a8 fp8 kernel
qinput, output = ops.cutlass_scaled_mm(
weight, qinput,
out_dtype=input.dtype, weight,
scale_a=x_scale, out_dtype=input.dtype,
scale_b=weight_scale, scale_a=x_scale,
bias=bias, scale_b=weight_scale,
) bias=bias,
else: )
assert ( else:
weight_scale.numel() == weight.shape[1] assert (
), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale" weight_scale.numel() == weight.shape[1]
output = fp8_scaled_mm( ), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale"
qinput, weight, x_scale, weight_scale, out_dtype=input.dtype, bias=bias output = fp8_scaled_mm(
) qinput,
return output.view(*output_shape) weight,
x_scale,
weight_scale,
out_dtype=input.dtype,
bias=bias,
)
return output.view(*output_shape)
except (ImportError, NameError, AttributeError):
pass
# torch.scaled_mm supports per tensor weights + activations only # torch.scaled_mm supports per tensor weights + activations only
# so fallback to naive if per channel or per token # so fallback to naive if per channel or per token
else: per_tensor_weights = weight_scale.numel() == 1
per_tensor_weights = weight_scale.numel() == 1 per_tensor_activations = x_scale.numel() == 1
per_tensor_activations = x_scale.numel() == 1
if per_tensor_weights and per_tensor_activations:
if per_tensor_weights and per_tensor_activations: # Fused GEMM_DQ
# Fused GEMM_DQ output = torch._scaled_mm(
output = torch._scaled_mm( qinput,
qinput, weight,
weight, out_dtype=input.dtype,
out_dtype=input.dtype, scale_a=x_scale,
scale_a=x_scale, scale_b=weight_scale,
scale_b=weight_scale, bias=bias,
bias=bias, )
) # A fix for discrepancy in scaled_mm which returns tuple
# A fix for discrepancy in scaled_mm which returns tuple # for torch < 2.5 and a single value in torch >= 2.5
# for torch < 2.5 and a single value in torch >= 2.5 if type(output) is tuple and len(output) == 2:
if type(output) is tuple and len(output) == 2: output = output[0]
output = output[0]
return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape) return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape)
else: else:
# Fallback for channelwise case, where we use unfused DQ # Fallback for channelwise case, where we use unfused DQ
# due to limitations with scaled_mm # due to limitations with scaled_mm
# Symmetric quantized GEMM by definition computes the following: # Symmetric quantized GEMM by definition computes the following:
# C = (s_x * X) (s_w * W) + bias # C = (s_x * X) (s_w * W) + bias
# This is equivalent to dequantizing the weights and activations # This is equivalent to dequantizing the weights and activations
# before applying a GEMM. # before applying a GEMM.
# #
# In order to compute quantized operands, a quantized kernel # In order to compute quantized operands, a quantized kernel
# will rewrite the above like so: # will rewrite the above like so:
# C = s_w * s_x * (X * W) + bias # C = s_w * s_x * (X * W) + bias
# #
# For the scaled_mm fallback case, we break this down, since it # For the scaled_mm fallback case, we break this down, since it
# does not support s_w being a vector. # does not support s_w being a vector.
# Making sure the dummy tensor is on the same device as the weight # Making sure the dummy tensor is on the same device as the weight
global TORCH_DEVICE_IDENTITY global TORCH_DEVICE_IDENTITY
if TORCH_DEVICE_IDENTITY.device != weight.device: if TORCH_DEVICE_IDENTITY.device != weight.device:
TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device) TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device)
# GEMM # GEMM
# This computes C = (X * W). # This computes C = (X * W).
# Output in fp32 to allow subsequent ops to happen in-place # Output in fp32 to allow subsequent ops to happen in-place
output = torch._scaled_mm( output = torch._scaled_mm(
qinput, qinput,
weight, weight,
scale_a=TORCH_DEVICE_IDENTITY, scale_a=TORCH_DEVICE_IDENTITY,
scale_b=TORCH_DEVICE_IDENTITY, scale_b=TORCH_DEVICE_IDENTITY,
out_dtype=torch.float32, out_dtype=torch.float32,
) )
# A fix for discrepancy in scaled_mm which returns tuple # A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5 # for torch < 2.5 and a single value in torch >= 2.5
if type(output) is tuple and len(output) == 2: if type(output) is tuple and len(output) == 2:
output = output[0] output = output[0]
# Unpad (undo num_token_padding) # Unpad (undo num_token_padding)
output = torch.narrow(output, 0, 0, input_2d.shape[0]) output = torch.narrow(output, 0, 0, input_2d.shape[0])
x_scale = torch.narrow(x_scale, 0, 0, input_2d.shape[0]) x_scale = torch.narrow(x_scale, 0, 0, input_2d.shape[0])
# DQ # DQ
# C = sw * sx * (X * W) + bias # C = sw * sx * (X * W) + bias
output = output * x_scale * weight_scale.t() output = output * x_scale * weight_scale.t()
if bias is not None: if bias is not None:
output = output + bias output = output + bias
return output.to(dtype=input.dtype).view(*output_shape) return output.to(dtype=input.dtype).view(*output_shape)
...@@ -3,11 +3,21 @@ from fractions import Fraction ...@@ -3,11 +3,21 @@ from fractions import Fraction
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
import torch import torch
from vllm.scalar_type import scalar_types
from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.utils import scalar_types
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.utils import is_cuda
_is_cuda = is_cuda()
try:
import vllm
VLLM_AVAILABLE = True
except ImportError:
VLLM_AVAILABLE = False
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -110,6 +120,9 @@ class GPTQConfig(QuantizationConfig): ...@@ -110,6 +120,9 @@ class GPTQConfig(QuantizationConfig):
def get_quant_method( def get_quant_method(
self, layer: torch.nn.Module, prefix: str self, layer: torch.nn.Module, prefix: str
) -> Optional["GPTQLinearMethod"]: ) -> Optional["GPTQLinearMethod"]:
if not VLLM_AVAILABLE:
raise ImportError("vllm is not installed")
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
from sglang.srt.layers.quantization import get_linear_quant_method from sglang.srt.layers.quantization import get_linear_quant_method
...@@ -263,6 +276,9 @@ class GPTQMarlinConfig(QuantizationConfig): ...@@ -263,6 +276,9 @@ class GPTQMarlinConfig(QuantizationConfig):
def get_quant_method( def get_quant_method(
self, layer: torch.nn.Module, prefix: str self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]: ) -> Optional["QuantizeMethodBase"]:
if not VLLM_AVAILABLE:
raise ImportError("vllm is not installed")
from vllm.model_executor.layers.quantization.gptq_marlin import ( from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinLinearMethod, GPTQMarlinLinearMethod,
GPTQMarlinMoEMethod, GPTQMarlinMoEMethod,
...@@ -285,6 +301,9 @@ class GPTQMarlinConfig(QuantizationConfig): ...@@ -285,6 +301,9 @@ class GPTQMarlinConfig(QuantizationConfig):
@classmethod @classmethod
def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]): def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
if not VLLM_AVAILABLE:
return False
quant_method = quant_config.get("quant_method", "").lower() quant_method = quant_config.get("quant_method", "").lower()
num_bits = quant_config.get("bits") num_bits = quant_config.get("bits")
group_size = quant_config.get("group_size") group_size = quant_config.get("group_size")
...@@ -294,9 +313,8 @@ class GPTQMarlinConfig(QuantizationConfig): ...@@ -294,9 +313,8 @@ class GPTQMarlinConfig(QuantizationConfig):
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supported, check_marlin_supported,
) )
from vllm.platforms import current_platform
if not current_platform.is_cuda(): if not _is_cuda:
return False return False
if quant_method != "gptq": if quant_method != "gptq":
...@@ -407,6 +425,9 @@ class MarlinConfig(QuantizationConfig): ...@@ -407,6 +425,9 @@ class MarlinConfig(QuantizationConfig):
def get_quant_method( def get_quant_method(
self, layer: torch.nn.Module, prefix: str self, layer: torch.nn.Module, prefix: str
) -> Optional["MarlinLinearMethod"]: ) -> Optional["MarlinLinearMethod"]:
if not VLLM_AVAILABLE:
raise ImportError("vllm is not installed")
from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod
if isinstance(layer, LinearBase) or ( if isinstance(layer, LinearBase) or (
......
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/kv_cache.py
import logging
import torch
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.utils import is_hip
_is_hip = is_hip()
logger = logging.getLogger(__name__)
class BaseKVCacheMethod(QuantizeMethodBase):
"""
Quant method that adds `_k_scale` and `_v_scale` attributes to the
Attention layer to support loading those scaling factors from checkpoints.
The k/v_scale will be used to:
- quantize k/v_cache entries before saving them to the cache
- dequantize k/v_cache entries before fetching them from the cache
:param quant_config: the appropriate QuantizationConfig
"""
def __init__(self, quant_config: QuantizationConfig):
self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module):
"""
Create "weight" (aka k_scale and v_scale) for an attention layer.
"""
# Initialize the KV cache scales to -1.0, which is an invalid value.
# If the k/v_scale appears in the checkpoint, it will be
# overwritten when loading weights.
layer.k_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False)
layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False)
@classmethod
def is_fp8_fnuz(cls) -> bool:
# only device 0 is checked, this assumes MI300 platforms are homogeneous
return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName
def apply(self, layer: torch.nn.Module) -> torch.Tensor:
raise RuntimeError(f"{self.__class__.__name__}.apply should not be called.")
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0
# regardless whether the kv-scale is available in the checkpoint.
# No need to process kv scales after loading if we are going to
# calculate them on the fly.
if layer.kv_cache_dtype != "auto" and not layer.calculate_kv_scales:
if layer.k_scale > 0.0 and layer.v_scale > 0.0:
# We prefer to use separate k_scale and v_scale if present
k_scale = layer.k_scale.to("cpu").tolist()
v_scale = layer.v_scale.to("cpu").tolist()
if _is_hip and self.is_fp8_fnuz():
k_scale *= 2
v_scale *= 2
elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
# If no scales were loaded (both scales are invalid negative
# values), use the default value of 1.0
k_scale = 1.0
v_scale = 1.0
else:
# If we find a single kv_scale in the checkpoint, we remap
# kv_scale to k_scale during weight loading, and duplicate
# k_scale to v_scale here
assert layer.k_scale > 0.0
scale_to_duplicate = max(layer.k_scale, layer.v_scale)
k_scale = scale_to_duplicate.to("cpu").tolist()
v_scale = scale_to_duplicate.to("cpu").tolist()
if _is_hip and self.is_fp8_fnuz():
k_scale *= 2
v_scale *= 2
if not isinstance(k_scale, float) or not isinstance(v_scale, float):
raise ValueError(
"Only support per-tensor scaling factor " "for fp8 KV cache"
)
# These are used in the final Attention.forward()
layer._k_scale.copy_(k_scale)
layer._v_scale.copy_(v_scale)
layer._k_scale_float = k_scale
layer._v_scale_float = v_scale
if k_scale == 1.0 and v_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype:
logger.warning(
"Using KV cache scaling factor 1.0 for fp8_e4m3. This "
"may cause accuracy issues. Please make sure k/v_scale "
"scaling factors are available in the fp8 checkpoint."
)
del layer.k_scale
del layer.v_scale
...@@ -5,12 +5,6 @@ from typing import Any, Dict, List, Optional ...@@ -5,12 +5,6 @@ from typing import Any, Dict, List, Optional
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise,
cutlass_fp8_supported,
requantize_with_max_scale,
)
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.linear import LinearBase, LinearMethodBase from sglang.srt.layers.linear import LinearBase, LinearMethodBase
...@@ -19,7 +13,15 @@ from sglang.srt.layers.quantization.base_config import ( ...@@ -19,7 +13,15 @@ from sglang.srt.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
) )
from sglang.srt.layers.quantization.fp8_utils import apply_fp8_linear from sglang.srt.layers.quantization.fp8_utils import (
apply_fp8_linear,
cutlass_fp8_supported,
)
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
from sglang.srt.layers.quantization.utils import (
convert_to_channelwise,
requantize_with_max_scale,
)
# Initialize logger for the module # Initialize logger for the module
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/scalar_type.py
import functools
import struct
from dataclasses import dataclass
from enum import Enum
from types import MappingProxyType
from typing import List, Mapping, Optional, Tuple, Union
import torch
def is_layer_skipped(
prefix: str,
ignored_layers: List[str],
fused_mapping: Mapping[str, List[str]] = MappingProxyType({}),
) -> bool:
# prefix: model.layers.0.self_attn.q_proj
# proj_name: q_proj
proj_name = prefix.split(".")[-1]
# Fused layers like gate_up_proj or qkv_proj will not be fused
# in the safetensors checkpoint. So, we convert the name
# from the fused version to unfused + check to make sure that
# each shard of the fused layer has the same scheme.
if proj_name in fused_mapping:
shard_prefixes = [
prefix.replace(proj_name, shard_proj_name)
for shard_proj_name in fused_mapping[proj_name]
]
is_skipped = None
for shard_prefix in shard_prefixes:
is_shard_skipped = shard_prefix in ignored_layers
if is_skipped is None:
is_skipped = is_shard_skipped
elif is_shard_skipped != is_skipped:
raise ValueError(
f"Detected some but not all shards of {prefix} "
"are quantized. All shards of fused layers "
"to have the same precision."
)
else:
is_skipped = prefix in ignored_layers
assert is_skipped is not None
return is_skipped
def per_tensor_dequantize(
tensor: torch.Tensor, inv_scale: Union[float, torch.Tensor]
) -> torch.Tensor:
fake_qweight = tensor.to(torch.float16)
dq_weight = fake_qweight * inv_scale
return dq_weight
def all_close_1d(x: torch.Tensor) -> bool:
assert len(x.shape) == 1
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
def convert_to_channelwise(
weight_scale: torch.Tensor, logical_widths: List[int]
) -> Tuple[torch.Tensor, torch.Tensor]:
# Create channelwise buffer
weight_scale_channel = torch.empty(
(sum(logical_widths), 1), dtype=torch.float32, device=weight_scale.device
)
# Expand each scale to match the size of each logical matrix.
start = 0
for idx, logical_width in enumerate(logical_widths):
end = start + logical_width
weight_scale_channel[start:end, :] = weight_scale[idx]
start = end
return weight_scale_channel
def requantize_with_max_scale(
weight: torch.Tensor, weight_scale: torch.Tensor, logical_widths: List[int]
) -> Tuple[torch.Tensor, torch.Tensor]:
# Max scale to be used for requanitzation.
max_w_scale = weight_scale.max()
# QKV / MLP is fused in the on disk checkpoint if any of the
# weight scales are still set to the default since we initialize
# N weight scales for N shards but we only load 1 weight scale
# from disk in this case. Skip requantization in this case (since)
# we already are quantized with the single scale.
# * Sample Model: nm-testing/Phi-3-mini-128k-instruct-FP8
unfused_module_in_checkpoint = (
weight_scale[-1] > torch.finfo(torch.float8_e4m3fn).min
)
# If unfused checkpoint, need requanize with the single scale.
if unfused_module_in_checkpoint:
start = 0
for idx, logical_width in enumerate(logical_widths):
end = start + logical_width
weight_dq = per_tensor_dequantize(weight[start:end, :], weight_scale[idx])
weight[start:end, :], _ = ops.scaled_fp8_quant(weight_dq, max_w_scale)
start = end
return max_w_scale, weight
# Mirrors enum in `core/scalar_type.hpp`
class NanRepr(Enum):
NONE = 0 # nans are not supported
IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s
EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s
# This ScalarType class is a parallel implementation of the C++ ScalarType
# class found in csrc/core/scalar_type.hpp. These two classes should be kept
# in sync until the inductor fully supports custom C++ classes.
@dataclass(frozen=True)
class ScalarType:
"""
ScalarType can represent a wide range of floating point and integer
types, in particular it can be used to represent sub-byte data types
(something that torch.dtype currently does not support). It is also
capable of representing types with a bias, i.e.:
`stored_value = value + bias`,
this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
of 8). The implementation for this class can be found in
csrc/core/scalar_type.hpp, these type signatures should be kept in sync
with that file.
"""
exponent: int
"""
Number of bits in the exponent if this is a floating point type
(zero if this an integer type)
"""
mantissa: int
"""
Number of bits in the mantissa if this is a floating point type,
or the number bits representing an integer excluding the sign bit if
this an integer type.
"""
signed: bool
"If the type is signed (i.e. has a sign bit)"
bias: int
"""
bias used to encode the values in this scalar type
(value = stored_value - bias, default 0) for example if we store the
type as an unsigned integer with a bias of 128 then the value 0 will be
stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
"""
_finite_values_only: bool = False
"""
Private: if infs are supported, used `has_infs()` instead.
"""
nan_repr: NanRepr = NanRepr.IEEE_754
"""
How NaNs are represent in this scalar type, returns NanRepr value.
(not applicable for integer types)
"""
def _floating_point_max_int(self) -> int:
assert (
self.mantissa <= 52 and self.exponent <= 11
), f"Cannot represent max/min as a double for type {self.__str__()}"
max_mantissa = (1 << self.mantissa) - 1
if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN:
max_mantissa = max_mantissa - 1
max_exponent = (1 << self.exponent) - 2
if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN or self.nan_repr == NanRepr.NONE:
assert (
self.exponent < 11
), f"Cannot represent max/min as a double for type {self.__str__()}"
max_exponent = max_exponent + 1
# adjust the exponent to match that of a double
# for now we assume the exponent bias is the standard 2^(e-1) -1, (where
# e is the exponent bits), there is some precedent for non-standard
# biases, example `float8_e4m3b11fnuz` here:
# https://github.com/jax-ml/ml_dtypes but to avoid premature over
# complication we are just assuming the standard exponent bias until
# there is a need to support non-standard biases
exponent_bias = (1 << (self.exponent - 1)) - 1
exponent_bias_double = (1 << 10) - 1 # double e = 11
max_exponent_double = max_exponent - exponent_bias + exponent_bias_double
# shift the mantissa and exponent into the proper positions for an
# IEEE double and bitwise-or them together.
return (max_mantissa << (52 - self.mantissa)) | (max_exponent_double << 52)
def _floating_point_max(self) -> float:
double_raw = self._floating_point_max_int()
return struct.unpack("!d", struct.pack("!Q", double_raw))[0]
def _raw_max(self) -> Union[int, float]:
if self.is_floating_point():
return self._floating_point_max()
else:
assert (
self.size_bits < 64 or self.size_bits == 64 and self.is_signed()
), "Cannot represent max as an int"
return (1 << self.mantissa) - 1
def _raw_min(self) -> Union[int, float]:
if self.is_floating_point():
assert (
self.is_signed()
), "We currently assume all floating point types are signed"
sign_bit_double = 1 << 63
max_raw = self._floating_point_max_int()
min_raw = max_raw | sign_bit_double
return struct.unpack("!d", struct.pack("!Q", min_raw))[0]
else:
assert (
not self.is_signed() or self.size_bits <= 64
), "Cannot represent min as a int64_t"
if self.is_signed():
return -(1 << (self.size_bits - 1))
else:
return 0
@functools.cached_property
def id(self) -> int:
"""
Convert the ScalarType to an int which can be passed to pytorch custom
ops. This layout of the int must be kept in sync with the C++
ScalarType's from_id method.
"""
val = 0
offset = 0
def or_and_advance(member, bit_width):
nonlocal val
nonlocal offset
bit_mask = (1 << bit_width) - 1
val = val | (int(member) & bit_mask) << offset
offset = offset + bit_width
or_and_advance(self.exponent, 8)
or_and_advance(self.mantissa, 8)
or_and_advance(self.signed, 1)
or_and_advance(self.bias, 32)
or_and_advance(self._finite_values_only, 1)
or_and_advance(self.nan_repr.value, 8)
assert offset <= 64, f"ScalarType fields too big {offset} to fit into an int64"
return val
@property
def size_bits(self) -> int:
return self.exponent + self.mantissa + int(self.signed)
def min(self) -> Union[int, float]:
"""
Min representable value for this scalar type.
(accounting for bias if there is one)
"""
return self._raw_min() - self.bias
def max(self) -> Union[int, float]:
"""
Max representable value for this scalar type.
(accounting for bias if there is one)
"""
return self._raw_max() - self.bias
def is_signed(self) -> bool:
"""
If the type is signed (i.e. has a sign bit), same as `signed`
added for consistency with:
https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html
"""
return self.signed
def is_floating_point(self) -> bool:
"If the type is a floating point type"
return self.exponent != 0
def is_integer(self) -> bool:
"If the type is an integer type"
return self.exponent == 0
def has_bias(self) -> bool:
"If the type has a non-zero bias"
return self.bias != 0
def has_infs(self) -> bool:
"If the type is floating point and supports infinity"
return not self._finite_values_only
def has_nans(self) -> bool:
return self.nan_repr != NanRepr.NONE.value
def is_ieee_754(self) -> bool:
"""
If the type is a floating point type that follows IEEE 754
conventions
"""
return self.nan_repr == NanRepr.IEEE_754.value and not self._finite_values_only
def __str__(self) -> str:
"""
naming generally follows: https://github.com/jax-ml/ml_dtypes
for floating point types (leading f) the scheme is:
`float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
flags:
- no-flags: means it follows IEEE 754 conventions
- f: means finite values only (no infinities)
- n: means nans are supported (non-standard encoding)
for integer types the scheme is:
`[u]int<size_bits>[b<bias>]`
- if bias is not present it means its zero
"""
if self.is_floating_point():
ret = (
"float"
+ str(self.size_bits)
+ "_e"
+ str(self.exponent)
+ "m"
+ str(self.mantissa)
)
if not self.is_ieee_754():
if self._finite_values_only:
ret = ret + "f"
if self.nan_repr != NanRepr.NONE:
ret = ret + "n"
return ret
else:
ret = ("int" if self.is_signed() else "uint") + str(self.size_bits)
if self.has_bias():
ret = ret + "b" + str(self.bias)
return ret
def __repr__(self) -> str:
return "ScalarType." + self.__str__()
# __len__ needs to be defined (and has to throw TypeError) for pytorch's
# opcheck to work.
def __len__(self) -> int:
raise TypeError
#
# Convenience Constructors
#
@classmethod
def int_(cls, size_bits: int, bias: Optional[int]) -> "ScalarType":
"Create a signed integer scalar type (size_bits includes sign-bit)."
ret = cls(0, size_bits - 1, True, bias if bias else 0)
ret.id # noqa B018: make sure the id is cached
return ret
@classmethod
def uint(cls, size_bits: int, bias: Optional[int]) -> "ScalarType":
"""Create a unsigned integer scalar type."""
ret = cls(0, size_bits, False, bias if bias else 0)
ret.id # noqa B018: make sure the id is cached
return ret
@classmethod
def float_IEEE754(cls, exponent: int, mantissa: int) -> "ScalarType":
"""
Create a standard floating point type
(i.e. follows IEEE 754 conventions).
"""
assert mantissa > 0 and exponent > 0
ret = cls(exponent, mantissa, True, 0)
ret.id # noqa B018: make sure the id is cached
return ret
@classmethod
def float_(
cls, exponent: int, mantissa: int, finite_values_only: bool, nan_repr: NanRepr
) -> "ScalarType":
"""
Create a non-standard floating point type
(i.e. does not follow IEEE 754 conventions).
"""
assert mantissa > 0 and exponent > 0
assert nan_repr != NanRepr.IEEE_754, (
"use `float_IEEE754` constructor for floating point types that "
"follow IEEE 754 conventions"
)
ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr)
ret.id # noqa B018: make sure the id is cached
return ret
# naming generally follows: https://github.com/jax-ml/ml_dtypes
# for floating point types (leading f) the scheme is:
# `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
# flags:
# - no-flags: means it follows IEEE 754 conventions
# - f: means finite values only (no infinities)
# - n: means nans are supported (non-standard encoding)
# for integer types the scheme is:
# `[u]int<size_bits>[b<bias>]`
# - if bias is not present it means its zero
class scalar_types:
int4 = ScalarType.int_(4, None)
uint4 = ScalarType.uint(4, None)
int8 = ScalarType.int_(8, None)
uint8 = ScalarType.uint(8, None)
float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN)
float8_e5m2 = ScalarType.float_IEEE754(5, 2)
float16_e8m7 = ScalarType.float_IEEE754(8, 7)
float16_e5m10 = ScalarType.float_IEEE754(5, 10)
# fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main
float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE)
# fp4, https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
float4_e2m1fn = ScalarType.float_(2, 1, True, NanRepr.NONE)
# "gptq" types
uint2b2 = ScalarType.uint(2, 2)
uint3b4 = ScalarType.uint(3, 4)
uint4b8 = ScalarType.uint(4, 8)
uint8b128 = ScalarType.uint(8, 128)
# colloquial names
bfloat16 = float16_e8m7
float16 = float16_e5m10
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment