Unverified Commit 0759ec49 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Hotfixing qwen2 and starcoder2 (which also get clamping). (#2167)

parent 963b6c6f
...@@ -368,7 +368,7 @@ class Qwen2ForCausalLM(torch.nn.Module): ...@@ -368,7 +368,7 @@ class Qwen2ForCausalLM(torch.nn.Module):
elif self.max_past is not None: elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention # Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values # kernel requires the true values
input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor) input_lengths = input_lengths.clamp(max=self.max_past_tensor)
hidden_states = self.model( hidden_states = self.model(
input_ids, input_ids,
......
...@@ -534,7 +534,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module): ...@@ -534,7 +534,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
elif self.max_past is not None: elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention # Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values # kernel requires the true values
input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor) input_lengths = input_lengths.clamp(max=self.max_past_tensor)
hidden_states = self.model( hidden_states = self.model(
input_ids, input_ids,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment