Unverified Commit bc3e20dc authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[`Llama`] remove prompt and fix prefix finetuning (#25565)

* nit

* update

* make sure use_default_system_prompt is saved

* update checkpointing

* consistency

* use_default_system_prompt for test
parent 30b3c46f
...@@ -683,7 +683,7 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -683,7 +683,7 @@ class LlamaModel(LlamaPreTrainedModel):
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
# None for past_key_value # None for past_key_value
return module(*inputs, output_attentions, None) return module(*inputs, past_key_value, output_attentions)
return custom_forward return custom_forward
...@@ -692,7 +692,6 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -692,7 +692,6 @@ class LlamaModel(LlamaPreTrainedModel):
hidden_states, hidden_states,
attention_mask, attention_mask,
position_ids, position_ids,
None,
) )
else: else:
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
......
...@@ -113,6 +113,7 @@ class LlamaTokenizer(PreTrainedTokenizer): ...@@ -113,6 +113,7 @@ class LlamaTokenizer(PreTrainedTokenizer):
add_bos_token=True, add_bos_token=True,
add_eos_token=False, add_eos_token=False,
clean_up_tokenization_spaces=False, clean_up_tokenization_spaces=False,
use_default_system_prompt=True,
spaces_between_special_tokens=False, spaces_between_special_tokens=False,
legacy=None, legacy=None,
**kwargs, **kwargs,
...@@ -131,6 +132,7 @@ class LlamaTokenizer(PreTrainedTokenizer): ...@@ -131,6 +132,7 @@ class LlamaTokenizer(PreTrainedTokenizer):
add_eos_token=add_eos_token, add_eos_token=add_eos_token,
sp_model_kwargs=self.sp_model_kwargs, sp_model_kwargs=self.sp_model_kwargs,
clean_up_tokenization_spaces=clean_up_tokenization_spaces, clean_up_tokenization_spaces=clean_up_tokenization_spaces,
use_default_system_prompt=use_default_system_prompt,
spaces_between_special_tokens=spaces_between_special_tokens, spaces_between_special_tokens=spaces_between_special_tokens,
legacy=legacy, legacy=legacy,
**kwargs, **kwargs,
...@@ -149,8 +151,9 @@ class LlamaTokenizer(PreTrainedTokenizer): ...@@ -149,8 +151,9 @@ class LlamaTokenizer(PreTrainedTokenizer):
self.vocab_file = vocab_file self.vocab_file = vocab_file
self.add_bos_token = add_bos_token self.add_bos_token = add_bos_token
self.add_eos_token = add_eos_token self.add_eos_token = add_eos_token
self.sp_model = self.get_spm_processor() self.use_default_system_prompt = use_default_system_prompt
self.sp_model = self.get_spm_processor()
self.unk_token_length = len(self.sp_model.encode(str(self.unk_token))) self.unk_token_length = len(self.sp_model.encode(str(self.unk_token)))
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_spm_processor # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_spm_processor
...@@ -390,16 +393,20 @@ class LlamaTokenizer(PreTrainedTokenizer): ...@@ -390,16 +393,20 @@ class LlamaTokenizer(PreTrainedTokenizer):
`List[int]`: `List[int]`:
Input ids for the conversation. Input ids for the conversation.
""" """
if len(conversation.past_user_inputs) > 0: if self.use_default_system_prompt:
if not conversation.past_user_inputs[0].startswith(B_SYS) or E_SYS not in conversation.past_user_inputs[0]: if len(conversation.past_user_inputs) > 0:
conversation.past_user_inputs[0] = ( if (
B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.past_user_inputs[0] not conversation.past_user_inputs[0].startswith(B_SYS)
) or E_SYS not in conversation.past_user_inputs[0]
elif conversation.new_user_input: ):
if not conversation.new_user_input.startswith(B_SYS) or E_SYS not in conversation.new_user_input: conversation.past_user_inputs[0] = (
conversation.new_user_input = B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.new_user_input B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.past_user_inputs[0]
else: )
raise ValueError("Last message must be from user") elif conversation.new_user_input:
if not conversation.new_user_input.startswith(B_SYS) or E_SYS not in conversation.new_user_input:
conversation.new_user_input = B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.new_user_input
else:
raise ValueError("Last message must be from user")
dialogue = list(conversation.iter_texts()) dialogue = list(conversation.iter_texts())
if not all([is_user for is_user, msg in dialogue[::2]]) or not all( if not all([is_user for is_user, msg in dialogue[::2]]) or not all(
......
...@@ -110,6 +110,7 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast): ...@@ -110,6 +110,7 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast):
eos_token="</s>", eos_token="</s>",
add_bos_token=True, add_bos_token=True,
add_eos_token=False, add_eos_token=False,
use_default_system_prompt=True,
**kwargs, **kwargs,
): ):
super().__init__( super().__init__(
...@@ -119,12 +120,13 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast): ...@@ -119,12 +120,13 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast):
unk_token=unk_token, unk_token=unk_token,
bos_token=bos_token, bos_token=bos_token,
eos_token=eos_token, eos_token=eos_token,
use_default_system_prompt=use_default_system_prompt,
**kwargs, **kwargs,
) )
self._add_bos_token = add_bos_token self._add_bos_token = add_bos_token
self._add_eos_token = add_eos_token self._add_eos_token = add_eos_token
self.update_post_processor() self.update_post_processor()
self.use_default_system_prompt = use_default_system_prompt
self.vocab_file = vocab_file self.vocab_file = vocab_file
self.can_save_slow_tokenizer = False if not self.vocab_file else True self.can_save_slow_tokenizer = False if not self.vocab_file else True
...@@ -212,16 +214,20 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast): ...@@ -212,16 +214,20 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast):
`List[int]`: `List[int]`:
Input ids for the conversation. Input ids for the conversation.
""" """
if len(conversation.past_user_inputs) > 0: if self.use_default_system_prompt:
if not conversation.past_user_inputs[0].startswith(B_SYS) or E_SYS not in conversation.past_user_inputs[0]: if len(conversation.past_user_inputs) > 0:
conversation.past_user_inputs[0] = ( if (
B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.past_user_inputs[0] not conversation.past_user_inputs[0].startswith(B_SYS)
) or E_SYS not in conversation.past_user_inputs[0]
elif conversation.new_user_input: ):
if not conversation.new_user_input.startswith(B_SYS) or E_SYS not in conversation.new_user_input: conversation.past_user_inputs[0] = (
conversation.new_user_input = B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.new_user_input B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.past_user_inputs[0]
else: )
raise ValueError("Last message must be from user") elif conversation.new_user_input:
if not conversation.new_user_input.startswith(B_SYS) or E_SYS not in conversation.new_user_input:
conversation.new_user_input = B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.new_user_input
else:
raise ValueError("Last message must be from user")
dialogue = list(conversation.iter_texts()) dialogue = list(conversation.iter_texts())
if not all([is_user for is_user, msg in dialogue[::2]]) or not all( if not all([is_user for is_user, msg in dialogue[::2]]) or not all(
......
...@@ -220,7 +220,7 @@ class ConversationalPipelineTests(unittest.TestCase): ...@@ -220,7 +220,7 @@ class ConversationalPipelineTests(unittest.TestCase):
@require_torch @require_torch
@slow @slow
def test_integration_torch_conversation_llama2_input_ids(self): def test_integration_torch_conversation_llama2_input_ids(self):
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf", use_default_system_prompt=True)
conversation = Conversation( conversation = Conversation(
"What is so great about #1?", "What is so great about #1?",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment