Unverified Commit 3118f633 authored by sroy745's avatar sroy745 Committed by GitHub
Browse files

[Bugfix] [Encoder-Decoder] Bugfix for encoder specific metadata construction...

[Bugfix] [Encoder-Decoder] Bugfix for encoder specific metadata construction during decode of encoder-decoder models.  (#8545)
parent 4c34ce89
...@@ -273,7 +273,8 @@ def test_prepare_prompt(batch_size): ...@@ -273,7 +273,8 @@ def test_prepare_prompt(batch_size):
"unsupported for encoder/ " "unsupported for encoder/ "
"decoder models") "decoder models")
@pytest.mark.parametrize("batch_size", BATCH_SIZES) @pytest.mark.parametrize("batch_size", BATCH_SIZES)
def test_prepare_decode(batch_size): @pytest.mark.parametrize("multiple_seqs_per_seq_group", [True, False])
def test_prepare_decode(batch_size, multiple_seqs_per_seq_group):
''' '''
Test the ability of the encoder/decoder model runner subclass to Test the ability of the encoder/decoder model runner subclass to
produce decode-phase model inputs & attention metadata. produce decode-phase model inputs & attention metadata.
...@@ -288,6 +289,7 @@ def test_prepare_decode(batch_size): ...@@ -288,6 +289,7 @@ def test_prepare_decode(batch_size):
Arguments: Arguments:
* batch_size * batch_size
* multiple_seqs_per_seq_group
* backend_name: The attention backend under test * backend_name: The attention backend under test
* enforce_eager: Enforce eager mode if True (i.e. no CUDAGraph) * enforce_eager: Enforce eager mode if True (i.e. no CUDAGraph)
''' '''
...@@ -305,22 +307,29 @@ def test_prepare_decode(batch_size): ...@@ -305,22 +307,29 @@ def test_prepare_decode(batch_size):
seq_lens: List[int] = [] seq_lens: List[int] = []
encoder_seq_lens: List[int] = [] encoder_seq_lens: List[int] = []
seq_group_metadata_list: List[SequenceGroupMetadata] = [] seq_group_metadata_list: List[SequenceGroupMetadata] = []
block_tables = {0: [1]} block_tables = {
0: [1],
1: [3]
} if multiple_seqs_per_seq_group else {
0: [1]
}
cross_block_table = [2] cross_block_table = [2]
for i in range(batch_size): for i in range(batch_size):
# make sure all tokens fit into one block # make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1 seq_len = i % (model_runner.block_size - 1) + 1
seq_lens.append(seq_len)
seq_data = SequenceData( seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(seq_len)))) array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(seq_len))))
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
encoder_seq_lens.append(encoder_seq_len)
encoder_seq_data = SequenceData( encoder_seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len)))) array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len))))
seq_group_metadata = SequenceGroupMetadata( seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}", request_id=f"test_{i}",
is_prompt=False, is_prompt=False,
seq_data={0: seq_data}, seq_data={
0: seq_data,
1: seq_data
} if multiple_seqs_per_seq_group else {0: seq_data},
sampling_params=SamplingParams(temperature=0), sampling_params=SamplingParams(temperature=0),
block_tables=block_tables, block_tables=block_tables,
encoder_seq_data=encoder_seq_data, encoder_seq_data=encoder_seq_data,
...@@ -328,6 +337,10 @@ def test_prepare_decode(batch_size): ...@@ -328,6 +337,10 @@ def test_prepare_decode(batch_size):
) )
assert seq_group_metadata.token_chunk_size == 1 assert seq_group_metadata.token_chunk_size == 1
seq_group_metadata_list.append(seq_group_metadata) seq_group_metadata_list.append(seq_group_metadata)
seq_lens.extend(
[seq_len for _ in range(len(seq_group_metadata.seq_data))])
encoder_seq_lens.extend(
[encoder_seq_len for _ in range(len(seq_group_metadata.seq_data))])
# Build # Build
# * Decoder model inputs # * Decoder model inputs
...@@ -398,19 +411,24 @@ def test_prepare_decode(batch_size): ...@@ -398,19 +411,24 @@ def test_prepare_decode(batch_size):
# Verify block tables are correct for prompts # Verify block tables are correct for prompts
# - Decoder self-attention # - Decoder self-attention
expected = torch.tensor( flattened_block_tables = [
[block_tables[0] for _ in range(len(seq_group_metadata_list))], block_table for block_table in block_tables.values()
dtype=torch.int32, ]
device=model_runner.device) expected = torch.tensor(flattened_block_tables *
len(seq_group_metadata_list),
dtype=torch.int32,
device=model_runner.device)
assert torch.equal( assert torch.equal(
attn_metadata.block_tables, attn_metadata.block_tables,
expected, expected,
) )
# - Encoder/decoder cross-attention # - Encoder/decoder cross-attention
expected = torch.tensor( expected = torch.tensor([
[cross_block_table for _ in range(len(seq_group_metadata_list))], cross_block_table for seq_group_metadata in seq_group_metadata_list
dtype=torch.int32, for _ in range(len(seq_group_metadata.seq_data))
device=model_runner.device) ],
dtype=torch.int32,
device=model_runner.device)
assert torch.equal( assert torch.equal(
attn_metadata.cross_block_tables, attn_metadata.cross_block_tables,
expected, expected,
...@@ -474,7 +492,8 @@ def test_prepare_decode(batch_size): ...@@ -474,7 +492,8 @@ def test_prepare_decode(batch_size):
@pytest.mark.parametrize("batch_size", list(range(1, 257))) @pytest.mark.parametrize("batch_size", list(range(1, 257)))
def test_prepare_decode_cuda_graph(batch_size): @pytest.mark.parametrize("multiple_seqs_per_seq_group", [True, False])
def test_prepare_decode_cuda_graph(batch_size, multiple_seqs_per_seq_group):
""" """
Tests that for encoder-decoder models with CUDA Graph capture and replay Tests that for encoder-decoder models with CUDA Graph capture and replay
enabled, the tensors used during the decode phase are correctly padded enabled, the tensors used during the decode phase are correctly padded
...@@ -489,32 +508,45 @@ def test_prepare_decode_cuda_graph(batch_size): ...@@ -489,32 +508,45 @@ def test_prepare_decode_cuda_graph(batch_size):
enable_chunked_prefill=False, enable_chunked_prefill=False,
enforce_eager=False, enforce_eager=False,
) )
block_tables = {
0: [1],
1: [3]
} if multiple_seqs_per_seq_group else {
0: [1]
}
seq_lens: List[int] = [] seq_lens: List[int] = []
encoder_seq_lens: List[int] = [] encoder_seq_lens: List[int] = []
seq_group_metadata_list: List[SequenceGroupMetadata] = [] seq_group_metadata_list: List[SequenceGroupMetadata] = []
block_tables = {0: [1]}
cross_block_table = [2] cross_block_table = [2]
expanded_batch_size = 0
for i in range(batch_size): for i in range(batch_size):
# make sure all tokens fit into one block # make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1 seq_len = i % (model_runner.block_size - 1) + 1
seq_lens.append(seq_len)
seq_data = SequenceData( seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(seq_len)))) array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(seq_len))))
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
encoder_seq_lens.append(encoder_seq_len)
encoder_seq_data = SequenceData( encoder_seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len)))) array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len))))
seq_group_metadata = SequenceGroupMetadata( seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}", request_id=f"test_{i}",
is_prompt=False, is_prompt=False,
seq_data={0: seq_data}, seq_data={
0: seq_data,
1: seq_data
} if multiple_seqs_per_seq_group else {0: seq_data},
sampling_params=SamplingParams(temperature=0), sampling_params=SamplingParams(temperature=0),
block_tables=block_tables, block_tables=block_tables,
encoder_seq_data=encoder_seq_data, encoder_seq_data=encoder_seq_data,
cross_block_table=cross_block_table, cross_block_table=cross_block_table,
) )
assert seq_group_metadata.token_chunk_size == 1 assert seq_group_metadata.token_chunk_size == 1
seq_lens.extend(
[seq_len for _ in range(len(seq_group_metadata.seq_data))])
encoder_seq_lens.extend(
[encoder_seq_len for _ in range(len(seq_group_metadata.seq_data))])
expanded_batch_size = expanded_batch_size + len(
seq_group_metadata.seq_data)
seq_group_metadata_list.append(seq_group_metadata) seq_group_metadata_list.append(seq_group_metadata)
model_input = model_runner.prepare_model_input(seq_group_metadata_list) model_input = model_runner.prepare_model_input(seq_group_metadata_list)
...@@ -530,8 +562,8 @@ def test_prepare_decode_cuda_graph(batch_size): ...@@ -530,8 +562,8 @@ def test_prepare_decode_cuda_graph(batch_size):
# With CUDA Graph capture and replay enabled, the decoder and encoder # With CUDA Graph capture and replay enabled, the decoder and encoder
# input sequences will be padded. Create the expected padded tensors # input sequences will be padded. Create the expected padded tensors
# accordingly. # accordingly.
graph_batch_size = _get_graph_batch_size(batch_size) graph_batch_size = _get_graph_batch_size(expanded_batch_size)
cuda_graph_pad_size = graph_batch_size - batch_size cuda_graph_pad_size = graph_batch_size - expanded_batch_size
padded_seq_lens = seq_lens + list(itertools.repeat(1, cuda_graph_pad_size)) padded_seq_lens = seq_lens + list(itertools.repeat(1, cuda_graph_pad_size))
padded_encoder_seq_lens = encoder_seq_lens + list( padded_encoder_seq_lens = encoder_seq_lens + list(
itertools.repeat(1, cuda_graph_pad_size)) itertools.repeat(1, cuda_graph_pad_size))
...@@ -560,10 +592,13 @@ def test_prepare_decode_cuda_graph(batch_size): ...@@ -560,10 +592,13 @@ def test_prepare_decode_cuda_graph(batch_size):
# Verify block tables are correct for prompts # Verify block tables are correct for prompts
# - Decoder self-attention. Pad the block tables as expected. # - Decoder self-attention. Pad the block tables as expected.
expected = [block_tables[0] for _ in range(batch_size)] flattened_block_tables = [
expected.extend([[] for _ in range(cuda_graph_pad_size)]) block_table for _ in range(len(seq_group_metadata_list))
for block_table in block_tables.values()
]
flattened_block_tables.extend([[] for _ in range(cuda_graph_pad_size)])
expected = make_tensor_with_pad( expected = make_tensor_with_pad(
expected, flattened_block_tables,
max_len=64, max_len=64,
pad=0, pad=0,
dtype=torch.int32, dtype=torch.int32,
...@@ -575,7 +610,10 @@ def test_prepare_decode_cuda_graph(batch_size): ...@@ -575,7 +610,10 @@ def test_prepare_decode_cuda_graph(batch_size):
) )
# - Encoder/decoder cross-attention. Pad the cross-attention block tables # - Encoder/decoder cross-attention. Pad the cross-attention block tables
# as expected. # as expected.
expected = [cross_block_table for _ in range(len(seq_group_metadata_list))] expected = [
cross_block_table for seq_group_metadata in seq_group_metadata_list
for _ in range(len(seq_group_metadata.seq_data))
]
expected.extend([[] for _ in range(cuda_graph_pad_size)]) expected.extend([[] for _ in range(cuda_graph_pad_size)])
expected = make_tensor_with_pad( expected = make_tensor_with_pad(
expected, expected,
......
...@@ -435,18 +435,18 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]): ...@@ -435,18 +435,18 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
encoder_input_tokens_tensor = self._empty_long_tensor() encoder_input_tokens_tensor = self._empty_long_tensor()
encoder_input_positions_tensor = self._empty_long_tensor() encoder_input_positions_tensor = self._empty_long_tensor()
cross_slot_mapping_tensor = self._empty_long_tensor() cross_slot_mapping_tensor = self._empty_long_tensor()
# Extract cross-attention block tables & # Extract cross-attention block tables &
# seq len from each sequence group metadata. # seq len from each sequence group metadata.
# Cross-attention block tables are empty # Cross-attention block tables are empty
# during vLLM memory profiling. # during vLLM memory profiling.
cross_block_tables = [] cross_block_tables = []
for seq_group_metadata in seq_group_metadata_list: for seq_group_metadata in seq_group_metadata_list:
encoder_seq_lens.append( for _ in range(len(seq_group_metadata.seq_data)):
seq_group_metadata.encoder_seq_data.get_len()) encoder_seq_lens.append(
cross_block_table = seq_group_metadata.cross_block_table seq_group_metadata.encoder_seq_data.get_len())
cross_block_tables.append([] if ( cross_block_table = seq_group_metadata.cross_block_table
cross_block_table is None) else cross_block_table) cross_block_tables.append([] if (
cross_block_table is None) else cross_block_table)
if (model_input.attn_metadata is not None if (model_input.attn_metadata is not None
and model_input.attn_metadata.use_cuda_graph): and model_input.attn_metadata.use_cuda_graph):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment