Unverified Commit 4a63bc32 authored by Baizhou Zhang's avatar Baizhou Zhang Committed by GitHub
Browse files

[Fix] Add torch compile for torch.clamp back (#4936)

parent a303325f
...@@ -39,6 +39,7 @@ import triton ...@@ -39,6 +39,7 @@ import triton
import triton.language as tl import triton.language as tl
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
from sglang.srt.utils import get_compiler_backend
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
...@@ -299,7 +300,7 @@ class ForwardBatch: ...@@ -299,7 +300,7 @@ class ForwardBatch:
# Init position information # Init position information
if ret.forward_mode.is_decode(): if ret.forward_mode.is_decode():
if ret.positions is None: if ret.positions is None:
ret.positions = torch.clamp((batch.seq_lens - 1), min=0).to(torch.int64) ret.positions = clamp_position(batch.seq_lens)
else: else:
ret.extend_seq_lens = torch.tensor( ret.extend_seq_lens = torch.tensor(
batch.extend_seq_lens, dtype=torch.int32 batch.extend_seq_lens, dtype=torch.int32
...@@ -519,3 +520,8 @@ def compute_position_torch( ...@@ -519,3 +520,8 @@ def compute_position_torch(
extend_start_loc = torch.zeros_like(extend_seq_lens) extend_start_loc = torch.zeros_like(extend_seq_lens)
extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0) extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0)
return positions.to(torch.int64), extend_start_loc return positions.to(torch.int64), extend_start_loc
@torch.compile(dynamic=True, backend=get_compiler_backend())
def clamp_position(seq_lens):
return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment