Unverified Commit a437aa99 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

[hotfix] fix mixtral with tensor-level compressed-tensor quantization (#8721)

parent 0e612dbf
...@@ -23,6 +23,7 @@ from sglang.srt.layers.quantization.utils import ( ...@@ -23,6 +23,7 @@ from sglang.srt.layers.quantization.utils import (
from sglang.srt.utils import is_cpu, is_cuda, is_hip, is_npu, set_weight_attrs from sglang.srt.utils import is_cpu, is_cuda, is_hip, is_npu, set_weight_attrs
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import ( from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
CompressedTensorsConfig, CompressedTensorsConfig,
...@@ -189,7 +190,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -189,7 +190,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer.w13_input_scale = None layer.w13_input_scale = None
layer.w2_input_scale = None layer.w2_input_scale = None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: FusedMoE) -> None:
# Fp8 moe kernels require a single activation scale. # Fp8 moe kernels require a single activation scale.
# We take the max of all the scales in case they differ. # We take the max of all the scales in case they differ.
if self.static_input_scales: if self.static_input_scales:
...@@ -246,7 +247,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -246,7 +247,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
assert layer.w13_weight_scale is not None assert layer.w13_weight_scale is not None
shard_size = layer.intermediate_size_per_partition shard_size = layer.intermediate_size_per_partition
max_w13_scales = layer.w13_weight_scale.max(dim=1).values max_w13_scales = layer.w13_weight_scale.max(dim=1).values
for expert_id in range(layer.local_num_experts): for expert_id in range(layer.num_local_experts):
start = 0 start = 0
for shard_id in range(2): for shard_id in range(2):
dq_weight = per_tensor_dequantize( dq_weight = per_tensor_dequantize(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment