"docs/vscode:/vscode.git/clone" did not exist on "c6d0dff4a39137ff206af76b655f7bcf3cadaf32"
Unverified Commit f352b793 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Minor Optimizations in Schedule Batch (#8724)


Co-authored-by: default avatarSuruchi Shah <surshah@linkedin.com>
parent 6642e3a2
...@@ -37,6 +37,7 @@ import logging ...@@ -37,6 +37,7 @@ import logging
import threading import threading
from enum import Enum, auto from enum import Enum, auto
from http import HTTPStatus from http import HTTPStatus
from itertools import chain
from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union
import numpy as np import numpy as np
...@@ -1145,9 +1146,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1145,9 +1146,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to( req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to(
self.device, non_blocking=True self.device, non_blocking=True
) )
input_ids_tensor = torch.tensor(sum(input_ids, []), dtype=torch.int64).to( input_ids_tensor = torch.tensor(
self.device, non_blocking=True list(chain.from_iterable(input_ids)), dtype=torch.int64
) ).to(self.device, non_blocking=True)
seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to( seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
self.device, non_blocking=True self.device, non_blocking=True
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment