Unverified Commit 319cb1e3 authored by Lukas Geiger's avatar Lukas Geiger Committed by GitHub
Browse files

[Core] Batch multi modal input using pinned memory (#19169)


Signed-off-by: default avatarLukas Geiger <lukas.geiger94@gmail.com>
parent 1efef716
...@@ -680,7 +680,8 @@ class MultiModalKwargs(UserDict[str, NestedTensors]): ...@@ -680,7 +680,8 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
return self._items_by_modality.keys() return self._items_by_modality.keys()
@staticmethod @staticmethod
def _try_stack(nested_tensors: NestedTensors) -> NestedTensors: def _try_stack(nested_tensors: NestedTensors,
pin_memory: bool = False) -> NestedTensors:
""" """
Stack the inner dimensions that have the same shape in Stack the inner dimensions that have the same shape in
a nested list of tensors. a nested list of tensors.
...@@ -697,7 +698,9 @@ class MultiModalKwargs(UserDict[str, NestedTensors]): ...@@ -697,7 +698,9 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
if isinstance(nested_tensors, (int, float)): if isinstance(nested_tensors, (int, float)):
return torch.tensor(nested_tensors) return torch.tensor(nested_tensors)
stacked = [MultiModalKwargs._try_stack(t) for t in nested_tensors] stacked = [
MultiModalKwargs._try_stack(t, pin_memory) for t in nested_tensors
]
if not is_list_of(stacked, torch.Tensor, check="all"): if not is_list_of(stacked, torch.Tensor, check="all"):
# Only tensors (not lists) can be stacked. # Only tensors (not lists) can be stacked.
return stacked return stacked
...@@ -713,10 +716,16 @@ class MultiModalKwargs(UserDict[str, NestedTensors]): ...@@ -713,10 +716,16 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
# The tensors have incompatible shapes and can't be stacked. # The tensors have incompatible shapes and can't be stacked.
return tensors_ return tensors_
return torch.stack(tensors_) outputs = torch.empty(len(tensors_),
*tensors_[0].shape,
dtype=tensors_[0].dtype,
device=tensors_[0].device,
pin_memory=pin_memory)
return torch.stack(tensors_, out=outputs)
@staticmethod @staticmethod
def batch(inputs_list: list["MultiModalKwargs"]) -> BatchedTensorInputs: def batch(inputs_list: list["MultiModalKwargs"],
pin_memory: bool = False) -> BatchedTensorInputs:
""" """
Batch multiple inputs together into a dictionary. Batch multiple inputs together into a dictionary.
...@@ -738,7 +747,7 @@ class MultiModalKwargs(UserDict[str, NestedTensors]): ...@@ -738,7 +747,7 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
item_lists[k].append(v) item_lists[k].append(v)
return { return {
k: MultiModalKwargs._try_stack(item_list) k: MultiModalKwargs._try_stack(item_list, pin_memory)
for k, item_list in item_lists.items() for k, item_list in item_lists.items()
} }
......
...@@ -962,7 +962,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -962,7 +962,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
encoder_outputs = [] encoder_outputs = []
for grouped_mm_inputs in grouped_mm_inputs_list: for grouped_mm_inputs in grouped_mm_inputs_list:
batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs) batched_mm_inputs = MultiModalKwargs.batch(
grouped_mm_inputs, pin_memory=self.pin_memory)
batched_mm_inputs = MultiModalKwargs.as_kwargs( batched_mm_inputs = MultiModalKwargs.as_kwargs(
batched_mm_inputs, batched_mm_inputs,
device=self.device, device=self.device,
...@@ -1989,7 +1990,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1989,7 +1990,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
).multi_modal_data ).multi_modal_data
batched_dummy_mm_inputs = MultiModalKwargs.batch( batched_dummy_mm_inputs = MultiModalKwargs.batch(
[dummy_mm_kwargs] * max_num_mm_items) [dummy_mm_kwargs] * max_num_mm_items,
pin_memory=self.pin_memory)
batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs( batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs(
batched_dummy_mm_inputs, batched_dummy_mm_inputs,
device=self.device, device=self.device,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment