"python/vscode:/vscode.git/clone" did not exist on "8153168c96c76cdc77eabcbe03b167f9f3b4385f"
Unverified Commit ac964d2e authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Support global scale in addition to per expert scale for cutedsl moe (#10270)

parent fa46e2bd
...@@ -39,7 +39,7 @@ from sglang.srt.layers.quantization.utils import ( ...@@ -39,7 +39,7 @@ from sglang.srt.layers.quantization.utils import (
requantize_with_max_scale, requantize_with_max_scale,
) )
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.utils import is_cuda, next_power_of_2 from sglang.srt.utils import get_bool_env_var, is_cuda, next_power_of_2
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
...@@ -74,6 +74,10 @@ except ImportError: ...@@ -74,6 +74,10 @@ except ImportError:
# Initialize logger for the module # Initialize logger for the module
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
CUTEDSL_MOE_SCALAR_INPUT_SCALE = get_bool_env_var(
"SGLANG_CUTEDSL_MOE_SCALAR_INPUT_SCALE", "true"
)
# Supported activation schemes for the current configuration # Supported activation schemes for the current configuration
ACTIVATION_SCHEMES = ["static"] ACTIVATION_SCHEMES = ["static"]
...@@ -1190,7 +1194,19 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -1190,7 +1194,19 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
w13_input_scale = layer.w13_input_scale.max().to(torch.float32) w13_input_scale = layer.w13_input_scale.max().to(torch.float32)
w2_input_scale = layer.w2_input_scale.max().to(torch.float32) w2_input_scale = layer.w2_input_scale.max().to(torch.float32)
elif self.enable_flashinfer_cutedsl_moe: elif self.enable_flashinfer_cutedsl_moe:
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32) # All-expert-one-input-scale is mathematically different from default per-expert-input-scale
# Thus we allow users to switch the flag to do thorough testing
if CUTEDSL_MOE_SCALAR_INPUT_SCALE:
w13_input_scale = (
layer.w13_input_scale.max()
.to(torch.float32)
.repeat(layer.w13_input_scale.shape[0])
)
else:
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(
torch.float32
)
w2_input_scale = layer.w2_input_scale w2_input_scale = layer.w2_input_scale
def _slice_scale(w): def _slice_scale(w):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment