Commit 1693e754 authored by yiqa's avatar yiqa
Browse files

使用groupgemm完成高吞吐模式适配

parent ce363e89
...@@ -42,7 +42,7 @@ class SlimQuantCompressedTensorsMarlinConfig(CompressedTensorsConfig): ...@@ -42,7 +42,7 @@ class SlimQuantCompressedTensorsMarlinConfig(CompressedTensorsConfig):
sparsity_ignore_list: list[str], sparsity_ignore_list: list[str],
kv_cache_scheme: Optional[dict[str, Any]] = None, kv_cache_scheme: Optional[dict[str, Any]] = None,
config: Optional[dict[str, Any]] = None, config: Optional[dict[str, Any]] = None,
packed_modules_mapping: Optional[dict[str, list[str]]] = None, packed_modules_mapping: Optional[dict[str, list[str]]] = None,
): ):
super().__init__( super().__init__(
target_scheme_map, target_scheme_map,
...@@ -52,10 +52,10 @@ class SlimQuantCompressedTensorsMarlinConfig(CompressedTensorsConfig): ...@@ -52,10 +52,10 @@ class SlimQuantCompressedTensorsMarlinConfig(CompressedTensorsConfig):
sparsity_ignore_list, sparsity_ignore_list,
kv_cache_scheme, kv_cache_scheme,
config, config,
packed_modules_mapping, packed_modules_mapping,
) )
@classmethod @classmethod
def override_quantization_method( def override_quantization_method(
cls, hf_quant_cfg, user_quant) -> Optional[str]: cls, hf_quant_cfg, user_quant) -> Optional[str]:
...@@ -73,7 +73,7 @@ class SlimQuantCompressedTensorsMarlinConfig(CompressedTensorsConfig): ...@@ -73,7 +73,7 @@ class SlimQuantCompressedTensorsMarlinConfig(CompressedTensorsConfig):
prefix: str, prefix: str,
) -> Optional["QuantizeMethodBase"]: ) -> Optional["QuantizeMethodBase"]:
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE # Avoid circular import from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE # Avoid circular import
# from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
# Check if the layer is skipped for quantization. # Check if the layer is skipped for quantization.
if should_ignore_layer(prefix, if should_ignore_layer(prefix,
ignore=self.ignore, ignore=self.ignore,
...@@ -85,8 +85,8 @@ class SlimQuantCompressedTensorsMarlinConfig(CompressedTensorsConfig): ...@@ -85,8 +85,8 @@ class SlimQuantCompressedTensorsMarlinConfig(CompressedTensorsConfig):
return UnquantizedEmbeddingMethod()#UnquantizedLinearMethod() return UnquantizedEmbeddingMethod()#UnquantizedLinearMethod()
layer.scheme = scheme layer.scheme = scheme
return CompressedTensorsLinearMethod(self) return CompressedTensorsLinearMethod(self)
# if isinstance(layer, RadixAttention): if isinstance(layer, RadixAttention):
# return CompressedTensorsKVCacheMethod(self) return CompressedTensorsKVCacheMethod(self)
if isinstance(layer, FusedMoE): if isinstance(layer, FusedMoE):
return CompressedTensorsMarlinMoEMethod.get_moe_method(self, layer) return CompressedTensorsMarlinMoEMethod.get_moe_method(self, layer)
return None return None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment