Unverified Commit 07ec07ad authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Improve torch compile for fused moe (#2327)

parent 83b340e3
...@@ -6,6 +6,7 @@ from torch.nn import functional as F ...@@ -6,6 +6,7 @@ from torch.nn import functional as F
from transformers import AutoConfig from transformers import AutoConfig
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_moe as fused_moe_triton from sglang.srt.layers.fused_moe_triton.fused_moe import fused_moe as fused_moe_triton
from sglang.srt.model_executor.cuda_graph_runner import set_torch_compile_config
def get_model_config(model_name: str, tp_size: int): def get_model_config(model_name: str, tp_size: int):
...@@ -64,7 +65,7 @@ def fused_topk_native( ...@@ -64,7 +65,7 @@ def fused_topk_native(
return topk_weights, topk_ids return topk_weights, topk_ids
@torch.compile @torch.compile(dynamic=False)
def fused_moe_torch( def fused_moe_torch(
x, x,
w1, w1,
...@@ -88,7 +89,8 @@ def fused_moe_torch( ...@@ -88,7 +89,8 @@ def fused_moe_torch(
w13_weights = w1[topk_ids] w13_weights = w1[topk_ids]
w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2) w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
w2_weights = w2[topk_ids] w2_weights = w2[topk_ids]
x1 = F.gelu(torch.einsum("ti,taoi -> tao", x, w1_weights)) x1 = torch.einsum("ti,taoi -> tao", x, w1_weights)
x1 = F.silu(x1)
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights) x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights) expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype)) return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
...@@ -174,6 +176,7 @@ def benchmark(batch_size, provider, model_config, use_fp8=False): ...@@ -174,6 +176,7 @@ def benchmark(batch_size, provider, model_config, use_fp8=False):
print(f"benchmark {provider} with batch_size={batch_size}") print(f"benchmark {provider} with batch_size={batch_size}")
torch.set_default_device("cuda") torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0) torch.cuda.manual_seed_all(0)
set_torch_compile_config()
num_tokens = batch_size num_tokens = batch_size
num_experts = model_config["num_experts"] num_experts = model_config["num_experts"]
......
...@@ -105,20 +105,29 @@ def fused_moe_forward_native( ...@@ -105,20 +105,29 @@ def fused_moe_forward_native(
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert custom_routing_function is None
topk_weights, topk_ids = select_experts_native( if use_grouped_topk:
hidden_states=x, assert num_expert_group is not None and topk_group is not None
router_logits=router_logits, topk_weights, topk_ids = grouped_topk(
use_grouped_topk=use_grouped_topk, x,
top_k=top_k, router_logits,
renormalize=renormalize, top_k,
topk_group=topk_group, renormalize,
num_expert_group=num_expert_group, num_expert_group,
) topk_group,
)
elif custom_routing_function is None:
topk_weights, topk_ids = fused_topk_native(x, router_logits, top_k, renormalize)
else:
topk_weights, topk_ids = custom_routing_function(
x, router_logits, top_k, renormalize
)
w13_weights = layer.w13_weight[topk_ids] w13_weights = layer.w13_weight[topk_ids]
w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2) w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
w2_weights = layer.w2_weight[topk_ids] w2_weights = layer.w2_weight[topk_ids]
x1 = F.silu(torch.einsum("ti,taoi -> tao", x, w1_weights)) x1 = torch.einsum("ti,taoi -> tao", x, w1_weights)
x1 = F.silu(x1)
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights) x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights) expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype)) return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
...@@ -36,7 +36,7 @@ if TYPE_CHECKING: ...@@ -36,7 +36,7 @@ if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
def _to_torch(model: torch.nn.Module, reverse: bool = False): def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int):
for sub in model._modules.values(): for sub in model._modules.values():
if isinstance(sub, CustomOp): if isinstance(sub, CustomOp):
if reverse: if reverse:
...@@ -45,24 +45,30 @@ def _to_torch(model: torch.nn.Module, reverse: bool = False): ...@@ -45,24 +45,30 @@ def _to_torch(model: torch.nn.Module, reverse: bool = False):
else: else:
# NOTE: Temporarily workaround MoE # NOTE: Temporarily workaround MoE
if "FusedMoE" in sub.__class__.__name__: if "FusedMoE" in sub.__class__.__name__:
sub._forward_method = fused_moe_forward_native if batch_size == 1:
# The performance of torch.compile on this layer is not always good when bs > 1,
# so we decide to skip it for now.
sub._forward_method = fused_moe_forward_native
else: else:
sub._forward_method = sub.forward_native sub._forward_method = sub.forward_native
setattr(sub, "is_torch_compile", True) setattr(sub, "is_torch_compile", True)
if isinstance(sub, torch.nn.Module): if isinstance(sub, torch.nn.Module):
_to_torch(sub, reverse) _to_torch(sub, reverse, batch_size)
@contextmanager @contextmanager
def patch_model( def patch_model(
model: torch.nn.Module, enable_compile: bool, tp_group: "GroupCoordinator" model: torch.nn.Module,
enable_compile: bool,
batch_size: int,
tp_group: "GroupCoordinator",
): ):
"""Patch the model to make it compatible with with torch.compile""" """Patch the model to make it compatible with with torch.compile"""
backup_ca_comm = None backup_ca_comm = None
try: try:
if enable_compile: if enable_compile:
_to_torch(model) _to_torch(model, reverse=False, batch_size=batch_size)
monkey_patch_vllm_all_gather() monkey_patch_vllm_all_gather()
backup_ca_comm = tp_group.ca_comm backup_ca_comm = tp_group.ca_comm
# Use custom-allreduce here. # Use custom-allreduce here.
...@@ -70,13 +76,15 @@ def patch_model( ...@@ -70,13 +76,15 @@ def patch_model(
# even with ENABLE_INTRA_NODE_COMM=1. # even with ENABLE_INTRA_NODE_COMM=1.
# tp_group.ca_comm = None # tp_group.ca_comm = None
yield torch.compile( yield torch.compile(
torch.no_grad()(model.forward), mode="max-autotune-no-cudagraphs" torch.no_grad()(model.forward),
mode="max-autotune-no-cudagraphs",
dynamic=False,
) )
else: else:
yield model.forward yield model.forward
finally: finally:
if enable_compile: if enable_compile:
_to_torch(model, reverse=True) _to_torch(model, reverse=True, batch_size=batch_size)
monkey_patch_vllm_all_gather(reverse=True) monkey_patch_vllm_all_gather(reverse=True)
tp_group.ca_comm = backup_ca_comm tp_group.ca_comm = backup_ca_comm
...@@ -237,6 +245,7 @@ class CudaGraphRunner: ...@@ -237,6 +245,7 @@ class CudaGraphRunner:
with patch_model( with patch_model(
self.model_runner.model, self.model_runner.model,
bs in self.compile_bs, bs in self.compile_bs,
bs,
self.model_runner.tp_group, self.model_runner.tp_group,
) as forward: ) as forward:
( (
......
...@@ -622,7 +622,7 @@ class ModelRunner: ...@@ -622,7 +622,7 @@ class ModelRunner:
tic = time.time() tic = time.time()
logger.info("Capture cuda graph begin. This can take up to several minutes.") logger.info("Capture cuda graph begin. This can take up to several minutes.")
self.cuda_graph_runner = CudaGraphRunner(self) self.cuda_graph_runner = CudaGraphRunner(self)
logger.info(f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f}s") logger.info(f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s")
def apply_torch_tp(self): def apply_torch_tp(self):
logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.") logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
......
...@@ -188,7 +188,7 @@ class TestSRTEngine(unittest.TestCase): ...@@ -188,7 +188,7 @@ class TestSRTEngine(unittest.TestCase):
) )
bench_args = BenchArgs(num_prompts=10) bench_args = BenchArgs(num_prompts=10)
result = throughput_test(server_args=server_args, bench_args=bench_args) result = throughput_test(server_args=server_args, bench_args=bench_args)
self.assertGreater(result["total_throughput"], 3500) self.assertGreater(result["total_throughput"], 3000)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -14,7 +14,7 @@ from sglang.test.test_utils import ( ...@@ -14,7 +14,7 @@ from sglang.test.test_utils import (
) )
class TestTorchCompile(unittest.TestCase): class TestTorchCompileMoe(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.model = DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST cls.model = DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST
...@@ -23,7 +23,7 @@ class TestTorchCompile(unittest.TestCase): ...@@ -23,7 +23,7 @@ class TestTorchCompile(unittest.TestCase):
cls.model, cls.model,
cls.base_url, cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=["--enable-torch-compile", "--torch-compile-max-bs", "1"], other_args=["--enable-torch-compile", "--torch-compile-max-bs", "8"],
) )
@classmethod @classmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment