# assert self.batch_size_per_gpu == 1, "Llava currently does not support batched generation. See https://github.com/haotian-liu/LLaVA/issues/754. HF Llava also has this issue."
ifaccelerator.num_processes>1:
assertaccelerator.distributed_typein[
DistributedType.FSDP,
DistributedType.MULTI_GPU,
DistributedType.DEEPSPEED,
],"Unsupported distributed type provided. Only DDP and FSDP are supported."
# If you want to use DistributedType.DEEPSPEED, you have to run accelerate config before using the model
# Also, you have to select zero stage 0 (equivalent to DDP) in order to make the prepare model works
# I tried to set different parameters in the kwargs to let default zero 2 stage works, but it didn't work.
1. No image, and there for, no image token should be added.
2. image token is already specified in the context, so we don't need to add it.
3. image token is not specified in the context and there is image inputs, so we need to add it. In this case, we add the image token at the beginning of the context and add a new line.
# encode, pad, and truncate contexts for this batch
ifvisuals:
image_tensor=process_images(
visuals,self._image_processor,self._config
)
ifisinstance(image_tensor,list):
image_tensor=[
_image.to(dtype=torch.float16,device=self.device)
for_imageinimage_tensor
]
else:
image_tensor=image_tensor.to(
dtype=torch.float16,device=self.device
)
else:
image_tensor=None
# prompts_input = contexts[0]
question_input=[]
forvisual,contextinzip(visuals,contexts):
if(
image_tensorisnotNone
andlen(image_tensor)!=0
andDEFAULT_IMAGE_TOKENnotincontext
):
"""
Three senarios:
1. No image, and there for, no image token should be added.
2. image token is already specified in the context, so we don't need to add it.
3. image token is not specified in the context and there is image inputs, so we need to add it. In this case, we add the image token at the beginning of the context and add a new line.
"""
image_tokens=(
[DEFAULT_IMAGE_TOKEN]*len(visual)
ifisinstance(visual,list)
else[DEFAULT_IMAGE_TOKEN]
)
image_tokens=" ".join(image_tokens)
question=image_tokens+"\n"+context
else:
question=context
conv=conv_templates[self.conv_template].copy()
conv.append_message(conv.roles[0],question)
conv.append_message(conv.roles[1],None)
prompt_question=conv.get_prompt()
question_input.append(prompt_question)
# The above for loop has bugs. When there is no visuals, e.g. pure text,
# there will be no for loop execute resulting in an empty question_input (because no visuals)