Unverified Commit 21ba3a88 authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

Remove useless variables in infer_batch.py (#651)

parent 9c5cac24
......@@ -270,6 +270,7 @@ class Batch:
prefix_lens: torch.Tensor = None
position_ids_offsets: torch.Tensor = None
out_cache_loc: torch.Tensor = None
extend_num_tokens: int = None
# For processing logprobs
return_logprob: bool = False
......@@ -280,10 +281,6 @@ class Batch:
image_sizes: List[List[int]] = None
image_offsets: List[int] = None
# Other arguments for control
output_ids: torch.Tensor = None
extend_num_tokens: int = None
# Batched sampling params
temperatures: torch.Tensor = None
top_ps: torch.Tensor = None
......@@ -820,6 +817,7 @@ def init_flashinfer_args(
prefix_lens,
flashinfer_decode_wrapper,
):
"""Init auxiliary variables for FlashInfer attention backend."""
num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
head_dim = model_runner.model_config.head_dim
......@@ -885,6 +883,7 @@ def init_flashinfer_args(
def init_triton_args(forward_mode, seq_lens, prefix_lens):
"""Init auxiliary variables for triton attention backend."""
batch_size = len(seq_lens)
max_seq_len = int(torch.max(seq_lens))
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment