"docs/basic_usage/native_api.ipynb" did not exist on "be7986e00544a28832841c916c07793173fd512c"
Unverified Commit e05e29d1 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Refactor CustomOp to avoid confusing bugs (#5382)

parent a2cb5913
from typing import Optional
import torch
from torch import nn from torch import nn
from sglang.srt.utils import is_cuda, is_hip from sglang.srt.utils import is_cuda, is_hip
...@@ -14,6 +11,26 @@ class CustomOp(nn.Module): ...@@ -14,6 +11,26 @@ class CustomOp(nn.Module):
super().__init__() super().__init__()
self._forward_method = self.dispatch_forward() self._forward_method = self.dispatch_forward()
def enter_torch_compile(self, num_tokens: int):
# NOTE: Temporarily workaround MoE
if "FusedMoE" in self.__class__.__name__:
if num_tokens == 1:
from sglang.srt.layers.moe.fused_moe_native import (
fused_moe_forward_native,
)
# The performance of torch.compile on this layer is not always good when bs > 1,
# so we decide to only use torch.compile when bs =1
self._forward_method = fused_moe_forward_native
else:
self._forward_method = self.forward_native
self.is_torch_compile = True
def leave_torch_compile(self):
self._forward_method = self.forward_cuda
self.is_torch_compile = False
# Please do not override this method, because `self._forward_method` can change when in torch compile mode
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
return self._forward_method(*args, **kwargs) return self._forward_method(*args, **kwargs)
......
...@@ -28,7 +28,6 @@ from sglang.srt.custom_op import CustomOp ...@@ -28,7 +28,6 @@ from sglang.srt.custom_op import CustomOp
from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.distributed import get_tensor_model_parallel_rank
from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native
from sglang.srt.layers.torchao_utils import save_gemlite_cache from sglang.srt.layers.torchao_utils import save_gemlite_cache
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ( from sglang.srt.model_executor.forward_batch_info import (
...@@ -60,18 +59,9 @@ def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int): ...@@ -60,18 +59,9 @@ def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
for sub in model._modules.values(): for sub in model._modules.values():
if isinstance(sub, CustomOp): if isinstance(sub, CustomOp):
if reverse: if reverse:
sub._forward_method = sub.forward_cuda sub.leave_torch_compile()
setattr(sub, "is_torch_compile", False)
else: else:
# NOTE: Temporarily workaround MoE sub.enter_torch_compile(num_tokens=num_tokens)
if "FusedMoE" in sub.__class__.__name__:
if num_tokens == 1:
# The performance of torch.compile on this layer is not always good when bs > 1,
# so we decide to only use torch.compile when bs =1
sub._forward_method = fused_moe_forward_native
else:
sub._forward_method = sub.forward_native
setattr(sub, "is_torch_compile", True)
if isinstance(sub, torch.nn.Module): if isinstance(sub, torch.nn.Module):
_to_torch(sub, reverse, num_tokens) _to_torch(sub, reverse, num_tokens)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment