Unverified Commit 48ed24c5 authored by Lunwen He's avatar Lunwen He Committed by GitHub
Browse files

Remove size check between attn_weights and kv_seq_len for phi3 (#32339)

* Remove size check between attn_weights and kv_seq_len

* add unit tests
parent e234061c
...@@ -453,12 +453,6 @@ class Phi3Attention(nn.Module): ...@@ -453,12 +453,6 @@ class Phi3Attention(nn.Module):
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None: if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights += causal_mask attn_weights += causal_mask
......
...@@ -19,7 +19,7 @@ import unittest ...@@ -19,7 +19,7 @@ import unittest
from parameterized import parameterized from parameterized import parameterized
from transformers import Phi3Config, is_torch_available, set_seed from transformers import Phi3Config, StaticCache, is_torch_available, set_seed
from transformers.testing_utils import ( from transformers.testing_utils import (
require_torch, require_torch,
slow, slow,
...@@ -43,6 +43,55 @@ if is_torch_available(): ...@@ -43,6 +43,55 @@ if is_torch_available():
Phi3Model, Phi3Model,
) )
end_of_text_token = 32000
class Phi3MiniWithStaticCache(torch.nn.Module):
def __init__(self, model: Phi3ForCausalLM, max_batch_size: int, max_seq_len: int):
super().__init__()
self.model = model
self.cache = StaticCache(
config=model.config,
max_batch_size=max_batch_size,
max_cache_len=max_seq_len,
device=self.model.device,
dtype=self.model.dtype,
)
def forward(
self,
input_ids: torch.LongTensor = None,
) -> torch.FloatTensor:
return self.model.forward(
input_ids=input_ids,
use_cache=True,
return_dict=True,
past_key_values=self.cache,
).logits
@staticmethod
def generate(model: Phi3ForCausalLM, prompt_tokens: torch.LongTensor, max_seq_len: int) -> list[int]:
model = Phi3MiniWithStaticCache(model, 1, max_seq_len + prompt_tokens.shape[-1])
response_tokens = []
for input_pos in range(prompt_tokens.shape[-1]):
result = model.forward(
input_ids=prompt_tokens[:, input_pos : input_pos + 1],
)
response_tokens.append(prompt_tokens[0][input_pos].item())
current_token = torch.argmax(result[:, -1, :], dim=-1).item()
response_tokens.append(current_token)
while current_token != end_of_text_token and len(response_tokens) < max_seq_len:
result = model.forward(
input_ids=torch.tensor([[current_token]], dtype=torch.long),
)
current_token = torch.argmax(result[:, -1, :], dim=-1).item()
response_tokens.append(current_token)
return response_tokens
class Phi3ModelTester: class Phi3ModelTester:
def __init__( def __init__(
...@@ -429,7 +478,30 @@ class Phi3IntegrationTest(unittest.TestCase): ...@@ -429,7 +478,30 @@ class Phi3IntegrationTest(unittest.TestCase):
output_text = tokenizer.batch_decode(outputs) output_text = tokenizer.batch_decode(outputs)
EXPECTED_OUTPUT = [ EXPECTED_OUTPUT = [
"<s><|system|> You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.<|end|><|user|> Can you provide ways to eat combinations of bananas and dragonfruits?<|end|><|assistant|> Absolutely! Bananas and dragonfruits are both delicious fruits that can be combined in various ways to create tasty and nutrit" "<|system|> You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.<|end|><|user|> Can you provide ways to eat combinations of bananas and dragonfruits?<|end|><|assistant|> Certainly! Bananas and dragonfruits can be combined in various delicious ways. Here are some ideas for incorporating these fruits into your"
]
self.assertListEqual(output_text, EXPECTED_OUTPUT)
def test_phi3_mini_4k_instruct_with_static_cache(self):
model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-4k-instruct")
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-4k-instruct")
messages = [
{
"role": "system",
"content": "You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.",
},
{"role": "user", "content": "Can you provide ways to eat combinations of bananas and dragonfruits?"},
]
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
response_tokens = Phi3MiniWithStaticCache.generate(model, inputs, 64)
output_text = tokenizer.batch_decode(torch.tensor([response_tokens], dtype=torch.long, device=torch_device))
EXPECTED_OUTPUT = [
"<|system|> You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.<|end|><|user|> Can you provide ways to eat combinations of bananas and dragonfruits?<|end|><|assistant|> Certainly! Bananas and dragonfruits can be combined in various delicious ways. Here are some"
] ]
self.assertListEqual(output_text, EXPECTED_OUTPUT) self.assertListEqual(output_text, EXPECTED_OUTPUT)
...@@ -467,7 +539,30 @@ class Phi3IntegrationTest(unittest.TestCase): ...@@ -467,7 +539,30 @@ class Phi3IntegrationTest(unittest.TestCase):
output_text = tokenizer.batch_decode(outputs) output_text = tokenizer.batch_decode(outputs)
EXPECTED_OUTPUT = [ EXPECTED_OUTPUT = [
"<s><|system|> You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.<|end|><|user|> Can you provide ways to eat combinations of bananas and dragonfruits?<|end|><|assistant|> Certainly! Bananas and dragonfruits can be combined in various delicious and healthy ways. Here are some ideas:\n\n1." "<|system|> You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.<|end|><|user|> Can you provide ways to eat combinations of bananas and dragonfruits?<|end|><|assistant|> Certainly! Bananas and dragonfruits can be combined in various delicious and nutritious ways. Here are some creative and healthy"
]
self.assertListEqual(output_text, EXPECTED_OUTPUT)
def test_phi3_mini_128k_instruct_with_static_cache(self):
model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-128k-instruct")
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-128k-instruct")
messages = [
{
"role": "system",
"content": "You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.",
},
{"role": "user", "content": "Can you provide ways to eat combinations of bananas and dragonfruits?"},
]
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
response_tokens = Phi3MiniWithStaticCache.generate(model, inputs, 64)
output_text = tokenizer.batch_decode(torch.tensor([response_tokens], dtype=torch.long, device=torch_device))
EXPECTED_OUTPUT = [
"<|system|> You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.<|end|><|user|> Can you provide ways to eat combinations of bananas and dragonfruits?<|end|><|assistant|> Certainly! Bananas and dragonfruits can be combined in various delicious and nutritious ways"
] ]
self.assertListEqual(output_text, EXPECTED_OUTPUT) self.assertListEqual(output_text, EXPECTED_OUTPUT)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment