Unverified Commit cec98f10 authored by yhyang201's avatar yhyang201 Committed by GitHub
Browse files

[Fix] Incorrect Memory Allocation on CUDA:0 by Non-Zero CUDA Processes in TP/DP (#5745)

parent 8dc4efd0
...@@ -48,6 +48,9 @@ class DictOutput(object): ...@@ -48,6 +48,9 @@ class DictOutput(object):
def __getitem__(self, item): def __getitem__(self, item):
return self.__dict__[item] return self.__dict__[item]
def __contains__(self, key):
return key in self.__dict__
def __setitem__(self, key, value): def __setitem__(self, key, value):
self.__dict__[key] = value self.__dict__[key] = value
......
...@@ -290,6 +290,9 @@ class DictOutput(object): ...@@ -290,6 +290,9 @@ class DictOutput(object):
def __getitem__(self, item): def __getitem__(self, item):
return self.__dict__[item] return self.__dict__[item]
def __contains__(self, key):
return key in self.__dict__
def __setitem__(self, key, value): def __setitem__(self, key, value):
self.__dict__[key] = value self.__dict__[key] = value
......
...@@ -8,6 +8,7 @@ from typing import List, Optional ...@@ -8,6 +8,7 @@ from typing import List, Optional
import numpy as np import numpy as np
import PIL import PIL
import torch
from PIL import Image from PIL import Image
from transformers import BaseImageProcessorFast from transformers import BaseImageProcessorFast
...@@ -89,6 +90,10 @@ class BaseMultimodalProcessor(ABC): ...@@ -89,6 +90,10 @@ class BaseMultimodalProcessor(ABC):
return_tensors="pt", return_tensors="pt",
**kwargs, **kwargs,
) )
if "pixel_values" in result and isinstance(
result["pixel_values"], torch.Tensor
):
result["pixel_values"] = result["pixel_values"].to("cpu")
return result return result
@abstractmethod @abstractmethod
......
...@@ -745,6 +745,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -745,6 +745,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
out_cache_loc: torch.Tensor = None # shape: [b], int64 out_cache_loc: torch.Tensor = None # shape: [b], int64
output_ids: torch.Tensor = None # shape: [b], int64 output_ids: torch.Tensor = None # shape: [b], int64
# For multimodal inputs
multimodal_inputs: Optional[List] = None
# The sum of all sequence lengths # The sum of all sequence lengths
seq_lens_sum: int = None seq_lens_sum: int = None
...@@ -1050,6 +1053,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1050,6 +1053,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Copy prefix and do some basic check # Copy prefix and do some basic check
input_embeds = [] input_embeds = []
extend_input_logprob_token_ids = [] extend_input_logprob_token_ids = []
multimodal_inputs = []
for i, (req, seq_len, pre_len) in enumerate(zip(reqs, seq_lens, prefix_lens)): for i, (req, seq_len, pre_len) in enumerate(zip(reqs, seq_lens, prefix_lens)):
req.req_pool_idx = req_pool_indices[i] req.req_pool_idx = req_pool_indices[i]
...@@ -1065,6 +1069,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1065,6 +1069,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# If req.input_embeds is already a list, append its content directly # If req.input_embeds is already a list, append its content directly
input_embeds.extend(req.input_embeds) # Use extend to avoid nesting input_embeds.extend(req.input_embeds) # Use extend to avoid nesting
multimodal_inputs.append(req.multimodal_inputs)
req.cached_tokens += pre_len - req.already_computed req.cached_tokens += pre_len - req.already_computed
req.already_computed = seq_len req.already_computed = seq_len
req.is_retracted = False req.is_retracted = False
...@@ -1147,6 +1153,16 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1147,6 +1153,16 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
if input_embeds if input_embeds
else None else None
) )
for mm_input in multimodal_inputs:
if mm_input is None:
continue
for mm_item in mm_input.mm_items:
pixel_values = getattr(mm_item, "pixel_values", None)
if isinstance(pixel_values, torch.Tensor):
mm_item.pixel_values = pixel_values.to(
self.device, non_blocking=True
)
self.multimodal_inputs = multimodal_inputs
self.seq_lens_sum = sum(seq_lens) self.seq_lens_sum = sum(seq_lens)
if self.return_logprob: if self.return_logprob:
...@@ -1452,6 +1468,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1452,6 +1468,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices] self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]
self.reqs = [self.reqs[i] for i in keep_indices] self.reqs = [self.reqs[i] for i in keep_indices]
self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices]
self.req_pool_indices = self.req_pool_indices[keep_indices_device] self.req_pool_indices = self.req_pool_indices[keep_indices_device]
self.seq_lens = self.seq_lens[keep_indices_device] self.seq_lens = self.seq_lens[keep_indices_device]
self.out_cache_loc = None self.out_cache_loc = None
...@@ -1500,6 +1517,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1500,6 +1517,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
self.token_ids_logprobs = [None] * len(self.reqs) + other.token_ids_logprobs self.token_ids_logprobs = [None] * len(self.reqs) + other.token_ids_logprobs
self.reqs.extend(other.reqs) self.reqs.extend(other.reqs)
self.multimodal_inputs.extend(other.multimodal_inputs)
self.return_logprob |= other.return_logprob self.return_logprob |= other.return_logprob
self.has_stream |= other.has_stream self.has_stream |= other.has_stream
...@@ -1558,7 +1576,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1558,7 +1576,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
extend_seq_lens=extend_seq_lens, extend_seq_lens=extend_seq_lens,
extend_prefix_lens=extend_prefix_lens, extend_prefix_lens=extend_prefix_lens,
extend_logprob_start_lens=extend_logprob_start_lens, extend_logprob_start_lens=extend_logprob_start_lens,
multimodal_inputs=[r.multimodal_inputs for r in self.reqs], multimodal_inputs=self.multimodal_inputs,
encoder_cached=self.encoder_cached, encoder_cached=self.encoder_cached,
encoder_lens=self.encoder_lens, encoder_lens=self.encoder_lens,
encoder_lens_cpu=self.encoder_lens_cpu, encoder_lens_cpu=self.encoder_lens_cpu,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment