Unverified Commit 180ff5ee authored by JieXin Liang's avatar JieXin Liang Committed by GitHub
Browse files

[fix] recover auto-dispatch for rmsnorm and rope (#6745)

parent 37f15475
......@@ -11,7 +11,20 @@ class CustomOp(nn.Module):
super().__init__()
self._forward_method = self.dispatch_forward()
# States for torch.compile
self._original_forward_method = None
self.is_torch_compile = False
def enter_torch_compile(self, num_tokens: int):
# Skip if Op is already entered compile mode.
# NOTE(alcanderian): Some Ops(for example RotaryEmbedding) will be reused
# among layers and `enter_torch_compile` will be called many times.
# We should prevent `self._original_forward_method` from being overridden when
# it is not the first time `enter_torch_compile` called.
if self.is_torch_compile:
return
self._original_forward_method = self._forward_method
# NOTE: Temporarily workaround MoE
if "FusedMoE" in self.__class__.__name__:
if num_tokens == 1:
......@@ -27,7 +40,12 @@ class CustomOp(nn.Module):
self.is_torch_compile = True
def leave_torch_compile(self):
self._forward_method = self.forward_cuda
# Skip if Op is already exited compile mode.
if not self.is_torch_compile:
return
self._forward_method = self._original_forward_method
self._original_forward_method = None
self.is_torch_compile = False
# Please do not override this method, because `self._forward_method` can change when in torch compile mode
......
......@@ -49,16 +49,6 @@ class RMSNorm(CustomOp):
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, *args, **kwargs):
if torch.compiler.is_compiling():
return self.forward_native(*args, **kwargs)
if _is_cuda:
return self.forward_cuda(*args, **kwargs)
elif _is_hip:
return self.forward_hip(*args, **kwargs)
else:
return self.forward_native(*args, **kwargs)
def forward_cuda(
self,
x: torch.Tensor,
......@@ -117,13 +107,9 @@ class GemmaRMSNorm(CustomOp):
self.weight = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = eps
def forward(self, *args, **kwargs):
if torch.compiler.is_compiling():
return self.forward_native(*args, **kwargs)
if _is_cuda:
return self.forward_cuda(*args, **kwargs)
else:
return self.forward_native(*args, **kwargs)
# Re-dispatch
if _is_hip:
self._forward_method = self.forward_native
def forward_native(
self,
......
......@@ -8,9 +8,10 @@ import torch
import torch.nn as nn
from sglang.srt.custom_op import CustomOp
from sglang.srt.utils import is_cuda
from sglang.srt.utils import is_cuda, is_hip
_is_cuda = is_cuda()
_is_hip = is_hip()
if _is_cuda:
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
......@@ -609,6 +610,10 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
)
# Re-dispatch
if _is_hip:
self._forward_method = self.forward_native
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
pos_freqs = self.base ** (
torch.arange(0, self.rotary_dim, 2, dtype=torch.float, device=self.device)
......@@ -650,17 +655,6 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
cache = torch.cat((cos, sin), dim=-1)
return cache
def forward_hip(self, *args, **kwargs):
return self.forward_native(*args, **kwargs)
def forward(self, *args, **kwargs):
if torch.compiler.is_compiling():
return self.forward_native(*args, **kwargs)
if _is_cuda:
return self.forward_cuda(*args, **kwargs)
else:
return self.forward_native(*args, **kwargs)
def forward_native(
self,
positions: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment