Unverified Commit ac964d2e authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Support global scale in addition to per expert scale for cutedsl moe (#10270)

parent fa46e2bd
......@@ -39,7 +39,7 @@ from sglang.srt.layers.quantization.utils import (
requantize_with_max_scale,
)
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.utils import is_cuda, next_power_of_2
from sglang.srt.utils import get_bool_env_var, is_cuda, next_power_of_2
if TYPE_CHECKING:
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
......@@ -74,6 +74,10 @@ except ImportError:
# Initialize logger for the module
logger = logging.getLogger(__name__)
CUTEDSL_MOE_SCALAR_INPUT_SCALE = get_bool_env_var(
"SGLANG_CUTEDSL_MOE_SCALAR_INPUT_SCALE", "true"
)
# Supported activation schemes for the current configuration
ACTIVATION_SCHEMES = ["static"]
......@@ -1190,7 +1194,19 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
w13_input_scale = layer.w13_input_scale.max().to(torch.float32)
w2_input_scale = layer.w2_input_scale.max().to(torch.float32)
elif self.enable_flashinfer_cutedsl_moe:
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
# All-expert-one-input-scale is mathematically different from default per-expert-input-scale
# Thus we allow users to switch the flag to do thorough testing
if CUTEDSL_MOE_SCALAR_INPUT_SCALE:
w13_input_scale = (
layer.w13_input_scale.max()
.to(torch.float32)
.repeat(layer.w13_input_scale.shape[0])
)
else:
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(
torch.float32
)
w2_input_scale = layer.w2_input_scale
def _slice_scale(w):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment