Unverified Commit 180ff5ee authored by JieXin Liang's avatar JieXin Liang Committed by GitHub
Browse files

[fix] recover auto-dispatch for rmsnorm and rope (#6745)

parent 37f15475
...@@ -11,7 +11,20 @@ class CustomOp(nn.Module): ...@@ -11,7 +11,20 @@ class CustomOp(nn.Module):
super().__init__() super().__init__()
self._forward_method = self.dispatch_forward() self._forward_method = self.dispatch_forward()
# States for torch.compile
self._original_forward_method = None
self.is_torch_compile = False
def enter_torch_compile(self, num_tokens: int): def enter_torch_compile(self, num_tokens: int):
# Skip if Op is already entered compile mode.
# NOTE(alcanderian): Some Ops(for example RotaryEmbedding) will be reused
# among layers and `enter_torch_compile` will be called many times.
# We should prevent `self._original_forward_method` from being overridden when
# it is not the first time `enter_torch_compile` called.
if self.is_torch_compile:
return
self._original_forward_method = self._forward_method
# NOTE: Temporarily workaround MoE # NOTE: Temporarily workaround MoE
if "FusedMoE" in self.__class__.__name__: if "FusedMoE" in self.__class__.__name__:
if num_tokens == 1: if num_tokens == 1:
...@@ -27,7 +40,12 @@ class CustomOp(nn.Module): ...@@ -27,7 +40,12 @@ class CustomOp(nn.Module):
self.is_torch_compile = True self.is_torch_compile = True
def leave_torch_compile(self): def leave_torch_compile(self):
self._forward_method = self.forward_cuda # Skip if Op is already exited compile mode.
if not self.is_torch_compile:
return
self._forward_method = self._original_forward_method
self._original_forward_method = None
self.is_torch_compile = False self.is_torch_compile = False
# Please do not override this method, because `self._forward_method` can change when in torch compile mode # Please do not override this method, because `self._forward_method` can change when in torch compile mode
......
...@@ -49,16 +49,6 @@ class RMSNorm(CustomOp): ...@@ -49,16 +49,6 @@ class RMSNorm(CustomOp):
self.weight = nn.Parameter(torch.ones(hidden_size)) self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps self.variance_epsilon = eps
def forward(self, *args, **kwargs):
if torch.compiler.is_compiling():
return self.forward_native(*args, **kwargs)
if _is_cuda:
return self.forward_cuda(*args, **kwargs)
elif _is_hip:
return self.forward_hip(*args, **kwargs)
else:
return self.forward_native(*args, **kwargs)
def forward_cuda( def forward_cuda(
self, self,
x: torch.Tensor, x: torch.Tensor,
...@@ -117,13 +107,9 @@ class GemmaRMSNorm(CustomOp): ...@@ -117,13 +107,9 @@ class GemmaRMSNorm(CustomOp):
self.weight = nn.Parameter(torch.zeros(hidden_size)) self.weight = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = eps self.variance_epsilon = eps
def forward(self, *args, **kwargs): # Re-dispatch
if torch.compiler.is_compiling(): if _is_hip:
return self.forward_native(*args, **kwargs) self._forward_method = self.forward_native
if _is_cuda:
return self.forward_cuda(*args, **kwargs)
else:
return self.forward_native(*args, **kwargs)
def forward_native( def forward_native(
self, self,
......
...@@ -8,9 +8,10 @@ import torch ...@@ -8,9 +8,10 @@ import torch
import torch.nn as nn import torch.nn as nn
from sglang.srt.custom_op import CustomOp from sglang.srt.custom_op import CustomOp
from sglang.srt.utils import is_cuda from sglang.srt.utils import is_cuda, is_hip
_is_cuda = is_cuda() _is_cuda = is_cuda()
_is_hip = is_hip()
if _is_cuda: if _is_cuda:
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
...@@ -609,6 +610,10 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): ...@@ -609,6 +610,10 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
) )
# Re-dispatch
if _is_hip:
self._forward_method = self.forward_native
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
pos_freqs = self.base ** ( pos_freqs = self.base ** (
torch.arange(0, self.rotary_dim, 2, dtype=torch.float, device=self.device) torch.arange(0, self.rotary_dim, 2, dtype=torch.float, device=self.device)
...@@ -650,17 +655,6 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): ...@@ -650,17 +655,6 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
cache = torch.cat((cos, sin), dim=-1) cache = torch.cat((cos, sin), dim=-1)
return cache return cache
def forward_hip(self, *args, **kwargs):
return self.forward_native(*args, **kwargs)
def forward(self, *args, **kwargs):
if torch.compiler.is_compiling():
return self.forward_native(*args, **kwargs)
if _is_cuda:
return self.forward_cuda(*args, **kwargs)
else:
return self.forward_native(*args, **kwargs)
def forward_native( def forward_native(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment