from vllm.kvprune.compression.common import ( BaseCompressionMethod, NoCompression, ) from vllm.kvprune.compression.criticalkv import CriticalAdaKVCompression from vllm.kvprune.compression.compactor import CompactorCompression from vllm.kvprune.compression.compression_config import ( BatchCompressionParams, CompressionMethod, SequenceCompressionParams, ) from vllm.kvprune.compression.snapkv import SnapKVCompression COMPRESSION_REGISTRY: dict[CompressionMethod, type[BaseCompressionMethod]] = { CompressionMethod.CRITICALADAKV: CriticalAdaKVCompression, CompressionMethod.COMPACTOR: CompactorCompression, CompressionMethod.SNAPKV: SnapKVCompression, CompressionMethod.NONE: NoCompression, } def apply_prerope_compression(q, k, v, context): method = context.compression_context.compression_method return COMPRESSION_REGISTRY[method].pre_rope_scoring(q, k, v, context=context) def apply_postrope_compression(q, k, v, prerope_scores, context): method = context.compression_context.compression_method return COMPRESSION_REGISTRY[method].post_rope_scoring( q, k, v, prerope_scores, context=context ) __all__ = [ "apply_prerope_compression", "apply_postrope_compression", "CompressionMethod", "BatchCompressionParams", "SequenceCompressionParams", "COMPRESSION_REGISTRY" ]