Unverified Commit 2ac4d5e2 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

Replace DtypeTensor (#1123)

parent 3302f0ae
...@@ -228,15 +228,25 @@ class Worker: ...@@ -228,15 +228,25 @@ class Worker:
input_positions = _pad_to_alignment(input_positions, multiple_of=8) input_positions = _pad_to_alignment(input_positions, multiple_of=8)
# Convert to tensors. # Convert to tensors.
tokens_tensor = torch.cuda.LongTensor(input_tokens) tokens_tensor = torch.tensor(input_tokens,
positions_tensor = torch.cuda.LongTensor(input_positions) dtype=torch.long,
slot_mapping_tensor = torch.cuda.IntTensor(slot_mapping) device="cuda")
context_lens_tensor = torch.cuda.IntTensor(context_lens) positions_tensor = torch.tensor(input_positions,
dtype=torch.long,
device="cuda")
slot_mapping_tensor = torch.tensor(slot_mapping,
dtype=torch.int,
device="cuda")
context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int,
device="cuda")
padded_block_tables = [ padded_block_tables = [
_pad_to_max(block_table, max_num_blocks_per_seq) _pad_to_max(block_table, max_num_blocks_per_seq)
for block_table in generation_block_tables for block_table in generation_block_tables
] ]
block_tables_tensor = torch.cuda.IntTensor(padded_block_tables) block_tables_tensor = torch.tensor(padded_block_tables,
dtype=torch.int,
device="cuda")
seq_data: Dict[int, SequenceData] = {} seq_data: Dict[int, SequenceData] = {}
for seq_group_metadata in seq_group_metadata_list: for seq_group_metadata in seq_group_metadata_list:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment