"benchmark/git@developer.sourcefind.cn:change/sglang.git" did not exist on "3f4cc0aff0166252bf319c94faa68326ad14dffb"
Unverified Commit 918a06e2 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: add test to check KV format (#23403)


Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 9cf4a8b4
...@@ -1645,6 +1645,77 @@ class GenerationTesterMixin: ...@@ -1645,6 +1645,77 @@ class GenerationTesterMixin:
self.assertTrue(no_failures) self.assertTrue(no_failures)
def test_past_key_values_format(self):
# Test that the KV cache is formatted correctly. Exceptions need to explicitly overwrite this test. Having a
# standard KV cache format is important for a consistent API (and for advanced generation methods).
for model_class in self.all_generative_model_classes:
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
# If it doesn't support cache, pass the test
if not hasattr(config, "use_cache"):
return
model = model_class(config).to(torch_device)
if "use_cache" not in inputs:
inputs["use_cache"] = True
outputs = model(**inputs)
# If "past_key_values" is not returned, pass the test (e.g. RWKV uses a different cache name and format)
if "past_key_values" not in outputs:
return
num_hidden_layers = (
getattr(config, "decoder_layers", None)
or getattr(config, "num_decoder_layers", None)
or config.num_hidden_layers
)
num_attention_heads = getattr(config, "decoder_attention_heads", config.num_attention_heads)
embed_dim = getattr(config, "d_model", config.hidden_size)
per_head_embed_dim = embed_dim // num_attention_heads
past_kv = outputs["past_key_values"]
self.assertEqual(len(past_kv), num_hidden_layers)
# Encoder-Decoder checks
if config.is_encoder_decoder:
encoder_num_attention_heads = config.encoder_attention_heads
encoder_per_head_embed_dim = embed_dim // encoder_num_attention_heads
batch_size, seq_length = inputs["decoder_input_ids"].shape
for i in range(num_hidden_layers):
self.assertEqual(len(past_kv[i]), 4) # K V for the decoder + K V for the encoder = 4
self.assertEqual(
past_kv[i][0].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim)
)
self.assertEqual(
past_kv[i][1].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim)
)
# The sequence length for the encoder K V depends on the model. Since it is not manipulated in
# autoregressive generation, I'm keeping the test general and not checking the 3rd dim
self.assertEqual(
(past_kv[i][2].shape[0], past_kv[i][2].shape[1], past_kv[i][2].shape[3]),
(batch_size, encoder_num_attention_heads, encoder_per_head_embed_dim),
)
self.assertEqual(
(past_kv[i][3].shape[0], past_kv[i][3].shape[1], past_kv[i][3].shape[3]),
(batch_size, encoder_num_attention_heads, encoder_per_head_embed_dim),
)
# Decoder-only checks
else:
# TODO: this line is only needed because of imagegpt, where "pixel_values" = "input_ids". Fix the
# tests in imagegpt such that `prepare_config_and_inputs_for_common` returns the later (and the other
# tests use it)
key = "input_ids" if "input_ids" in inputs else "pixel_values"
batch_size, seq_length = inputs[key].shape
for i in range(num_hidden_layers):
self.assertEqual(len(past_kv[0]), 2) # K V for the decoder = 2
self.assertEqual(
past_kv[i][0].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim)
)
self.assertEqual(
past_kv[i][1].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim)
)
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1): def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
batch_size, seq_length = input_ids.shape batch_size, seq_length = input_ids.shape
num_sequences_in_output = batch_size * num_return_sequences num_sequences_in_output = batch_size * num_return_sequences
......
...@@ -393,6 +393,10 @@ class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi ...@@ -393,6 +393,10 @@ class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bloom_weight_initialization(*config_and_inputs) self.model_tester.create_and_check_bloom_weight_initialization(*config_and_inputs)
@unittest.skip("Bloom has a non-standard KV cache format.")
def test_past_key_values_format(self):
pass
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
......
...@@ -475,6 +475,10 @@ class GPTBigCodeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste ...@@ -475,6 +475,10 @@ class GPTBigCodeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
def test_disk_offload(self): def test_disk_offload(self):
pass pass
@unittest.skip("BigCodeGPT has a non-standard KV cache format.")
def test_past_key_values_format(self):
pass
def test_gpt_bigcode_model(self): def test_gpt_bigcode_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_gpt_bigcode_model(*config_and_inputs) self.model_tester.create_and_check_gpt_bigcode_model(*config_and_inputs)
......
...@@ -831,8 +831,12 @@ class ReformerLSHAttnModelTest( ...@@ -831,8 +831,12 @@ class ReformerLSHAttnModelTest(
[expected_shape] * len(iter_hidden_states), [expected_shape] * len(iter_hidden_states),
) )
@unittest.skip("Fails because the sequence length is not a multiple of 4")
def test_problem_types(self): def test_problem_types(self):
# Fails because the sequence length is not a multiple of 4 pass
@unittest.skip("Fails because the sequence length is not a multiple of 4")
def test_past_key_values_format(self):
pass pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment