Unverified Commit 392f2863 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Add dtype for more operations (#1705)

parent 6d0fa73e
...@@ -537,8 +537,8 @@ class ScheduleBatch: ...@@ -537,8 +537,8 @@ class ScheduleBatch:
# Set fields # Set fields
with out_cache_loc.device: with out_cache_loc.device:
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32) self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32)
self.req_pool_indices = torch.tensor(req_pool_indices) self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int32)
self.seq_lens = torch.tensor(seq_lens) self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32)
self.extend_num_tokens = extend_num_tokens self.extend_num_tokens = extend_num_tokens
self.out_cache_loc = out_cache_loc self.out_cache_loc = out_cache_loc
......
...@@ -145,8 +145,9 @@ class ForwardBatch: ...@@ -145,8 +145,9 @@ class ForwardBatch:
], ],
axis=0, axis=0,
), ),
dtype=torch.int64,
device=device, device=device,
).to(torch.int64) )
ret.image_inputs = batch.image_inputs ret.image_inputs = batch.image_inputs
ret.extend_seq_lens = torch.tensor(batch.extend_seq_lens, device=device) ret.extend_seq_lens = torch.tensor(batch.extend_seq_lens, device=device)
......
...@@ -57,7 +57,7 @@ class SamplingBatchInfo: ...@@ -57,7 +57,7 @@ class SamplingBatchInfo:
[r.sampling_params.top_p for r in reqs], dtype=torch.float [r.sampling_params.top_p for r in reqs], dtype=torch.float
) )
top_ks = torch.tensor( top_ks = torch.tensor(
[r.sampling_params.top_k for r in reqs], dtype=torch.int [r.sampling_params.top_k for r in reqs], dtype=torch.int32
) )
min_ps = torch.tensor( min_ps = torch.tensor(
[r.sampling_params.min_p for r in reqs], dtype=torch.float [r.sampling_params.min_p for r in reqs], dtype=torch.float
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment