Commit e42a922c authored by maxiao1's avatar maxiao1
Browse files

change tbo about cudagrah

parent 394648be
...@@ -75,8 +75,8 @@ class SiluAndMul(CustomOp): ...@@ -75,8 +75,8 @@ class SiluAndMul(CustomOp):
def forward_native(self, x: torch.Tensor) -> torch.Tensor: def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward().""" """PyTorch-native implementation equivalent to forward()."""
if not torch.compiler.is_compiling(): # 非 capture 阶段 if not torch.compiler.is_compiling() and envs.VLLM_ENABLE_TBO:
return self.forward_cuda(x) # 强制走 fused kernel return self.forward_cuda(x)
else: else:
d = x.shape[-1] // 2 d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:] return F.silu(x[..., :d]) * x[..., d:]
......
...@@ -166,10 +166,9 @@ class RMSNorm(CustomOp): ...@@ -166,10 +166,9 @@ class RMSNorm(CustomOp):
residual: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if not torch.compiler.is_compiling(): # 非 capture 阶段 if not torch.compiler.is_compiling() and envs.VLLM_ENABLE_TBO:
return self.forward_cuda(x, residual) # 强制走 fused kernel return self.forward_cuda(x, residual)
else: else:
# 否则fallback到原始实现
orig_dtype = x.dtype orig_dtype = x.dtype
x = x.to(torch.float32) x = x.to(torch.float32)
if residual is not None: if residual is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment