Unverified Commit 05ea7b79 authored by Tom Aarsen's avatar Tom Aarsen Committed by GitHub
Browse files

Refactor: Use Llama RoPE implementation for Falcon (#26933)

* Use Llama RoPE implementation for Falcon

+ Add copy functionalities

* Use standard cache format for Falcon

* Simplify apply_rotary_pos_emb, copy from Llama

* Remove unnecessary cache conversion test

We don't need to convert any caches anymore!

* Resolve copy complaint
parent e9a6c72b
......@@ -340,24 +340,6 @@ class FalconModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
def test_cache_conversions(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = input_dict["input_ids"]
model = FalconForCausalLM(config)
model.to(torch_device)
model.eval()
result = model(input_ids, use_cache=True)
batch_size = input_ids.shape[0]
rw_cache = model._convert_to_rw_cache(result.past_key_values)
standard_cache = model._convert_cache_to_standard_format(rw_cache, batch_size)
for layer in range(len(rw_cache)):
for tensor_idx in range(2):
self.assertTrue(rw_cache[layer][tensor_idx].ndim == 3)
self.assertTrue(result.past_key_values[layer][tensor_idx].ndim == 4)
self.assertTrue(
torch.all(result.past_key_values[layer][tensor_idx] == standard_cache[layer][tensor_idx])
)
def test_falcon_sequence_classification_model_for_multi_label(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment