Unverified Commit abfc4f33 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Misc] Use dataclass for InputMetadata (#3452)


Co-authored-by: default avataryoukaichao <youkaichao@126.com>
parent 6b78837b
...@@ -2,7 +2,6 @@ import contextlib ...@@ -2,7 +2,6 @@ import contextlib
import io import io
import os import os
import re import re
import shutil
import subprocess import subprocess
import warnings import warnings
from pathlib import Path from pathlib import Path
......
from dataclasses import dataclass
from typing import Optional from typing import Optional
import torch import torch
@dataclass
class InputMetadata: class InputMetadata:
"""Metadata for input sequences. Used in PagedAttention. """Metadata for input sequences. Used in PagedAttention.
...@@ -15,40 +17,17 @@ class InputMetadata: ...@@ -15,40 +17,17 @@ class InputMetadata:
kv_cache_dtype: Data type to store kv cache. kv_cache_dtype: Data type to store kv cache.
""" """
def __init__( is_prompt: bool
self, slot_mapping: torch.Tensor
is_prompt: bool, prompt_lens: Optional[torch.Tensor]
slot_mapping: torch.Tensor, max_seq_len: Optional[int]
prompt_lens: Optional[torch.Tensor], start_loc: Optional[torch.Tensor]
max_seq_len: Optional[int], max_context_len: Optional[int]
start_loc: Optional[torch.Tensor], context_lens: Optional[torch.Tensor]
max_context_len: Optional[int], block_tables: Optional[torch.Tensor]
context_lens: Optional[torch.Tensor], use_cuda_graph: bool
block_tables: Optional[torch.Tensor], kv_cache_dtype: str
use_cuda_graph: bool,
kv_cache_dtype: str,
) -> None:
self.is_prompt = is_prompt
self.prompt_lens = prompt_lens
self.max_seq_len = max_seq_len
self.start_loc = start_loc
self.max_context_len = max_context_len
self.slot_mapping = slot_mapping
self.context_lens = context_lens
self.block_tables = block_tables
self.use_cuda_graph = use_cuda_graph
self.kv_cache_dtype = kv_cache_dtype
# Set during the execution of the first attention op. def __post_init__(self):
# FIXME(woosuk): This is a hack. # will not appear in the __repr__ and __init__
self.attn_bias = None self.attn_bias = None
def __repr__(self) -> str:
return ("InputMetadata("
f"is_prompt={self.is_prompt}, "
f"max_context_len={self.max_context_len}, "
f"slot_mapping={self.slot_mapping}, "
f"context_lens={self.context_lens}, "
f"block_tables={self.block_tables}, "
f"use_cuda_graph={self.use_cuda_graph}, "
f"kv_cache_dtype={self.kv_cache_dtype})")
import contextlib import contextlib
import dataclasses
import time import time
from typing import Dict, List, Optional, Tuple, Set, Union from typing import Dict, List, Optional, Tuple, Set, Union
...@@ -521,45 +522,27 @@ class ModelRunner: ...@@ -521,45 +522,27 @@ class ModelRunner:
metadata_dict = { metadata_dict = {
"input_tokens": input_tokens, "input_tokens": input_tokens,
"input_positions": input_positions, "input_positions": input_positions,
"is_prompt": input_metadata.is_prompt,
"slot_mapping": input_metadata.slot_mapping,
"prompt_lens": input_metadata.prompt_lens,
"max_seq_len": input_metadata.max_seq_len,
"start_loc": input_metadata.start_loc,
"max_context_len": input_metadata.max_context_len,
"context_lens": input_metadata.context_lens,
"block_tables": input_metadata.block_tables,
"use_cuda_graph": input_metadata.use_cuda_graph,
"kv_cache_dtype": input_metadata.kv_cache_dtype,
"selected_token_indices": "selected_token_indices":
sampling_metadata.selected_token_indices, sampling_metadata.selected_token_indices,
"lora_requests": lora_requests, "lora_requests": lora_requests,
"lora_mapping": lora_mapping, "lora_mapping": lora_mapping,
} }
metadata_dict.update(dataclasses.asdict(input_metadata))
broadcast_tensor_dict(metadata_dict, src=0) broadcast_tensor_dict(metadata_dict, src=0)
else: else:
metadata_dict = broadcast_tensor_dict(src=0) metadata_dict = broadcast_tensor_dict(src=0)
input_tokens = metadata_dict["input_tokens"] input_tokens = metadata_dict.pop("input_tokens")
input_positions = metadata_dict["input_positions"] input_positions = metadata_dict.pop("input_positions")
lora_mapping = metadata_dict["lora_mapping"] selected_token_indices = metadata_dict.pop(
lora_requests = metadata_dict["lora_requests"] "selected_token_indices")
input_metadata = InputMetadata( lora_mapping = metadata_dict.pop("lora_mapping")
is_prompt=metadata_dict["is_prompt"], lora_requests = metadata_dict.pop("lora_requests")
slot_mapping=metadata_dict["slot_mapping"], input_metadata = InputMetadata(**metadata_dict)
prompt_lens=metadata_dict["prompt_lens"],
max_seq_len=metadata_dict["max_seq_len"],
start_loc=metadata_dict["start_loc"],
max_context_len=metadata_dict["max_context_len"],
context_lens=metadata_dict["context_lens"],
block_tables=metadata_dict["block_tables"],
use_cuda_graph=metadata_dict["use_cuda_graph"],
kv_cache_dtype=metadata_dict["kv_cache_dtype"],
)
sampling_metadata = SamplingMetadata( sampling_metadata = SamplingMetadata(
seq_groups=None, seq_groups=None,
seq_data=None, seq_data=None,
prompt_lens=None, prompt_lens=None,
selected_token_indices=metadata_dict["selected_token_indices"], selected_token_indices=selected_token_indices,
categorized_sample_indices=None, categorized_sample_indices=None,
generators=None, generators=None,
perform_sampling=False, perform_sampling=False,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment