Unverified Commit 136c6e04 authored by Xinyuan Tong's avatar Xinyuan Tong Committed by GitHub
Browse files

fix: Handles input_embeds in GenerateReqInput when n>1 (#7830)


Signed-off-by: default avatarXinyuan Tong <justinning0323@outlook.com>
parent 43e20c06
...@@ -200,6 +200,8 @@ class GenerateReqInput: ...@@ -200,6 +200,8 @@ class GenerateReqInput:
self.text = [self.text] self.text = [self.text]
if self.input_ids is not None: if self.input_ids is not None:
self.input_ids = [self.input_ids] self.input_ids = [self.input_ids]
if self.input_embeds is not None:
self.input_embeds = [self.input_embeds]
def _normalize_single_inputs(self): def _normalize_single_inputs(self):
"""Normalize inputs for a single example.""" """Normalize inputs for a single example."""
...@@ -324,7 +326,9 @@ class GenerateReqInput: ...@@ -324,7 +326,9 @@ class GenerateReqInput:
new_rids = [f"{self.rid}_{i}" for i in range(num)] new_rids = [f"{self.rid}_{i}" for i in range(num)]
self.rid = new_rids self.rid = new_rids
elif isinstance(self.rid, list): elif isinstance(self.rid, list):
if len(self.rid) != num: # Note: the length of rid shall be the same as the batch_size,
# as the rid would be expanded for parallel sampling in tokenizer_manager
if len(self.rid) != self.batch_size:
raise ValueError( raise ValueError(
"The specified rids length mismatch with the batch_size for batch processing." "The specified rids length mismatch with the batch_size for batch processing."
) )
...@@ -400,6 +404,9 @@ class GenerateReqInput: ...@@ -400,6 +404,9 @@ class GenerateReqInput:
return GenerateReqInput( return GenerateReqInput(
text=self.text[i] if self.text is not None else None, text=self.text[i] if self.text is not None else None,
input_ids=self.input_ids[i] if self.input_ids is not None else None, input_ids=self.input_ids[i] if self.input_ids is not None else None,
input_embeds=(
self.input_embeds[i] if self.input_embeds is not None else None
),
image_data=self.image_data[i], image_data=self.image_data[i],
audio_data=self.audio_data[i], audio_data=self.audio_data[i],
sampling_params=self.sampling_params[i], sampling_params=self.sampling_params[i],
......
...@@ -67,6 +67,7 @@ suites = { ...@@ -67,6 +67,7 @@ suites = {
TestFile("test_hidden_states.py", 55), TestFile("test_hidden_states.py", 55),
TestFile("test_int8_kernel.py", 8), TestFile("test_int8_kernel.py", 8),
TestFile("test_input_embeddings.py", 38), TestFile("test_input_embeddings.py", 38),
TestFile("test_io_struct.py", 8),
TestFile("test_jinja_template_utils.py", 1), TestFile("test_jinja_template_utils.py", 1),
TestFile("test_metrics.py", 32), TestFile("test_metrics.py", 32),
TestFile("test_mla.py", 167), TestFile("test_mla.py", 167),
......
...@@ -159,6 +159,7 @@ class TestGenerateReqInputNormalization(CustomTestCase): ...@@ -159,6 +159,7 @@ class TestGenerateReqInputNormalization(CustomTestCase):
"""Test that when some batch items have images and others None, parallel expansion works correctly.""" """Test that when some batch items have images and others None, parallel expansion works correctly."""
req = copy.deepcopy(self.base_req) req = copy.deepcopy(self.base_req)
req.text = ["Prompt 1", "Prompt 2", "Prompt 3"] req.text = ["Prompt 1", "Prompt 2", "Prompt 3"]
req.rid = ["id1", "id2", "id3"]
req.image_data = [ req.image_data = [
["image1.jpg"], ["image1.jpg"],
None, None,
...@@ -311,6 +312,71 @@ class TestGenerateReqInputNormalization(CustomTestCase): ...@@ -311,6 +312,71 @@ class TestGenerateReqInputNormalization(CustomTestCase):
self.assertFalse(req.is_single) self.assertFalse(req.is_single)
self.assertEqual(req.batch_size, 2) self.assertEqual(req.batch_size, 2)
def test_input_embeds_with_parallel_sampling(self):
"""Test input_embeds normalization with parallel sampling (n > 1)."""
# Test single input_embeds with parallel sampling
req = GenerateReqInput(
input_embeds=[[0.1, 0.2]], # single embedding vector
sampling_params={"n": 2},
)
req.normalize_batch_and_arguments()
# Should be converted from single to batch and then expanded
self.assertFalse(req.is_single)
self.assertEqual(len(req.input_embeds), 2)
# Both should be the same input_embeds
self.assertEqual(req.input_embeds[0], [[0.1, 0.2]])
self.assertEqual(req.input_embeds[1], [[0.1, 0.2]])
# Test batch input_embeds with parallel sampling
req = GenerateReqInput(
input_embeds=[[[0.1, 0.2]], [[0.3, 0.4]]], sampling_params={"n": 3}
)
req.normalize_batch_and_arguments()
# Should be expanded
self.assertFalse(req.is_single)
self.assertEqual(len(req.input_embeds), 6)
# Check that the expansion is correct
expected_embeds = [[[0.1, 0.2]], [[0.3, 0.4]]] * 3
self.assertEqual(req.input_embeds, expected_embeds)
# Test with different n values per sample (should raise error)
req = GenerateReqInput(
input_embeds=[[[0.1, 0.2]], [[0.3, 0.4]]],
sampling_params=[{"n": 2}, {"n": 3}],
)
with self.assertRaises(ValueError):
req.normalize_batch_and_arguments()
def test_input_embeds_single_to_batch_conversion(self):
"""Test that single input_embeds are properly converted to batch when using parallel sampling."""
# Test the specific case that was fixed: single input_embeds with n > 1
req = GenerateReqInput(
input_embeds=[[0.1, 0.2, 0.3]], sampling_params={"n": 2} # Single embedding
)
req.normalize_batch_and_arguments()
# Should convert single to batch and then expand
self.assertFalse(req.is_single)
self.assertEqual(len(req.input_embeds), 2)
# Both should be the same single embedding
self.assertEqual(req.input_embeds[0], [[0.1, 0.2, 0.3]])
self.assertEqual(req.input_embeds[1], [[0.1, 0.2, 0.3]])
# Test with higher n value
req = GenerateReqInput(input_embeds=[[0.1, 0.2, 0.3]], sampling_params={"n": 5})
req.normalize_batch_and_arguments()
self.assertFalse(req.is_single)
self.assertEqual(len(req.input_embeds), 5)
# All should be the same
for i in range(5):
self.assertEqual(req.input_embeds[i], [[0.1, 0.2, 0.3]])
def test_lora_path_normalization(self): def test_lora_path_normalization(self):
"""Test normalization of lora_path.""" """Test normalization of lora_path."""
# Test single lora_path with batch input # Test single lora_path with batch input
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment