Unverified Commit 21ba3a88 authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

Remove useless variables in infer_batch.py (#651)

parent 9c5cac24
...@@ -270,6 +270,7 @@ class Batch: ...@@ -270,6 +270,7 @@ class Batch:
prefix_lens: torch.Tensor = None prefix_lens: torch.Tensor = None
position_ids_offsets: torch.Tensor = None position_ids_offsets: torch.Tensor = None
out_cache_loc: torch.Tensor = None out_cache_loc: torch.Tensor = None
extend_num_tokens: int = None
# For processing logprobs # For processing logprobs
return_logprob: bool = False return_logprob: bool = False
...@@ -280,10 +281,6 @@ class Batch: ...@@ -280,10 +281,6 @@ class Batch:
image_sizes: List[List[int]] = None image_sizes: List[List[int]] = None
image_offsets: List[int] = None image_offsets: List[int] = None
# Other arguments for control
output_ids: torch.Tensor = None
extend_num_tokens: int = None
# Batched sampling params # Batched sampling params
temperatures: torch.Tensor = None temperatures: torch.Tensor = None
top_ps: torch.Tensor = None top_ps: torch.Tensor = None
...@@ -820,6 +817,7 @@ def init_flashinfer_args( ...@@ -820,6 +817,7 @@ def init_flashinfer_args(
prefix_lens, prefix_lens,
flashinfer_decode_wrapper, flashinfer_decode_wrapper,
): ):
"""Init auxiliary variables for FlashInfer attention backend."""
num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size) num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
head_dim = model_runner.model_config.head_dim head_dim = model_runner.model_config.head_dim
...@@ -885,6 +883,7 @@ def init_flashinfer_args( ...@@ -885,6 +883,7 @@ def init_flashinfer_args(
def init_triton_args(forward_mode, seq_lens, prefix_lens): def init_triton_args(forward_mode, seq_lens, prefix_lens):
"""Init auxiliary variables for triton attention backend."""
batch_size = len(seq_lens) batch_size = len(seq_lens)
max_seq_len = int(torch.max(seq_lens)) max_seq_len = int(torch.max(seq_lens))
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda") start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment