Unverified Commit cec98f10 authored by yhyang201's avatar yhyang201 Committed by GitHub
Browse files

[Fix] Incorrect Memory Allocation on CUDA:0 by Non-Zero CUDA Processes in TP/DP (#5745)

parent 8dc4efd0
......@@ -48,6 +48,9 @@ class DictOutput(object):
def __getitem__(self, item):
return self.__dict__[item]
def __contains__(self, key):
return key in self.__dict__
def __setitem__(self, key, value):
self.__dict__[key] = value
......
......@@ -290,6 +290,9 @@ class DictOutput(object):
def __getitem__(self, item):
return self.__dict__[item]
def __contains__(self, key):
return key in self.__dict__
def __setitem__(self, key, value):
self.__dict__[key] = value
......
......@@ -8,6 +8,7 @@ from typing import List, Optional
import numpy as np
import PIL
import torch
from PIL import Image
from transformers import BaseImageProcessorFast
......@@ -89,6 +90,10 @@ class BaseMultimodalProcessor(ABC):
return_tensors="pt",
**kwargs,
)
if "pixel_values" in result and isinstance(
result["pixel_values"], torch.Tensor
):
result["pixel_values"] = result["pixel_values"].to("cpu")
return result
@abstractmethod
......
......@@ -745,6 +745,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
out_cache_loc: torch.Tensor = None # shape: [b], int64
output_ids: torch.Tensor = None # shape: [b], int64
# For multimodal inputs
multimodal_inputs: Optional[List] = None
# The sum of all sequence lengths
seq_lens_sum: int = None
......@@ -1050,6 +1053,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Copy prefix and do some basic check
input_embeds = []
extend_input_logprob_token_ids = []
multimodal_inputs = []
for i, (req, seq_len, pre_len) in enumerate(zip(reqs, seq_lens, prefix_lens)):
req.req_pool_idx = req_pool_indices[i]
......@@ -1065,6 +1069,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# If req.input_embeds is already a list, append its content directly
input_embeds.extend(req.input_embeds) # Use extend to avoid nesting
multimodal_inputs.append(req.multimodal_inputs)
req.cached_tokens += pre_len - req.already_computed
req.already_computed = seq_len
req.is_retracted = False
......@@ -1147,6 +1153,16 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
if input_embeds
else None
)
for mm_input in multimodal_inputs:
if mm_input is None:
continue
for mm_item in mm_input.mm_items:
pixel_values = getattr(mm_item, "pixel_values", None)
if isinstance(pixel_values, torch.Tensor):
mm_item.pixel_values = pixel_values.to(
self.device, non_blocking=True
)
self.multimodal_inputs = multimodal_inputs
self.seq_lens_sum = sum(seq_lens)
if self.return_logprob:
......@@ -1452,6 +1468,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]
self.reqs = [self.reqs[i] for i in keep_indices]
self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices]
self.req_pool_indices = self.req_pool_indices[keep_indices_device]
self.seq_lens = self.seq_lens[keep_indices_device]
self.out_cache_loc = None
......@@ -1500,6 +1517,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
self.token_ids_logprobs = [None] * len(self.reqs) + other.token_ids_logprobs
self.reqs.extend(other.reqs)
self.multimodal_inputs.extend(other.multimodal_inputs)
self.return_logprob |= other.return_logprob
self.has_stream |= other.has_stream
......@@ -1558,7 +1576,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
extend_seq_lens=extend_seq_lens,
extend_prefix_lens=extend_prefix_lens,
extend_logprob_start_lens=extend_logprob_start_lens,
multimodal_inputs=[r.multimodal_inputs for r in self.reqs],
multimodal_inputs=self.multimodal_inputs,
encoder_cached=self.encoder_cached,
encoder_lens=self.encoder_lens,
encoder_lens_cpu=self.encoder_lens_cpu,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment