"git@developer.sourcefind.cn:change/sglang.git" did not exist on "887c2b4575772f8f70e0bc55dd701ae274d3ab32"
Unverified Commit 2695ab05 authored by Yun Dai's avatar Yun Dai Committed by GitHub
Browse files

Fix loading KV quantization scale; Enable modelopt kv cache (#4686)


Co-authored-by: default avatarqingquansong <ustcsqq@gmail.com>
parent 88d6fd9a
...@@ -239,7 +239,7 @@ class ModelConfig: ...@@ -239,7 +239,7 @@ class ModelConfig:
# check if is modelopt model -- modelopt doesn't have corresponding field # check if is modelopt model -- modelopt doesn't have corresponding field
# in hf `config.json` but has a standalone `hf_quant_config.json` in the root directory # in hf `config.json` but has a standalone `hf_quant_config.json` in the root directory
# example: https://huggingface.co/nvidia/Llama-3.1-8B-Instruct-FP8/tree/main # example: https://huggingface.co/nvidia/Llama-3.1-8B-Instruct-FP8/tree/main
is_local = os.path.isdir(self.model_path) is_local = os.path.exists(self.model_path)
modelopt_quant_config = {"quant_method": "modelopt"} modelopt_quant_config = {"quant_method": "modelopt"}
if not is_local: if not is_local:
from huggingface_hub import HfApi from huggingface_hub import HfApi
......
...@@ -292,6 +292,8 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -292,6 +292,8 @@ class FlashAttentionBackend(AttentionBackend):
self.decode_cuda_graph_metadata = {} self.decode_cuda_graph_metadata = {}
self.target_verify_metadata = {} self.target_verify_metadata = {}
self.req_to_token = model_runner.req_to_token_pool.req_to_token self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.kv_cache_dtype = model_runner.kv_cache_dtype
self.kv_cache_dtype_str = model_runner.server_args.kv_cache_dtype
self.page_size = model_runner.page_size self.page_size = model_runner.page_size
self.use_mla = ( self.use_mla = (
model_runner.model_config.attention_arch == AttentionArch.MLA model_runner.model_config.attention_arch == AttentionArch.MLA
...@@ -520,6 +522,12 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -520,6 +522,12 @@ class FlashAttentionBackend(AttentionBackend):
if layer.sliding_window_size is not None if layer.sliding_window_size is not None
else (-1, -1) else (-1, -1)
) )
k_descale, v_descale = None, None
if self.kv_cache_dtype_str != "auto":
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
k_descale = layer.k_scale.expand(descale_shape)
v_descale = layer.v_scale.expand(descale_shape)
q = q.to(self.kv_cache_dtype)
causal = not layer.is_cross_attention causal = not layer.is_cross_attention
# Check if we should use local attention # Check if we should use local attention
...@@ -576,8 +584,8 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -576,8 +584,8 @@ class FlashAttentionBackend(AttentionBackend):
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
softcap=layer.logit_cap, softcap=layer.logit_cap,
k_descale=layer.k_scale, k_descale=k_descale,
v_descale=layer.v_scale, v_descale=v_descale,
) )
else: else:
# Do absorbed multi-latent attention # Do absorbed multi-latent attention
...@@ -609,8 +617,8 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -609,8 +617,8 @@ class FlashAttentionBackend(AttentionBackend):
softmax_scale=layer.scaling, softmax_scale=layer.scaling,
causal=True, causal=True,
softcap=layer.logit_cap, softcap=layer.logit_cap,
k_descale=layer.k_scale, k_descale=k_descale,
v_descale=layer.v_scale, v_descale=v_descale,
) )
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
...@@ -657,6 +665,13 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -657,6 +665,13 @@ class FlashAttentionBackend(AttentionBackend):
) )
causal = not layer.is_cross_attention causal = not layer.is_cross_attention
k_descale, v_descale = None, None
if self.kv_cache_dtype_str != "auto":
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
k_descale = layer.k_scale.expand(descale_shape)
v_descale = layer.v_scale.expand(descale_shape)
q = q.to(self.kv_cache_dtype)
if not self.use_mla: if not self.use_mla:
# Do multi-head attention # Do multi-head attention
...@@ -694,8 +709,8 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -694,8 +709,8 @@ class FlashAttentionBackend(AttentionBackend):
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
softcap=layer.logit_cap, softcap=layer.logit_cap,
k_descale=layer.k_scale, k_descale=k_descale,
v_descale=layer.v_scale, v_descale=v_descale,
) )
else: else:
# Do absorbed multi-latent attention # Do absorbed multi-latent attention
...@@ -729,8 +744,8 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -729,8 +744,8 @@ class FlashAttentionBackend(AttentionBackend):
softmax_scale=layer.scaling, softmax_scale=layer.scaling,
causal=True, causal=True,
softcap=layer.logit_cap, softcap=layer.logit_cap,
k_descale=layer.k_scale, k_descale=k_descale,
v_descale=layer.v_scale, v_descale=v_descale,
) )
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
......
...@@ -82,6 +82,8 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -82,6 +82,8 @@ class FlashInferAttnBackend(AttentionBackend):
self.max_context_len = model_runner.model_config.context_len self.max_context_len = model_runner.model_config.context_len
self.skip_prefill = skip_prefill self.skip_prefill = skip_prefill
self.is_multimodal = model_runner.model_config.is_multimodal self.is_multimodal = model_runner.model_config.is_multimodal
self.kv_cache_dtype = model_runner.kv_cache_dtype
self.kv_cache_dtype_str = model_runner.server_args.kv_cache_dtype
assert not ( assert not (
model_runner.sliding_window_size is not None model_runner.sliding_window_size is not None
...@@ -391,6 +393,8 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -391,6 +393,8 @@ class FlashInferAttnBackend(AttentionBackend):
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
save_kv_cache=True, save_kv_cache=True,
): ):
k_scale = layer.k_scale_float if self.kv_cache_dtype_str != "auto" else None
v_scale = layer.v_scale_float if self.kv_cache_dtype_str != "auto" else None
prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[ prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
self._get_wrapper_idx(layer) self._get_wrapper_idx(layer)
] ]
...@@ -407,7 +411,7 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -407,7 +411,7 @@ class FlashInferAttnBackend(AttentionBackend):
assert v is not None assert v is not None
if save_kv_cache: if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer( forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale layer, cache_loc, k, v, k_scale, v_scale
) )
o = prefill_wrapper_paged.forward( o = prefill_wrapper_paged.forward(
...@@ -417,8 +421,8 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -417,8 +421,8 @@ class FlashInferAttnBackend(AttentionBackend):
sm_scale=layer.scaling, sm_scale=layer.scaling,
window_left=layer.sliding_window_size, window_left=layer.sliding_window_size,
logits_soft_cap=logits_soft_cap, logits_soft_cap=logits_soft_cap,
k_scale=layer.k_scale, k_scale=k_scale,
v_scale=layer.v_scale, v_scale=v_scale,
) )
else: else:
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse( o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
...@@ -445,7 +449,7 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -445,7 +449,7 @@ class FlashInferAttnBackend(AttentionBackend):
if save_kv_cache: if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer( forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale layer, cache_loc, k, v, k_scale, v_scale
) )
return o.view(-1, layer.tp_q_head_num * layer.head_dim) return o.view(-1, layer.tp_q_head_num * layer.head_dim)
...@@ -459,6 +463,8 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -459,6 +463,8 @@ class FlashInferAttnBackend(AttentionBackend):
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
save_kv_cache=True, save_kv_cache=True,
): ):
k_scale = layer.k_scale_float if self.kv_cache_dtype_str != "auto" else None
v_scale = layer.v_scale_float if self.kv_cache_dtype_str != "auto" else None
decode_wrapper = self.forward_metadata.decode_wrappers[ decode_wrapper = self.forward_metadata.decode_wrappers[
self._get_wrapper_idx(layer) self._get_wrapper_idx(layer)
] ]
...@@ -472,7 +478,7 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -472,7 +478,7 @@ class FlashInferAttnBackend(AttentionBackend):
assert v is not None assert v is not None
if save_kv_cache: if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer( forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale layer, cache_loc, k, v, k_scale, v_scale
) )
o = decode_wrapper.forward( o = decode_wrapper.forward(
...@@ -480,8 +486,8 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -480,8 +486,8 @@ class FlashInferAttnBackend(AttentionBackend):
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
sm_scale=layer.scaling, sm_scale=layer.scaling,
logits_soft_cap=layer.logit_cap, logits_soft_cap=layer.logit_cap,
k_scale=layer.k_scale, k_scale=k_scale,
v_scale=layer.v_scale, v_scale=v_scale,
) )
return o.view(-1, layer.tp_q_head_num * layer.head_dim) return o.view(-1, layer.tp_q_head_num * layer.head_dim)
......
...@@ -8,6 +8,7 @@ from sglang.srt.layers.quantization.base_config import ( ...@@ -8,6 +8,7 @@ from sglang.srt.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
) )
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.utils import is_hip from sglang.srt.utils import is_hip
_is_hip = is_hip() _is_hip = is_hip()
...@@ -17,7 +18,7 @@ logger = logging.getLogger(__name__) ...@@ -17,7 +18,7 @@ logger = logging.getLogger(__name__)
class BaseKVCacheMethod(QuantizeMethodBase): class BaseKVCacheMethod(QuantizeMethodBase):
""" """
Quant method that adds `_k_scale` and `_v_scale` attributes to the Quant method that adds `k_scale` and `v_scale` attributes to the
Attention layer to support loading those scaling factors from checkpoints. Attention layer to support loading those scaling factors from checkpoints.
The k/v_scale will be used to: The k/v_scale will be used to:
- quantize k/v_cache entries before saving them to the cache - quantize k/v_cache entries before saving them to the cache
...@@ -36,8 +37,12 @@ class BaseKVCacheMethod(QuantizeMethodBase): ...@@ -36,8 +37,12 @@ class BaseKVCacheMethod(QuantizeMethodBase):
# Initialize the KV cache scales to -1.0, which is an invalid value. # Initialize the KV cache scales to -1.0, which is an invalid value.
# If the k/v_scale appears in the checkpoint, it will be # If the k/v_scale appears in the checkpoint, it will be
# overwritten when loading weights. # overwritten when loading weights.
layer.k_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False) layer.k_scale = torch.nn.Parameter(
layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False) torch.tensor(-1.0, dtype=torch.float32), requires_grad=False
)
layer.v_scale = torch.nn.Parameter(
torch.tensor(-1.0, dtype=torch.float32), requires_grad=False
)
@classmethod @classmethod
def is_fp8_fnuz(cls) -> bool: def is_fp8_fnuz(cls) -> bool:
...@@ -47,52 +52,38 @@ class BaseKVCacheMethod(QuantizeMethodBase): ...@@ -47,52 +52,38 @@ class BaseKVCacheMethod(QuantizeMethodBase):
def apply(self, layer: torch.nn.Module) -> torch.Tensor: def apply(self, layer: torch.nn.Module) -> torch.Tensor:
raise RuntimeError(f"{self.__class__.__name__}.apply should not be called.") raise RuntimeError(f"{self.__class__.__name__}.apply should not be called.")
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: RadixAttention) -> None:
# If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0 if layer.k_scale > 0.0 and layer.v_scale > 0.0:
# regardless whether the kv-scale is available in the checkpoint. # We prefer to use separate k_scale and v_scale if present
# No need to process kv scales after loading if we are going to k_scale = layer.k_scale.to("cpu").tolist()
# calculate them on the fly. v_scale = layer.v_scale.to("cpu").tolist()
if layer.kv_cache_dtype != "auto" and not layer.calculate_kv_scales: if _is_hip and self.is_fp8_fnuz():
if layer.k_scale > 0.0 and layer.v_scale > 0.0: k_scale *= 2
# We prefer to use separate k_scale and v_scale if present v_scale *= 2
k_scale = layer.k_scale.to("cpu").tolist() elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
v_scale = layer.v_scale.to("cpu").tolist() # If no scales were loaded (both scales are invalid negative
if _is_hip and self.is_fp8_fnuz(): # values), use the default value of 1.0
k_scale *= 2 k_scale = 1.0
v_scale *= 2 v_scale = 1.0
elif layer.k_scale < 0.0 and layer.v_scale < 0.0: else:
# If no scales were loaded (both scales are invalid negative # If we find a single kv_scale in the checkpoint, we remap
# values), use the default value of 1.0 # kv_scale to k_scale during weight loading, and duplicate
k_scale = 1.0 # k_scale to v_scale here
v_scale = 1.0 assert layer.k_scale > 0.0
else: scale_to_duplicate = max(layer.k_scale, layer.v_scale)
# If we find a single kv_scale in the checkpoint, we remap k_scale = scale_to_duplicate.to("cpu").tolist()
# kv_scale to k_scale during weight loading, and duplicate v_scale = scale_to_duplicate.to("cpu").tolist()
# k_scale to v_scale here if _is_hip and self.is_fp8_fnuz():
assert layer.k_scale > 0.0 k_scale *= 2
scale_to_duplicate = max(layer.k_scale, layer.v_scale) v_scale *= 2
k_scale = scale_to_duplicate.to("cpu").tolist()
v_scale = scale_to_duplicate.to("cpu").tolist() if not isinstance(k_scale, float) or not isinstance(v_scale, float):
if _is_hip and self.is_fp8_fnuz(): raise ValueError(
k_scale *= 2 "Only support per-tensor scaling factor " "for fp8 KV cache"
v_scale *= 2 )
if not isinstance(k_scale, float) or not isinstance(v_scale, float): # These are used in the final Attention.forward()
raise ValueError( layer.k_scale.copy_(k_scale)
"Only support per-tensor scaling factor " "for fp8 KV cache" layer.v_scale.copy_(v_scale)
) layer.k_scale_float = k_scale
layer.v_scale_float = v_scale
# These are used in the final Attention.forward()
layer._k_scale.copy_(k_scale)
layer._v_scale.copy_(v_scale)
layer._k_scale_float = k_scale
layer._v_scale_float = v_scale
if k_scale == 1.0 and v_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype:
logger.warning(
"Using KV cache scaling factor 1.0 for fp8_e4m3. This "
"may cause accuracy issues. Please make sure k/v_scale "
"scaling factors are available in the fp8 checkpoint."
)
del layer.k_scale
del layer.v_scale
...@@ -6,7 +6,6 @@ from typing import Any, Dict, List, Optional ...@@ -6,7 +6,6 @@ from typing import Any, Dict, List, Optional
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.linear import LinearBase, LinearMethodBase from sglang.srt.layers.linear import LinearBase, LinearMethodBase
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
...@@ -22,6 +21,7 @@ from sglang.srt.layers.quantization.utils import ( ...@@ -22,6 +21,7 @@ from sglang.srt.layers.quantization.utils import (
convert_to_channelwise, convert_to_channelwise,
requantize_with_max_scale, requantize_with_max_scale,
) )
from sglang.srt.layers.radix_attention import RadixAttention
# Initialize logger for the module # Initialize logger for the module
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -33,12 +33,19 @@ ACTIVATION_SCHEMES = ["static"] ...@@ -33,12 +33,19 @@ ACTIVATION_SCHEMES = ["static"]
class ModelOptFp8Config(QuantizationConfig): class ModelOptFp8Config(QuantizationConfig):
"""Configuration for ModelOpt FP8 quantization, including serialization and compatibility checks.""" """Configuration for ModelOpt FP8 quantization, including serialization and compatibility checks."""
def __init__(self, is_checkpoint_fp8_serialized: bool = False) -> None: def __init__(
self,
is_checkpoint_fp8_serialized: bool = False,
kv_cache_quant_method: Optional[str] = None,
exclude_modules: Optional[List[str]] = None,
) -> None:
""" """
Args: Args:
is_checkpoint_fp8_serialized (bool): Indicates if the checkpoint uses serialized FP8 format. is_checkpoint_fp8_serialized (bool): Indicates if the checkpoint uses serialized FP8 format.
""" """
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
self.kv_cache_quant_method = kv_cache_quant_method
self.exclude_modules = exclude_modules
if is_checkpoint_fp8_serialized: if is_checkpoint_fp8_serialized:
logger.warning( logger.warning(
"Detected ModelOpt FP8 checkpoint. The format is experimental and subject to change." "Detected ModelOpt FP8 checkpoint. The format is experimental and subject to change."
...@@ -63,6 +70,12 @@ class ModelOptFp8Config(QuantizationConfig): ...@@ -63,6 +70,12 @@ class ModelOptFp8Config(QuantizationConfig):
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp8Config": def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp8Config":
quant_method = cls.get_from_keys(config, ["quantization"]).get("quant_algo") quant_method = cls.get_from_keys(config, ["quantization"]).get("quant_algo")
kv_cache_quant_method = cls.get_from_keys(config, ["quantization"]).get(
"kv_cache_quant_algo"
)
exclude_modules = cls.get_from_keys(config, ["quantization"]).get(
"exclude_modules"
)
if "FP8" not in quant_method: if "FP8" not in quant_method:
raise ValueError( raise ValueError(
...@@ -70,15 +83,23 @@ class ModelOptFp8Config(QuantizationConfig): ...@@ -70,15 +83,23 @@ class ModelOptFp8Config(QuantizationConfig):
"Check the `hf_quant_config.json` file for your model's configuration." "Check the `hf_quant_config.json` file for your model's configuration."
) )
return cls(is_checkpoint_fp8_serialized=True) return cls(
is_checkpoint_fp8_serialized=True,
kv_cache_quant_method=kv_cache_quant_method,
exclude_modules=exclude_modules,
)
def get_quant_method( def get_quant_method(
self, layer: torch.nn.Module, prefix: str self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]: ) -> Optional["QuantizeMethodBase"]:
if self.exclude_modules and any(
module in prefix for module in self.exclude_modules
):
return None
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
return ModelOptFp8LinearMethod(self) return ModelOptFp8LinearMethod(self)
if isinstance(layer, AttentionBackend): if self.kv_cache_quant_method and isinstance(layer, RadixAttention):
return ModelOptFp8KVCacheMethod(self) return ModelOptFp8KVCacheMethod(self)
return None return None
......
...@@ -13,8 +13,12 @@ ...@@ -13,8 +13,12 @@
# ============================================================================== # ==============================================================================
"""Radix attention.""" """Radix attention."""
from typing import Optional
from torch import nn from torch import nn
from sglang.srt.layers.linear import UnquantizedLinearMethod
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
...@@ -34,6 +38,7 @@ class RadixAttention(nn.Module): ...@@ -34,6 +38,7 @@ class RadixAttention(nn.Module):
v_head_dim: int = -1, v_head_dim: int = -1,
sliding_window_size: int = -1, sliding_window_size: int = -1,
is_cross_attention: bool = False, is_cross_attention: bool = False,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
use_irope: bool = False, use_irope: bool = False,
): ):
...@@ -49,9 +54,16 @@ class RadixAttention(nn.Module): ...@@ -49,9 +54,16 @@ class RadixAttention(nn.Module):
self.logit_cap = logit_cap self.logit_cap = logit_cap
self.sliding_window_size = sliding_window_size or -1 self.sliding_window_size = sliding_window_size or -1
self.is_cross_attention = is_cross_attention self.is_cross_attention = is_cross_attention
self.use_irope = use_irope
self.k_scale = None self.k_scale = None
self.v_scale = None self.v_scale = None
self.use_irope = use_irope self.k_scale_float = None
self.v_scale_float = None
self.quant_method = None
if quant_config is not None:
self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
if self.quant_method is not None:
self.quant_method.create_weights(self)
def forward( def forward(
self, self,
......
...@@ -178,6 +178,7 @@ class BaiChuanAttention(nn.Module): ...@@ -178,6 +178,7 @@ class BaiChuanAttention(nn.Module):
scaling, scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
layer_id=layer_id, layer_id=layer_id,
quant_config=quant_config,
prefix=add_prefix("attn", prefix), prefix=add_prefix("attn", prefix),
) )
else: else:
...@@ -194,6 +195,7 @@ class BaiChuanAttention(nn.Module): ...@@ -194,6 +195,7 @@ class BaiChuanAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
layer_id=layer_id, layer_id=layer_id,
quant_config=quant_config,
prefix=add_prefix("attn", prefix), prefix=add_prefix("attn", prefix),
) )
......
...@@ -113,6 +113,7 @@ class GLMAttention(nn.Module): ...@@ -113,6 +113,7 @@ class GLMAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
layer_id=layer_id, layer_id=layer_id,
quant_config=quant_config,
prefix=add_prefix("attn", prefix), prefix=add_prefix("attn", prefix),
) )
......
...@@ -204,6 +204,7 @@ class CohereAttention(nn.Module): ...@@ -204,6 +204,7 @@ class CohereAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
layer_id=layer_id, layer_id=layer_id,
quant_config=quant_config,
prefix=add_prefix("attn", prefix), prefix=add_prefix("attn", prefix),
) )
if self.use_qk_norm: if self.use_qk_norm:
......
...@@ -249,6 +249,7 @@ class DbrxAttention(nn.Module): ...@@ -249,6 +249,7 @@ class DbrxAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
layer_id=layer_id, layer_id=layer_id,
quant_config=quant_config,
prefix=add_prefix("attn", prefix), prefix=add_prefix("attn", prefix),
) )
......
...@@ -255,6 +255,7 @@ class DeepseekAttention(nn.Module): ...@@ -255,6 +255,7 @@ class DeepseekAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
layer_id=layer_id, layer_id=layer_id,
quant_config=quant_config,
prefix=add_prefix("attn", prefix), prefix=add_prefix("attn", prefix),
) )
......
...@@ -489,6 +489,7 @@ class DeepseekV2Attention(nn.Module): ...@@ -489,6 +489,7 @@ class DeepseekV2Attention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_local_heads, num_kv_heads=self.num_local_heads,
layer_id=layer_id, layer_id=layer_id,
quant_config=quant_config,
prefix=add_prefix("attn", prefix), prefix=add_prefix("attn", prefix),
) )
...@@ -669,6 +670,7 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -669,6 +670,7 @@ class DeepseekV2AttentionMLA(nn.Module):
num_kv_heads=1, num_kv_heads=1,
layer_id=layer_id, layer_id=layer_id,
v_head_dim=self.kv_lora_rank, v_head_dim=self.kv_lora_rank,
quant_config=quant_config,
prefix=add_prefix("attn_mqa", prefix), prefix=add_prefix("attn_mqa", prefix),
) )
...@@ -679,6 +681,7 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -679,6 +681,7 @@ class DeepseekV2AttentionMLA(nn.Module):
num_kv_heads=self.num_local_heads, num_kv_heads=self.num_local_heads,
layer_id=layer_id, layer_id=layer_id,
v_head_dim=self.v_head_dim, v_head_dim=self.v_head_dim,
quant_config=quant_config,
prefix=add_prefix("attn_mha", prefix), prefix=add_prefix("attn_mha", prefix),
) )
......
...@@ -155,6 +155,7 @@ class ExaoneAttention(nn.Module): ...@@ -155,6 +155,7 @@ class ExaoneAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
layer_id=layer_id, layer_id=layer_id,
quant_config=quant_config,
) )
def forward( def forward(
......
...@@ -137,6 +137,7 @@ class GemmaAttention(nn.Module): ...@@ -137,6 +137,7 @@ class GemmaAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
layer_id=layer_id, layer_id=layer_id,
quant_config=quant_config,
prefix=add_prefix("attn", prefix), prefix=add_prefix("attn", prefix),
) )
......
...@@ -163,6 +163,7 @@ class Gemma2Attention(nn.Module): ...@@ -163,6 +163,7 @@ class Gemma2Attention(nn.Module):
if use_sliding_window if use_sliding_window
else None else None
), ),
quant_config=quant_config,
prefix=add_prefix("attn", prefix), prefix=add_prefix("attn", prefix),
) )
......
...@@ -193,6 +193,7 @@ class Gemma3Attention(nn.Module): ...@@ -193,6 +193,7 @@ class Gemma3Attention(nn.Module):
# Module must also define `get_attention_sliding_window_size` to correctly initialize # Module must also define `get_attention_sliding_window_size` to correctly initialize
# attention backend in `ForwardBatch`. # attention backend in `ForwardBatch`.
sliding_window_size=self.sliding_window, sliding_window_size=self.sliding_window,
quant_config=quant_config,
prefix=add_prefix("attn", prefix), prefix=add_prefix("attn", prefix),
) )
......
...@@ -78,6 +78,7 @@ class GPT2Attention(nn.Module): ...@@ -78,6 +78,7 @@ class GPT2Attention(nn.Module):
scaling=self.scale, scaling=self.scale,
num_kv_heads=total_num_heads, num_kv_heads=total_num_heads,
layer_id=layer_id, layer_id=layer_id,
quant_config=quant_config,
) )
def forward( def forward(
......
...@@ -87,6 +87,7 @@ class GPTBigCodeAttention(nn.Module): ...@@ -87,6 +87,7 @@ class GPTBigCodeAttention(nn.Module):
scaling=self.scale, scaling=self.scale,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
layer_id=layer_id, layer_id=layer_id,
quant_config=quant_config,
prefix=add_prefix("attn", prefix), prefix=add_prefix("attn", prefix),
) )
......
...@@ -158,6 +158,7 @@ class GraniteAttention(nn.Module): ...@@ -158,6 +158,7 @@ class GraniteAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
layer_id=layer_id, layer_id=layer_id,
quant_config=quant_config,
prefix=add_prefix("attn", prefix), prefix=add_prefix("attn", prefix),
) )
......
...@@ -215,6 +215,7 @@ class Grok1Attention(nn.Module): ...@@ -215,6 +215,7 @@ class Grok1Attention(nn.Module):
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
layer_id=layer_id, layer_id=layer_id,
logit_cap=logit_cap, logit_cap=logit_cap,
quant_config=quant_config,
) )
def forward( def forward(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment