Unverified Commit 2bd18e2d authored by Yang Zheng's avatar Yang Zheng Committed by GitHub
Browse files

Memory pool: Minor optimize to avoid to (#2901)

parent 83452dbb
...@@ -668,7 +668,7 @@ class ScheduleBatch: ...@@ -668,7 +668,7 @@ class ScheduleBatch:
or len(req.prefix_indices) >= im.num_image_tokens or len(req.prefix_indices) >= im.num_image_tokens
) )
self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int32).to( self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int64).to(
self.device, non_blocking=True self.device, non_blocking=True
) )
...@@ -702,7 +702,7 @@ class ScheduleBatch: ...@@ -702,7 +702,7 @@ class ScheduleBatch:
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to( self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
self.device, non_blocking=True self.device, non_blocking=True
) )
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to( self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
self.device, non_blocking=True self.device, non_blocking=True
) )
...@@ -778,10 +778,10 @@ class ScheduleBatch: ...@@ -778,10 +778,10 @@ class ScheduleBatch:
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to( self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
self.device, non_blocking=True self.device, non_blocking=True
) )
self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int32).to( self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int64).to(
self.device, non_blocking=True self.device, non_blocking=True
) )
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to( self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
self.device, non_blocking=True self.device, non_blocking=True
) )
self.input_embeds = ( self.input_embeds = (
...@@ -1014,9 +1014,9 @@ class ScheduleBatch: ...@@ -1014,9 +1014,9 @@ class ScheduleBatch:
def prepare_for_idle(self): def prepare_for_idle(self):
self.forward_mode = ForwardMode.IDLE self.forward_mode = ForwardMode.IDLE
self.input_ids = torch.empty(0, dtype=torch.int32, device=self.device) self.input_ids = torch.empty(0, dtype=torch.int32, device=self.device)
self.seq_lens = torch.empty(0, dtype=torch.int32, device=self.device) self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
self.out_cache_loc = torch.empty(0, dtype=torch.int32, device=self.device) self.out_cache_loc = torch.empty(0, dtype=torch.int32, device=self.device)
self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device) self.req_pool_indices = torch.empty(0, dtype=torch.int64, device=self.device)
self.seq_lens_sum = 0 self.seq_lens_sum = 0
self.extend_num_tokens = 0 self.extend_num_tokens = 0
self.sampling_info = SamplingBatchInfo.from_schedule_batch( self.sampling_info = SamplingBatchInfo.from_schedule_batch(
...@@ -1084,7 +1084,7 @@ class ScheduleBatch: ...@@ -1084,7 +1084,7 @@ class ScheduleBatch:
self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices] self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]
self.reqs = [self.reqs[i] for i in keep_indices] self.reqs = [self.reqs[i] for i in keep_indices]
new_indices = torch.tensor(keep_indices, dtype=torch.int32).to( new_indices = torch.tensor(keep_indices, dtype=torch.int64).to(
self.device, non_blocking=True self.device, non_blocking=True
) )
self.req_pool_indices = self.req_pool_indices[new_indices] self.req_pool_indices = self.req_pool_indices[new_indices]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment