Unverified Commit 45473d4b authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Make input_ids a torch.Tensor (#1568)

parent 114bbc86
...@@ -514,9 +514,10 @@ class ScheduleBatch: ...@@ -514,9 +514,10 @@ class ScheduleBatch:
pt += req.extend_input_len pt += req.extend_input_len
# Set fields # Set fields
self.input_ids = sum(input_ids, []) with out_cache_loc.device:
self.req_pool_indices = torch.tensor(req_pool_indices, device="cuda") self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32)
self.seq_lens = torch.tensor(seq_lens, device="cuda") self.req_pool_indices = torch.tensor(req_pool_indices)
self.seq_lens = torch.tensor(seq_lens)
self.extend_num_tokens = extend_num_tokens self.extend_num_tokens = extend_num_tokens
self.out_cache_loc = out_cache_loc self.out_cache_loc = out_cache_loc
...@@ -536,7 +537,7 @@ class ScheduleBatch: ...@@ -536,7 +537,7 @@ class ScheduleBatch:
req.fill_ids = req.origin_input_ids + req.output_ids req.fill_ids = req.origin_input_ids + req.output_ids
req.extend_input_len = 1 req.extend_input_len = 1
input_ids = self.input_ids + running_batch.input_ids input_ids = torch.cat([self.input_ids, running_batch.input_ids])
out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc]) out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
extend_num_tokens = self.extend_num_tokens + running_bs extend_num_tokens = self.extend_num_tokens + running_bs
...@@ -722,7 +723,9 @@ class ScheduleBatch: ...@@ -722,7 +723,9 @@ class ScheduleBatch:
for r in self.reqs for r in self.reqs
] ]
self.input_ids = input_ids self.input_ids = torch.tensor(
input_ids, dtype=torch.int32, device=self.seq_lens.device
)
self.seq_lens.add_(1) self.seq_lens.add_(1)
# Alloc mem # Alloc mem
...@@ -824,7 +827,7 @@ class ModelWorkerBatch: ...@@ -824,7 +827,7 @@ class ModelWorkerBatch:
# The forward mode # The forward mode
forward_mode: ForwardMode forward_mode: ForwardMode
# The input ids # The input ids
input_ids: List[int] input_ids: torch.Tensor
# The indices of requests in the req_to_token_pool # The indices of requests in the req_to_token_pool
req_pool_indices: torch.Tensor req_pool_indices: torch.Tensor
# The sequence length # The sequence length
......
...@@ -30,6 +30,7 @@ class ReqToTokenPool: ...@@ -30,6 +30,7 @@ class ReqToTokenPool:
def __init__(self, size: int, max_context_len: int, device: str): def __init__(self, size: int, max_context_len: int, device: str):
self.size = size self.size = size
self.max_context_len = max_context_len self.max_context_len = max_context_len
self.device = device
self.free_slots = list(range(size)) self.free_slots = list(range(size))
self.req_to_token = torch.empty( self.req_to_token = torch.empty(
(size, max_context_len), dtype=torch.int32, device=device (size, max_context_len), dtype=torch.int32, device=device
......
...@@ -123,7 +123,7 @@ class ForwardBatch: ...@@ -123,7 +123,7 @@ class ForwardBatch:
ret = cls( ret = cls(
forward_mode=batch.forward_mode, forward_mode=batch.forward_mode,
batch_size=len(batch.seq_lens), batch_size=len(batch.seq_lens),
input_ids=torch.tensor(batch.input_ids, dtype=torch.int32, device=device), input_ids=batch.input_ids,
req_pool_indices=batch.req_pool_indices, req_pool_indices=batch.req_pool_indices,
seq_lens=batch.seq_lens, seq_lens=batch.seq_lens,
out_cache_loc=batch.out_cache_loc, out_cache_loc=batch.out_cache_loc,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment