Unverified Commit d18ed5cf authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Mllama flash version (#2585)

* Working loading state.

* Preprocessing.

* Working state ? (Broke idefics1 temporarily).

* Cleaner condition.

* Fix idefics.

* Updating config, removing TODO

* Mllama

* Ugrade transformers 4.45

* Flashing mllama.

* Starting to get there.

* Working state.

* Integrations tests for mllama (cutting to 10 tokens because there seems'
to be instability after (meaning size of the batch matters.

* Updating model link.

* Earlier assert.

* Fix vlm ?

* remove log.

* Force ignore all images but last.

* Default dtype bfloat16.

* Update integration test after switch to bf16.

* Remove dead code.

* Removed dead code.

* Upgrade the flake to latest transformers/tokenizers

* Move to hf tgi-nix

* Upgrade to 0.5.0
parent 584b4d7a
import torch
import torch.distributed
from typing import Optional
from text_generation_server.models.custom_modeling.idefics_config import IdeficsConfig
from text_generation_server.models.custom_modeling.idefics_processing import (
IdeficsProcessor,
)
from transformers import LlamaTokenizerFast
from text_generation_server.models.custom_modeling.idefics_modeling import (
IdeficsForVisionText2Text,
)
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLM
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
from text_generation_server.utils.quantization import get_loader
from text_generation_server.utils.import_utils import SYSTEM
class IDEFICSSharded(IdeficsCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.quantize = quantize
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
# 9b seems to work correctly enough in float16, but 80b seems
# to be really saturating for f16.
dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
# Float16 doesn't exist on target.
dtype = torch.bfloat16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
self.device, self.dtype = device, dtype
config = IdeficsConfig.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
)
config.quantize = quantize
config.speculator = speculator
config.vision_config.quantize = quantize
tokenizer = LlamaTokenizerFast.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
self.processor = IdeficsProcessor.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
weights_loader = get_loader(
quantize=quantize, model_id=model_id, revision=revision
)
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames,
device=device,
dtype=dtype,
process_group=self.process_group,
weights_loader=weights_loader,
)
model = IdeficsForVisionText2Text(config, weights)
torch.distributed.barrier(group=self.process_group)
super(IdeficsCausalLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
...@@ -6,6 +6,7 @@ import time ...@@ -6,6 +6,7 @@ import time
from dataclasses import dataclass from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import ( from transformers import (
AutoConfig,
AutoProcessor, AutoProcessor,
AutoTokenizer, AutoTokenizer,
PreTrainedTokenizerBase, PreTrainedTokenizerBase,
...@@ -22,6 +23,18 @@ from text_generation_server.models.types import ( ...@@ -22,6 +23,18 @@ from text_generation_server.models.types import (
) )
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
import torch.distributed
from text_generation_server.models.custom_modeling.idefics_modeling import (
IdeficsForVisionText2Text,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
from text_generation_server.utils.quantization import get_loader
from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
...@@ -577,23 +590,38 @@ class IdeficsCausalLM(Model): ...@@ -577,23 +590,38 @@ class IdeficsCausalLM(Model):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.quantize = quantize self.quantize = quantize
from text_generation_server.models.custom_modeling.idefics_modeling import ( self.process_group, rank, world_size = initialize_torch_distributed()
IdeficsForVisionText2Text,
)
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device(f"cuda:{rank}")
dtype = torch.bfloat16 if dtype is None else dtype # 9b seems to work correctly enough in float16, but 80b seems
# to be really saturating for f16.
dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
# Float16 doesn't exist on target.
dtype = torch.bfloat16 if dtype is None else dtype
else: else:
if quantize:
raise ValueError("quantization is not available on CPU")
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype dtype = torch.float32 if dtype is None else dtype
self.device, self.dtype = device, dtype
config = AutoConfig.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
)
config.quantize = quantize
config.speculator = speculator
config.vision_config.quantize = quantize
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, model_id,
...@@ -609,38 +637,34 @@ class IdeficsCausalLM(Model): ...@@ -609,38 +637,34 @@ class IdeficsCausalLM(Model):
truncation_side="left", truncation_side="left",
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
model = IdeficsForVisionText2Text.from_pretrained(
model_id, weights_loader = get_loader(
revision=revision, quantize=quantize, model_id=model_id, revision=revision
torch_dtype=dtype,
device_map=(
"auto"
if torch.cuda.is_available() and torch.cuda.device_count() > 1
else None
),
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code,
) )
if torch.cuda.is_available() and torch.cuda.device_count() == 1: torch.distributed.barrier(group=self.process_group)
model = model.cuda() filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
if tokenizer.pad_token_id is None: filenames,
if model.config.pad_token_id is not None: device=device,
tokenizer.pad_token_id = model.config.pad_token_id dtype=dtype,
elif model.config.eos_token_id is not None: process_group=self.process_group,
tokenizer.pad_token_id = model.config.eos_token_id weights_loader=weights_loader,
elif tokenizer.eos_token_id is not None: )
tokenizer.pad_token_id = tokenizer.eos_token_id
else: model = IdeficsForVisionText2Text(config, weights)
tokenizer.add_special_tokens({"pad_token": "<unk>"})
self.config = config
super(IdeficsCausalLM, self).__init__( torch.distributed.barrier(group=self.process_group)
super().__init__(
model_id=model_id, model_id=model_id,
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=True, requires_padding=True,
dtype=dtype, dtype=dtype,
device=device, device=device,
rank=rank,
world_size=world_size,
) )
@property @property
......
from io import BytesIO
from PIL import Image
import torch
from typing import Iterable, Optional, Tuple, List, Dict
from text_generation_server.pb.generate_pb2 import Request
from dataclasses import dataclass
from opentelemetry import trace
from transformers import (
PreTrainedTokenizerBase,
)
from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch, VlmCausalLM
from text_generation_server.pb import generate_pb2
from text_generation_server.models.flash_causal_lm import (
block_tables_to_ragged,
)
from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION
from text_generation_server.layers.attention import Seqlen
tracer = trace.get_tracer(__name__)
@dataclass
class MllamaCausalLMBatch(VlmCausalLMBatch):
image_indices: List[int] = 42
aspect_ratio_ids: Optional[torch.Tensor] = None
aspect_ratio_mask: Optional[torch.Tensor] = None
cross_attention_states: Optional[torch.Tensor] = None
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches):
batch = super().concatenate(batches)
batch.pixel_values = None
batch.pixel_attention_mask = None
offset = 0
image_indices = []
attention_states = []
for b in batches:
if b.cross_attention_states is not None:
attention_states.append(b.cross_attention_states)
image_indices.extend([i + offset for i in b.image_indices])
offset += len(b.image_indices)
if len(attention_states) > 0:
assert len(image_indices) > 0
batch.cross_attention_states = torch.cat(attention_states, dim=0)
batch.image_indices = image_indices
else:
batch.cross_attention_states = None
batch.image_indices = []
return batch
@tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]):
assert self.image_indices is not None
batch = super().filter(request_ids)
assert self.image_indices is not None
indices = []
for i, request_id in enumerate(request_ids):
idx = self.requests_idx_mapping[request_id]
indices.append(idx)
offset = 0
new_image_indices = []
prev_i = None
for i in self.image_indices:
if i in indices:
new_image_indices.append(offset)
if i != prev_i:
offset += 1
prev_i = i
batch.image_indices = new_image_indices
if len(new_image_indices) > 0:
assert max(new_image_indices) < self.cross_attention_states.shape[0]
assert offset <= self.cross_attention_states.shape[0]
batch.cross_attention_states = self.cross_attention_states[
new_image_indices
]
else:
batch.cross_attention_states = None
return batch
@classmethod
def batch_tokenized_inputs(
cls, requests: Iterable[Request], tokenizer, processor, config
):
image_inputs = []
texts = []
image_indices = []
batch_tokenized_inputs = []
for i, r in enumerate(requests):
# Each input is encoded into a list, where each element of this input list is either a string or a URL
curr_text = ""
curr_image = None
curr_i = None
for chunk in r.input_chunks.chunks:
chunk_type = chunk.WhichOneof("chunk")
if chunk_type == "text":
curr_text += chunk.text
elif chunk_type == "image":
image = Image.open(BytesIO(chunk.image.data))
# TODO unsure about BOS
curr_text += "<|image|>"
image_input = processor.image_processor(image, return_tensors="pt")
curr_image = image_input
curr_i = i
# image_inputs.append(image_input)
# image_indices.append(i)
else:
raise RuntimeError(f"Invalid chunk type {chunk_type}")
texts.append(curr_text)
if curr_image is not None:
image_inputs.append(curr_image)
image_indices.append(curr_i)
input_ids = tokenizer(
curr_text,
truncation=True,
max_length=r.truncate,
add_special_tokens=r.add_special_tokens,
)["input_ids"]
batch_tokenized_inputs.append(input_ids)
if image_inputs:
image_input = image_inputs[0]
new_image_inputs = {
"pixel_values": torch.cat(
[img["pixel_values"] for img in image_inputs], dim=0
),
}
if "aspect_ratio_ids" in image_input:
new_image_inputs["aspect_ratio_ids"] = torch.cat(
[img["aspect_ratio_ids"] for img in image_inputs], dim=0
)
if "aspect_ratio_mask" in image_input:
new_image_inputs["aspect_ratio_mask"] = torch.cat(
[img["aspect_ratio_mask"] for img in image_inputs], dim=0
)
image_inputs = new_image_inputs
image_inputs["image_indices"] = image_indices
else:
image_inputs = None
if image_inputs is not None:
assert len(image_indices) == image_inputs["pixel_values"].shape[0]
return batch_tokenized_inputs, image_inputs
@classmethod
def from_pb_processor(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
processor,
config,
dtype: torch.dtype,
device: torch.device,
) -> "VlmCausalLMBatch":
batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
pb.requests, tokenizer, processor, config
)
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
# XXX: <|image|> token is actually out of bounds and bugs out the logit processors.
batch.all_input_ids_tensor = batch.all_input_ids_tensor.clamp(
max=config.text_config.vocab_size - 1
)
batch.input_ids = batch.input_ids.clamp(max=config.text_config.vocab_size - 1)
if image_inputs is not None:
batch.pixel_values = image_inputs["pixel_values"].to(
device=device, dtype=dtype
)
batch.aspect_ratio_ids = image_inputs["aspect_ratio_ids"].to(device=device)
batch.aspect_ratio_mask = image_inputs["aspect_ratio_mask"].to(
device=device
)
batch.image_indices = image_inputs["image_indices"]
else:
batch.pixel_values = None
batch.aspect_ratio_ids = None
batch.aspect_ratio_mask = None
batch.image_indices = []
assert batch.image_indices is not None
return batch
class MllamaCausalLM(VlmCausalLM):
def forward(
self,
batch: VlmCausalLMBatch,
adapter_data: Optional[Dict[str, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Model Forward
if batch.speculative_ids is not None:
input_ids = batch.input_ids
position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = self.kv_cache
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
max_s = batch.max_seqlen
lm_head_indices = batch.prefill_head_indices
speculative_ids = batch.speculative_ids
B, speculative_length = speculative_ids.shape
new_length = speculative_length + 1
new_input_ids = torch.cat(
[input_ids.unsqueeze(-1), speculative_ids], dim=1
).reshape(-1)
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
arange_int = arange.to(dtype=torch.int32)
new_position_ids = (
position_ids.unsqueeze(-1).expand(B, new_length) + arange
).view(-1)
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
input_lengths = (
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1)
prefix_lens_tensor = (
batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length)
).reshape(-1)
# Add Copy the block tables for all members
block_tables = (
block_tables.unsqueeze(1)
.expand(B, new_length, -1)
.reshape(B * new_length, -1)
.contiguous()
)
max_s = max_s + speculative_length
input_ids = new_input_ids
position_ids = new_position_ids
else:
input_ids = batch.input_ids
position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = self.kv_cache
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
prefix_lens_tensor = batch.prefix_lens_tensor
max_s = batch.max_seqlen
lm_head_indices = batch.prefill_head_indices
if cu_seqlen_prefill is None and self.max_past() is not None:
# In decode, not prefill, we're actually overwriting the KV-cache
# in a circular buffer mode.
# This makes sure the max_s for the decode pass is correct.
max_s = min(self.max_past(), max_s)
bs = input_ids.shape[0]
# Try to find an associated cuda graph
bs = input_ids.shape[0]
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
if sorted_padded_bs:
# Get associated cuda graph
cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
else:
cuda_graph = None
if (
cu_seqlen_prefill is not None
or cuda_graph is None
# Only run cuda graphs when there's no images.
or batch.cross_attention_states is not None
):
input_lengths = input_lengths + prefix_lens_tensor
if PREFIX_CACHING:
block_tables = block_tables_to_ragged(
block_tables=block_tables,
input_lengths=batch.input_lengths,
prefix_lens=batch.prefix_lens,
)
with self._forward_context(
block_tables=block_tables,
cu_seqlen_prefill=cu_seqlen_prefill,
input_lengths_tensor=input_lengths,
prefix_lens_tensor=prefix_lens_tensor,
):
max_k = (input_lengths + prefix_lens_tensor).max().item()
seqlen = Seqlen(
input_lengths=input_lengths,
prefix_lengths=prefix_lens_tensor,
cu_seqlen_q=cu_seqlen_prefill,
max_q=max_s,
max_k=max_k,
)
if batch.pixel_values is not None:
cross_attention_states = self.model.vision_forward(
pixel_values=batch.pixel_values,
aspect_ratio_ids=batch.aspect_ratio_ids,
aspect_ratio_mask=batch.aspect_ratio_mask,
)
batch.cross_attention_states = cross_attention_states
cross_attention_states = batch.cross_attention_states
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
seqlen=seqlen,
max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices,
cross_attention_states=cross_attention_states,
adapter_data=adapter_data,
image_indices=batch.image_indices[:],
)
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None
if batch.pixel_values is not None:
batch.pixel_values = None
return logits, speculative_logits
# Copy inputs to the static inputs of the cuda graph
# Static inputs are potentially padded
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
if ATTENTION == "flashinfer":
block_tables = block_tables_to_ragged(
block_tables=block_tables,
input_lengths=batch.input_lengths,
prefix_lens=batch.prefix_lens,
)
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
else:
cuda_graph["block_tables"][
: block_tables.shape[0], : block_tables.shape[1]
] = block_tables
cuda_graph["slots"].fill_(0)
cuda_graph["slots"][: slots.shape[0]] = slots
cuda_graph["input_lengths"].zero_()
cuda_graph["input_lengths"][: input_lengths.shape[0]] = (
input_lengths + prefix_lens_tensor
)
# Replay the graph
cuda_graph["graph"].replay()
# Slice output to the correct shape
speculative_logits = (
cuda_graph["speculative_logits"][:bs]
if cuda_graph["speculative_logits"] is not None
else None
)
logits = cuda_graph["logits"][:bs]
return logits, speculative_logits
...@@ -22,8 +22,14 @@ try: ...@@ -22,8 +22,14 @@ try:
VlmCausalLMBatch, VlmCausalLMBatch,
) )
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch
VLM_BATCH_TYPES = {PaliGemmaBatch, VlmCausalLMBatch, IdeficsCausalLMBatch} VLM_BATCH_TYPES = {
PaliGemmaBatch,
VlmCausalLMBatch,
IdeficsCausalLMBatch,
MllamaCausalLMBatch,
}
except (ImportError, NotImplementedError): except (ImportError, NotImplementedError):
# These imports can fail on CPU/Non flash. # These imports can fail on CPU/Non flash.
VLM_BATCH_TYPES = set() VLM_BATCH_TYPES = set()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment