"...git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "6724e791641ed3c604bf929dfa46cfc71fa8ba71"
Unverified Commit b95bb692 authored by Eldar Kurtić's avatar Eldar Kurtić Committed by GitHub
Browse files

[kv-cache, ct] Use compressed-tensors as a source of ground-truth for quant strategies (#34254)


Signed-off-by: default avatarYour Name <you@example.com>
Co-authored-by: default avatarYour Name <you@example.com>
parent 39264545
...@@ -951,11 +951,11 @@ class CompressedTensorsKVCacheMethod(BaseKVCacheMethod): ...@@ -951,11 +951,11 @@ class CompressedTensorsKVCacheMethod(BaseKVCacheMethod):
f"received num_bits={num_bits}, type={type_}" f"received num_bits={num_bits}, type={type_}"
) )
# TODO: delegate validation to compressed-tensors library so that we have a strategy = QuantizationStrategy(kv_cache_scheme.get("strategy"))
# single source of truth. Right now this is not possible until the next release supported_strategies = (
# of compressed-tensors. QuantizationStrategy.TENSOR,
strategy = kv_cache_scheme.get("strategy") QuantizationStrategy.ATTN_HEAD,
supported_strategies = ("tensor", "attn_head") )
if strategy not in supported_strategies: if strategy not in supported_strategies:
raise NotImplementedError( raise NotImplementedError(
"Invalid strategy for compressed-tensors KV cache. " "Invalid strategy for compressed-tensors KV cache. "
...@@ -981,9 +981,11 @@ class CompressedTensorsKVCacheMethod(BaseKVCacheMethod): ...@@ -981,9 +981,11 @@ class CompressedTensorsKVCacheMethod(BaseKVCacheMethod):
hasattr(self.quant_config, "kv_cache_scheme") hasattr(self.quant_config, "kv_cache_scheme")
and self.quant_config.kv_cache_scheme is not None and self.quant_config.kv_cache_scheme is not None
): ):
strategy = self.quant_config.kv_cache_scheme["strategy"] strategy = QuantizationStrategy(
self.quant_config.kv_cache_scheme["strategy"]
)
if strategy == "attn_head": if strategy == QuantizationStrategy.ATTN_HEAD:
assert layer.impl.supports_per_head_quant_scales, ( assert layer.impl.supports_per_head_quant_scales, (
f"Layer {layer.__class__.__name__} with implementation " f"Layer {layer.__class__.__name__} with implementation "
f"{layer.impl.__class__.__name__} does not support per-head scales." f"{layer.impl.__class__.__name__} does not support per-head scales."
...@@ -1020,7 +1022,7 @@ class CompressedTensorsKVCacheMethod(BaseKVCacheMethod): ...@@ -1020,7 +1022,7 @@ class CompressedTensorsKVCacheMethod(BaseKVCacheMethod):
# - q_scale is partitioned over query heads. # - q_scale is partitioned over query heads.
# - k/v_scale is partitioned over kv heads when total_kv_heads >= tp_size, # - k/v_scale is partitioned over kv heads when total_kv_heads >= tp_size,
# and replicated when total_kv_heads < tp_size. # and replicated when total_kv_heads < tp_size.
if strategy == "attn_head": if strategy == QuantizationStrategy.ATTN_HEAD:
def _tp_aware_loader( def _tp_aware_loader(
param: torch.Tensor, param: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment