Unverified Commit 5e0a9b09 authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Apply deepseek cuda rope (#5385)


Co-authored-by: default avatarYineng Zhang <me@zhyncs.com>
parent bdde2375
......@@ -645,7 +645,18 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
cache = torch.cat((cos, sin), dim=-1)
return cache
def forward(
def forward_hip(self, *args, **kwargs):
return self.forward_native(*args, **kwargs)
def forward(self, *args, **kwargs):
if torch._dynamo.is_compiling:
return self.forward_native(*args, **kwargs)
if _is_cuda_available:
return self.forward_cuda(*args, **kwargs)
else:
return self.forward_native(*args, **kwargs)
def forward_native(
self,
positions: torch.Tensor,
query: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment