Unverified Commit f352b793 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Minor Optimizations in Schedule Batch (#8724)


Co-authored-by: default avatarSuruchi Shah <surshah@linkedin.com>
parent 6642e3a2
......@@ -37,6 +37,7 @@ import logging
import threading
from enum import Enum, auto
from http import HTTPStatus
from itertools import chain
from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union
import numpy as np
......@@ -1145,9 +1146,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to(
self.device, non_blocking=True
)
input_ids_tensor = torch.tensor(sum(input_ids, []), dtype=torch.int64).to(
self.device, non_blocking=True
)
input_ids_tensor = torch.tensor(
list(chain.from_iterable(input_ids)), dtype=torch.int64
).to(self.device, non_blocking=True)
seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
self.device, non_blocking=True
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment