Unverified Commit 47954b81 authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat: format code (#1070)

parent b32e9ce9
...@@ -137,7 +137,7 @@ class Client: ...@@ -137,7 +137,7 @@ class Client:
typical_p=typical_p, typical_p=typical_p,
watermark=watermark, watermark=watermark,
decoder_input_details=decoder_input_details, decoder_input_details=decoder_input_details,
top_n_tokens=top_n_tokens top_n_tokens=top_n_tokens,
) )
request = Request(inputs=prompt, stream=False, parameters=parameters) request = Request(inputs=prompt, stream=False, parameters=parameters)
......
...@@ -133,7 +133,9 @@ class Request(BaseModel): ...@@ -133,7 +133,9 @@ class Request(BaseModel):
and parameters.best_of > 1 and parameters.best_of > 1
and field_value and field_value
): ):
raise ValidationError("`best_of` != 1 is not supported when `stream` == True") raise ValidationError(
"`best_of` != 1 is not supported when `stream` == True"
)
return field_value return field_value
......
...@@ -3,7 +3,11 @@ import pytest ...@@ -3,7 +3,11 @@ import pytest
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def flash_llama_awq_handle(launcher): def flash_llama_awq_handle(launcher):
with launcher("abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq", num_shard=1, quantize="awq") as handle: with launcher(
"abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq",
num_shard=1,
quantize="awq",
) as handle:
yield handle yield handle
...@@ -12,6 +16,7 @@ async def flash_llama_awq(flash_llama_awq_handle): ...@@ -12,6 +16,7 @@ async def flash_llama_awq(flash_llama_awq_handle):
await flash_llama_awq_handle.health(300) await flash_llama_awq_handle.health(300)
return flash_llama_awq_handle.client return flash_llama_awq_handle.client
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_awq(flash_llama_awq, response_snapshot): async def test_flash_llama_awq(flash_llama_awq, response_snapshot):
...@@ -20,11 +25,13 @@ async def test_flash_llama_awq(flash_llama_awq, response_snapshot): ...@@ -20,11 +25,13 @@ async def test_flash_llama_awq(flash_llama_awq, response_snapshot):
) )
assert response.details.generated_tokens == 10 assert response.details.generated_tokens == 10
assert response.generated_text == "\nWhat is the difference between Deep Learning and Machine" assert (
response.generated_text
== "\nWhat is the difference between Deep Learning and Machine"
)
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot): async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot):
...@@ -49,16 +56,18 @@ async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot): ...@@ -49,16 +56,18 @@ async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_awq_load( async def test_flash_llama_awq_load(flash_llama_awq, generate_load, response_snapshot):
flash_llama_awq, generate_load, response_snapshot
):
responses = await generate_load( responses = await generate_load(
flash_llama_awq, "What is Deep Learning?", max_new_tokens=10, n=4 flash_llama_awq, "What is Deep Learning?", max_new_tokens=10, n=4
) )
assert len(responses) == 4 assert len(responses) == 4
assert all([r.generated_text == "\nWhat is the difference between Deep Learning and Machine" for r in responses]) assert all(
[
r.generated_text
== "\nWhat is the difference between Deep Learning and Machine"
for r in responses
]
)
assert responses == response_snapshot assert responses == response_snapshot
import pytest import pytest
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def flash_llama_awq_handle_sharded(launcher): def flash_llama_awq_handle_sharded(launcher):
with launcher("abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq", num_shard=2, quantize="awq") as handle: with launcher(
"abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq",
num_shard=2,
quantize="awq",
) as handle:
yield handle yield handle
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
async def flash_llama_awq_sharded(flash_llama_awq_handle_sharded): async def flash_llama_awq_sharded(flash_llama_awq_handle_sharded):
await flash_llama_awq_handle_sharded.health(300) await flash_llama_awq_handle_sharded.health(300)
return flash_llama_awq_handle_sharded.client return flash_llama_awq_handle_sharded.client
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapshot): async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapshot):
...@@ -18,9 +25,13 @@ async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapsho ...@@ -18,9 +25,13 @@ async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapsho
) )
assert response.details.generated_tokens == 10 assert response.details.generated_tokens == 10
assert response.generated_text == "\nWhat is the difference between Deep Learning and Machine" assert (
response.generated_text
== "\nWhat is the difference between Deep Learning and Machine"
)
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_awq_load_sharded( async def test_flash_llama_awq_load_sharded(
...@@ -31,6 +42,12 @@ async def test_flash_llama_awq_load_sharded( ...@@ -31,6 +42,12 @@ async def test_flash_llama_awq_load_sharded(
) )
assert len(responses) == 4 assert len(responses) == 4
assert all([r.generated_text == "\nWhat is the difference between Deep Learning and Machine" for r in responses]) assert all(
[
r.generated_text
== "\nWhat is the difference between Deep Learning and Machine"
for r in responses
]
)
assert responses == response_snapshot assert responses == response_snapshot
...@@ -3,9 +3,7 @@ import pytest ...@@ -3,9 +3,7 @@ import pytest
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def idefics_handle(launcher): def idefics_handle(launcher):
with launcher( with launcher("HuggingFaceM4/idefics-9b-instruct", num_shard=2) as handle:
"HuggingFaceM4/idefics-9b-instruct", num_shard=2
) as handle:
yield handle yield handle
......
...@@ -45,12 +45,15 @@ def test_stopping_criteria_max(): ...@@ -45,12 +45,15 @@ def test_stopping_criteria_max():
assert criteria(1, "") == (False, None) assert criteria(1, "") == (False, None)
assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH) assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH)
def test_batch_top_tokens(): def test_batch_top_tokens():
top_n_tokens = [0, 2, 3, 4, 5] top_n_tokens = [0, 2, 3, 4, 5]
top_n_tokens_tensor = torch.tensor(top_n_tokens) top_n_tokens_tensor = torch.tensor(top_n_tokens)
inp_logprobs = torch.tensor([[-1., -3., -4., -2., -3.]] * 5) inp_logprobs = torch.tensor([[-1.0, -3.0, -4.0, -2.0, -3.0]] * 5)
topn_tok_ids, topn_tok_logprobs = batch_top_tokens(top_n_tokens, top_n_tokens_tensor, inp_logprobs) topn_tok_ids, topn_tok_logprobs = batch_top_tokens(
top_n_tokens, top_n_tokens_tensor, inp_logprobs
)
assert topn_tok_ids[0] == [] assert topn_tok_ids[0] == []
assert topn_tok_ids[1] == [0, 3] assert topn_tok_ids[1] == [0, 3]
......
...@@ -125,8 +125,12 @@ def download_weights( ...@@ -125,8 +125,12 @@ def download_weights(
if not is_local_model: if not is_local_model:
try: try:
adapter_config_filename = hf_hub_download(model_id, revision=revision, filename="adapter_config.json") adapter_config_filename = hf_hub_download(
utils.download_and_unload_peft(model_id, revision, trust_remote_code=trust_remote_code) model_id, revision=revision, filename="adapter_config.json"
)
utils.download_and_unload_peft(
model_id, revision, trust_remote_code=trust_remote_code
)
is_local_model = True is_local_model = True
utils.weight_files(model_id, revision, extension) utils.weight_files(model_id, revision, extension)
return return
...@@ -179,11 +183,12 @@ def download_weights( ...@@ -179,11 +183,12 @@ def download_weights(
import transformers import transformers
import json import json
if is_local_model: if is_local_model:
config_filename = os.path.join(model_id, "config.json") config_filename = os.path.join(model_id, "config.json")
else: else:
config_filename = hf_hub_download(model_id, revision=revision, filename="config.json") config_filename = hf_hub_download(
model_id, revision=revision, filename="config.json"
)
with open(config_filename, "r") as f: with open(config_filename, "r") as f:
config = json.load(f) config = json.load(f)
architecture = config["architectures"][0] architecture = config["architectures"][0]
......
...@@ -153,7 +153,11 @@ def get_model( ...@@ -153,7 +153,11 @@ def get_model(
) )
elif model_type == "mpt": elif model_type == "mpt":
return MPTSharded( return MPTSharded(
model_id, revision, quantize=quantize, dtype=dtype, trust_remote_code=trust_remote_code model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
) )
elif model_type == "gpt_neox": elif model_type == "gpt_neox":
...@@ -252,13 +256,13 @@ def get_model( ...@@ -252,13 +256,13 @@ def get_model(
) )
elif model_type == "idefics": elif model_type == "idefics":
if FLASH_ATTENTION: if FLASH_ATTENTION:
return IDEFICSSharded( return IDEFICSSharded(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
else: else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
...@@ -269,13 +273,9 @@ def get_model( ...@@ -269,13 +273,9 @@ def get_model(
"gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
) )
if quantize == "awq": if quantize == "awq":
raise ValueError( raise ValueError("awq quantization is not supported for AutoModel")
"awq quantization is not supported for AutoModel"
)
elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"): elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
raise ValueError( raise ValueError("4bit quantization is not supported for AutoModel")
"4bit quantization is not supported for AutoModel"
)
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
return CausalLM( return CausalLM(
model_id, model_id,
......
...@@ -643,9 +643,12 @@ class CausalLM(Model): ...@@ -643,9 +643,12 @@ class CausalLM(Model):
# Decode generated tokens # Decode generated tokens
output_text, _, _ = self.decode_token( output_text, _, _ = self.decode_token(
all_input_ids[:, 0], all_input_ids[:, 0],
prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1, prefix_offset=len(all_input_ids)
read_offset=len(all_input_ids) - stopping_criteria.current_tokens, - stopping_criteria.current_tokens
skip_special_tokens=True - 1,
read_offset=len(all_input_ids)
- stopping_criteria.current_tokens,
skip_special_tokens=True,
) )
# Get seed # Get seed
if isinstance(next_token_chooser.choice, Sampling): if isinstance(next_token_chooser.choice, Sampling):
......
...@@ -40,7 +40,10 @@ from text_generation_server.utils.layers import ( ...@@ -40,7 +40,10 @@ from text_generation_server.utils.layers import (
) )
CUSTOM_KERNELS_ENABLED = False CUSTOM_KERNELS_ENABLED = False
if torch.cuda.is_available() and not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True": if (
torch.cuda.is_available()
and not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True"
):
try: try:
from custom_kernels import fused_bloom_attention_cuda from custom_kernels import fused_bloom_attention_cuda
......
...@@ -169,6 +169,7 @@ def load_attention(config, prefix, weights): ...@@ -169,6 +169,7 @@ def load_attention(config, prefix, weights):
bias=False, bias=False,
) )
def _load_gqa(config, prefix: str, weights): def _load_gqa(config, prefix: str, weights):
assert config.hidden_size % config.num_attention_heads == 0 assert config.hidden_size % config.num_attention_heads == 0
assert config.num_attention_heads % weights.process_group.size() == 0 assert config.num_attention_heads % weights.process_group.size() == 0
...@@ -211,7 +212,10 @@ class FlashLlamaAttention(torch.nn.Module): ...@@ -211,7 +212,10 @@ class FlashLlamaAttention(torch.nn.Module):
# config=config, prefix=f"{prefix}.rotary_emb", weights=weights # config=config, prefix=f"{prefix}.rotary_emb", weights=weights
# ) # )
self.rotary_emb = PositionRotaryEmbedding.static( self.rotary_emb = PositionRotaryEmbedding.static(
config=config, dim=self.head_size, base=config.rope_theta, device=weights.device config=config,
dim=self.head_size,
base=config.rope_theta,
device=weights.device,
) )
self.softmax_scale = self.head_size**-0.5 self.softmax_scale = self.head_size**-0.5
......
...@@ -20,7 +20,12 @@ import numpy as np ...@@ -20,7 +20,12 @@ import numpy as np
from PIL import Image from PIL import Image
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
from transformers.image_transforms import resize, to_channel_dimension_format, rescale, normalize from transformers.image_transforms import (
resize,
to_channel_dimension_format,
rescale,
normalize,
)
from transformers.image_utils import ( from transformers.image_utils import (
ChannelDimension, ChannelDimension,
ImageInput, ImageInput,
...@@ -121,7 +126,11 @@ class IdeficsImageProcessor(BaseImageProcessor): ...@@ -121,7 +126,11 @@ class IdeficsImageProcessor(BaseImageProcessor):
a PyTorch tensor of the processed images a PyTorch tensor of the processed images
""" """
image_size = image_size if image_size is not None else self.image_size image_size = image_size if image_size is not None else self.image_size
image_num_channels = image_num_channels if image_num_channels is not None else self.image_num_channels image_num_channels = (
image_num_channels
if image_num_channels is not None
else self.image_num_channels
)
image_mean = image_mean if image_mean is not None else self.image_mean image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std image_std = image_std if image_std is not None else self.image_std
size = (image_size, image_size) size = (image_size, image_size)
...@@ -160,9 +169,13 @@ class IdeficsImageProcessor(BaseImageProcessor): ...@@ -160,9 +169,13 @@ class IdeficsImageProcessor(BaseImageProcessor):
images = [resize(x, size, resample=PILImageResampling.BICUBIC) for x in images] images = [resize(x, size, resample=PILImageResampling.BICUBIC) for x in images]
images = [self.rescale(image=image, scale=1 / 255) for image in images] images = [self.rescale(image=image, scale=1 / 255) for image in images]
images = [self.normalize(x, mean=image_mean, std=image_std) for x in images] images = [self.normalize(x, mean=image_mean, std=image_std) for x in images]
images = [to_channel_dimension_format(x, ChannelDimension.FIRST) for x in images] images = [
to_channel_dimension_format(x, ChannelDimension.FIRST) for x in images
]
# TODO: this converts to torch tensors - switch to convert_to_tensors once it becomes available # TODO: this converts to torch tensors - switch to convert_to_tensors once it becomes available
images = BatchFeature(data={"pixel_values": images}, tensor_type=TensorType.PYTORCH)["pixel_values"] images = BatchFeature(
data={"pixel_values": images}, tensor_type=TensorType.PYTORCH
)["pixel_values"]
return images return images
...@@ -185,7 +198,9 @@ class IdeficsImageProcessor(BaseImageProcessor): ...@@ -185,7 +198,9 @@ class IdeficsImageProcessor(BaseImageProcessor):
response.raise_for_status() response.raise_for_status()
return Image.open(BytesIO(response.content)) return Image.open(BytesIO(response.content))
else: else:
raise ValueError(f"only a single or a list of entries is supported but got type={type(image_url_or_urls)}") raise ValueError(
f"only a single or a list of entries is supported but got type={type(image_url_or_urls)}"
)
def rescale( def rescale(
self, self,
...@@ -255,10 +270,9 @@ class IdeficsImageProcessor(BaseImageProcessor): ...@@ -255,10 +270,9 @@ class IdeficsImageProcessor(BaseImageProcessor):
`np.ndarray`: The normalized image. `np.ndarray`: The normalized image.
""" """
# TODO 4.32 # TODO 4.32
return normalize( return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
image, mean=mean, std=std, data_format=data_format, **kwargs
)
import transformers import transformers
transformers.IdeficsImageProcessor = IdeficsImageProcessor transformers.IdeficsImageProcessor = IdeficsImageProcessor
...@@ -46,7 +46,8 @@ from text_generation_server.utils.layers import ( ...@@ -46,7 +46,8 @@ from text_generation_server.utils.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
) )
EPS=1e-5 EPS = 1e-5
class IdeficsPerceiverResampler(nn.Module): class IdeficsPerceiverResampler(nn.Module):
def __init__( def __init__(
...@@ -78,7 +79,12 @@ class IdeficsPerceiverResampler(nn.Module): ...@@ -78,7 +79,12 @@ class IdeficsPerceiverResampler(nn.Module):
""" """
super().__init__() super().__init__()
self.embed_dim, self.n_heads, self.head_dim, self.n_latents = embed_dim, n_heads, head_dim, n_latents self.embed_dim, self.n_heads, self.head_dim, self.n_latents = (
embed_dim,
n_heads,
head_dim,
n_latents,
)
self.qk_layer_norms = config.perceiver_config.qk_layer_norms_perceiver self.qk_layer_norms = config.perceiver_config.qk_layer_norms_perceiver
# Create Latents for Perceiver # Create Latents for Perceiver
...@@ -107,14 +113,16 @@ class IdeficsPerceiverResampler(nn.Module): ...@@ -107,14 +113,16 @@ class IdeficsPerceiverResampler(nn.Module):
prefix=f"{prefix}.blocks.{layer_id}.1", prefix=f"{prefix}.blocks.{layer_id}.1",
intermediate_size=self.intermediate_dim, intermediate_size=self.intermediate_dim,
config=config, config=config,
weights=weights weights=weights,
), ),
] ]
) )
for layer_id in range(depth) for layer_id in range(depth)
] ]
) )
self.layer_norm = nn.LayerNorm.load(prefix=f"{prefix}.layer_norm", weights=weights, eps=EPS) self.layer_norm = nn.LayerNorm.load(
prefix=f"{prefix}.layer_norm", weights=weights, eps=EPS
)
def forward(self, context: torch.Tensor) -> torch.Tensor: def forward(self, context: torch.Tensor) -> torch.Tensor:
"""Resample arbitrary length context & *compress* down to self.n_latents latent embeddings""" """Resample arbitrary length context & *compress* down to self.n_latents latent embeddings"""
...@@ -130,25 +138,34 @@ class IdeficsPerceiverResampler(nn.Module): ...@@ -130,25 +138,34 @@ class IdeficsPerceiverResampler(nn.Module):
class IdeficsPerceiverAttention(nn.Module): class IdeficsPerceiverAttention(nn.Module):
def __init__(self, def __init__(
prefix, self,
config, prefix,
embed_dim: int, config,
n_heads: int, embed_dim: int,
head_dim: int, n_heads: int,
qk_layer_norms: bool, head_dim: int,
weights qk_layer_norms: bool,
) -> None: weights,
) -> None:
"""Perceiver Cross-Attention Module --> let long-form inputs be `context`, resampled embeddings be `latents`""" """Perceiver Cross-Attention Module --> let long-form inputs be `context`, resampled embeddings be `latents`"""
super().__init__() super().__init__()
self.embed_dim, self.n_heads, self.head_dim = embed_dim, n_heads, head_dim self.embed_dim, self.n_heads, self.head_dim = embed_dim, n_heads, head_dim
self.qk_layer_norms = qk_layer_norms self.qk_layer_norms = qk_layer_norms
# Normalization & Scaling # Normalization & Scaling
self.context_layer_norm = nn.LayerNorm.load(prefix=f"{prefix}.context_layer_norm", weights=weights, eps=EPS) self.context_layer_norm = nn.LayerNorm.load(
self.latents_layer_norm = nn.LayerNorm.load(prefix=f"{prefix}.latents_layer_norm", weights=weights, eps=EPS) prefix=f"{prefix}.context_layer_norm", weights=weights, eps=EPS
)
self.latents_layer_norm = nn.LayerNorm.load(
prefix=f"{prefix}.latents_layer_norm", weights=weights, eps=EPS
)
if self.qk_layer_norms: if self.qk_layer_norms:
self.q_layer_norm = nn.LayerNorm.load(prefix=f"{prefix}.q_layer_norm", weights=weights, eps=EPS) self.q_layer_norm = nn.LayerNorm.load(
self.k_layer_norm = nn.LayerNorm.load(prefix=f"{prefix}.k_layer_norm", weights=weights, eps=EPS) prefix=f"{prefix}.q_layer_norm", weights=weights, eps=EPS
)
self.k_layer_norm = nn.LayerNorm.load(
prefix=f"{prefix}.k_layer_norm", weights=weights, eps=EPS
)
self.qk_scale = self.head_dim**-0.5 self.qk_scale = self.head_dim**-0.5
...@@ -164,10 +181,10 @@ class IdeficsPerceiverAttention(nn.Module): ...@@ -164,10 +181,10 @@ class IdeficsPerceiverAttention(nn.Module):
self.q_proj = TensorParallelColumnLinear.load( self.q_proj = TensorParallelColumnLinear.load(
config=config, prefix=f"{prefix}.q_proj", weights=weights, bias=False config=config, prefix=f"{prefix}.q_proj", weights=weights, bias=False
) )
self.k_proj = TensorParallelColumnLinear.load( self.k_proj = TensorParallelColumnLinear.load(
config=config, prefix=f"{prefix}.k_proj", weights=weights, bias=False config=config, prefix=f"{prefix}.k_proj", weights=weights, bias=False
) )
self.v_proj = TensorParallelColumnLinear.load( self.v_proj = TensorParallelColumnLinear.load(
config=config, prefix=f"{prefix}.v_proj", weights=weights, bias=False config=config, prefix=f"{prefix}.v_proj", weights=weights, bias=False
) )
...@@ -202,7 +219,12 @@ class IdeficsPerceiverAttention(nn.Module): ...@@ -202,7 +219,12 @@ class IdeficsPerceiverAttention(nn.Module):
# Multiheaded Self-Attention w/ stable softmax (subtract per-row max -- `amax` -- before softmax call) # Multiheaded Self-Attention w/ stable softmax (subtract per-row max -- `amax` -- before softmax call)
# =>> `attn` should be a 2D matrix of shape [n_latents x (context + n_latents)] # =>> `attn` should be a 2D matrix of shape [n_latents x (context + n_latents)]
# einsum.rearrange(x, "bsz seq (heads embed) -> bsz heads seq embed", heads=self.n_heads) # einsum.rearrange(x, "bsz seq (heads embed) -> bsz heads seq embed", heads=self.n_heads)
q, k, v = [x.reshape(batch_size, x.shape[1], self.n_heads, self.head_dim).transpose(1, 2) for x in (q, k, v)] q, k, v = [
x.reshape(batch_size, x.shape[1], self.n_heads, self.head_dim).transpose(
1, 2
)
for x in (q, k, v)
]
if self.qk_layer_norms: if self.qk_layer_norms:
q = self.q_layer_norm(q) q = self.q_layer_norm(q)
...@@ -219,25 +241,34 @@ class IdeficsPerceiverAttention(nn.Module): ...@@ -219,25 +241,34 @@ class IdeficsPerceiverAttention(nn.Module):
class IdeficsMLP(nn.Module): class IdeficsMLP(nn.Module):
def __init__(self, def __init__(
prefix, self,
intermediate_size, prefix,
config, intermediate_size,
weights, config,
): weights,
):
"""Simple MLP block with intermediate_size and embedding size""" """Simple MLP block with intermediate_size and embedding size"""
super().__init__() super().__init__()
self.embed_dim = config.vision_config.embed_dim self.embed_dim = config.vision_config.embed_dim
self.ln = nn.LayerNorm.load(prefix=f"{prefix}.ln", weights=weights, eps=EPS) self.ln = nn.LayerNorm.load(prefix=f"{prefix}.ln", weights=weights, eps=EPS)
self.fc = TensorParallelColumnLinear.load( self.fc = TensorParallelColumnLinear.load(
config=config, prefix=f"{prefix}.fc", weights=weights, bias=False, config=config,
prefix=f"{prefix}.fc",
weights=weights,
bias=False,
) )
self.act = nn.ReLU() self.act = nn.ReLU()
self.c_proj = TensorParallelRowLinear.load( self.c_proj = TensorParallelRowLinear.load(
config=config, prefix=f"{prefix}.c_proj", weights=weights, bias=False, config=config,
prefix=f"{prefix}.c_proj",
weights=weights,
bias=False,
) )
def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: def forward(
self, hidden_states: Optional[Tuple[torch.FloatTensor]]
) -> torch.FloatTensor:
hidden_states = self.ln(hidden_states) hidden_states = self.ln(hidden_states)
hidden_states = self.fc(hidden_states) hidden_states = self.fc(hidden_states)
hidden_states = self.act(hidden_states) hidden_states = self.act(hidden_states)
......
...@@ -21,9 +21,16 @@ from urllib.parse import urlparse ...@@ -21,9 +21,16 @@ from urllib.parse import urlparse
from transformers.feature_extraction_utils import BatchFeature from transformers.feature_extraction_utils import BatchFeature
from transformers.processing_utils import ProcessorMixin from transformers.processing_utils import ProcessorMixin
from transformers.tokenization_utils_base import BatchEncoding, PaddingStrategy, TextInput, TruncationStrategy from transformers.tokenization_utils_base import (
BatchEncoding,
PaddingStrategy,
TextInput,
TruncationStrategy,
)
from transformers.utils import TensorType, is_torch_available from transformers.utils import TensorType, is_torch_available
from text_generation_server.models.custom_modeling.idefics_image_processing import IdeficsImageProcessor from text_generation_server.models.custom_modeling.idefics_image_processing import (
IdeficsImageProcessor,
)
if is_torch_available(): if is_torch_available():
...@@ -124,7 +131,14 @@ class IdeficsProcessor(ProcessorMixin): ...@@ -124,7 +131,14 @@ class IdeficsProcessor(ProcessorMixin):
image_processor_class = "IdeficsImageProcessor" image_processor_class = "IdeficsImageProcessor"
tokenizer_class = "LlamaTokenizerFast" tokenizer_class = "LlamaTokenizerFast"
def __init__(self, image_processor, tokenizer=None, image_size=224, add_end_of_utterance_token=None, **kwargs): def __init__(
self,
image_processor,
tokenizer=None,
image_size=224,
add_end_of_utterance_token=None,
**kwargs,
):
if image_processor is None: if image_processor is None:
raise ValueError("You need to specify an `image_processor`.") raise ValueError("You need to specify an `image_processor`.")
if tokenizer is None: if tokenizer is None:
...@@ -142,7 +156,8 @@ class IdeficsProcessor(ProcessorMixin): ...@@ -142,7 +156,8 @@ class IdeficsProcessor(ProcessorMixin):
self.tokenizer_was_trained_with_end_of_utterance_token = ( self.tokenizer_was_trained_with_end_of_utterance_token = (
True True
if "<end_of_utterance>" in self.tokenizer.special_tokens_map.get("additional_special_tokens", []) if "<end_of_utterance>"
in self.tokenizer.special_tokens_map.get("additional_special_tokens", [])
else False else False
) )
...@@ -265,7 +280,9 @@ class IdeficsProcessor(ProcessorMixin): ...@@ -265,7 +280,9 @@ class IdeficsProcessor(ProcessorMixin):
# if the value isn't overriden by the user, check if the tokenizer was trained with this token and then use it # if the value isn't overriden by the user, check if the tokenizer was trained with this token and then use it
if add_end_of_utterance_token is None: if add_end_of_utterance_token is None:
add_end_of_utterance_token = self.tokenizer_was_trained_with_end_of_utterance_token add_end_of_utterance_token = (
self.tokenizer_was_trained_with_end_of_utterance_token
)
# turn non-batched prompts into batched # turn non-batched prompts into batched
if not any(isinstance(i, list) for i in prompts): if not any(isinstance(i, list) for i in prompts):
...@@ -358,10 +375,14 @@ class IdeficsProcessor(ProcessorMixin): ...@@ -358,10 +375,14 @@ class IdeficsProcessor(ProcessorMixin):
current_images = images[:local_max_num_images] current_images = images[:local_max_num_images]
if len(current_images) > 0: if len(current_images) > 0:
padded_image_tensor = torch.zeros(max_num_images, *current_images.size()[1:]) padded_image_tensor = torch.zeros(
max_num_images, *current_images.size()[1:]
)
padded_image_tensor[: current_images.size(0)] = current_images padded_image_tensor[: current_images.size(0)] = current_images
else: else:
padded_image_tensor = torch.zeros(max_num_images, *self.default_image_dims) padded_image_tensor = torch.zeros(
max_num_images, *self.default_image_dims
)
output_images.append(padded_image_tensor) output_images.append(padded_image_tensor)
output_input_ids.append(torch.tensor(padded_input_ids)) output_input_ids.append(torch.tensor(padded_input_ids))
...@@ -373,14 +394,19 @@ class IdeficsProcessor(ProcessorMixin): ...@@ -373,14 +394,19 @@ class IdeficsProcessor(ProcessorMixin):
output_attention_masks = torch.stack(output_attention_masks) output_attention_masks = torch.stack(output_attention_masks)
if at_least_one_image: if at_least_one_image:
image_attention_mask, _ = image_attention_mask_for_packed_input_ids(output_input_ids, self.tokenizer) image_attention_mask, _ = image_attention_mask_for_packed_input_ids(
output_input_ids, self.tokenizer
)
image_attention_mask = incremental_to_binary_attention_mask( image_attention_mask = incremental_to_binary_attention_mask(
image_attention_mask, num_classes=max_num_images image_attention_mask, num_classes=max_num_images
) )
else: else:
# in full language mode we set the image mask to all-0s # in full language mode we set the image mask to all-0s
image_attention_mask = torch.zeros( image_attention_mask = torch.zeros(
output_input_ids.shape[0], output_input_ids.shape[1], 1, dtype=torch.bool output_input_ids.shape[0],
output_input_ids.shape[1],
1,
dtype=torch.bool,
) )
return BatchFeature( return BatchFeature(
......
...@@ -75,7 +75,9 @@ class IdeficsVisionEmbeddings(nn.Module): ...@@ -75,7 +75,9 @@ class IdeficsVisionEmbeddings(nn.Module):
self.image_size = config.image_size self.image_size = config.image_size
self.patch_size = config.patch_size self.patch_size = config.patch_size
self.class_embedding = nn.Parameter(weights.get_tensor(f"{prefix}.class_embedding")) self.class_embedding = nn.Parameter(
weights.get_tensor(f"{prefix}.class_embedding")
)
self.patch_embedding = nn.Conv2d.load_no_bias( self.patch_embedding = nn.Conv2d.load_no_bias(
prefix=f"{prefix}.patch_embedding", prefix=f"{prefix}.patch_embedding",
...@@ -91,12 +93,16 @@ class IdeficsVisionEmbeddings(nn.Module): ...@@ -91,12 +93,16 @@ class IdeficsVisionEmbeddings(nn.Module):
self.position_embedding = TensorParallelEmbedding( self.position_embedding = TensorParallelEmbedding(
prefix="model.vision_model.embeddings.position_embedding", weights=weights prefix="model.vision_model.embeddings.position_embedding", weights=weights
) )
self.position_ids = torch.arange(self.num_positions).expand((1, -1)).to(device=weights.device) self.position_ids = (
torch.arange(self.num_positions).expand((1, -1)).to(device=weights.device)
)
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0] batch_size = pixel_values.shape[0]
target_dtype = self.patch_embedding.weight.dtype target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] patch_embeds = self.patch_embedding(
pixel_values.to(dtype=target_dtype)
) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2) patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1) class_embeds = self.class_embedding.expand(batch_size, 1, -1)
...@@ -132,7 +138,6 @@ class IdeficsVisionAttention(nn.Module): ...@@ -132,7 +138,6 @@ class IdeficsVisionAttention(nn.Module):
self.num_heads = self.num_heads // weights.process_group.size() self.num_heads = self.num_heads // weights.process_group.size()
self.embed_dim = self.embed_dim // weights.process_group.size() self.embed_dim = self.embed_dim // weights.process_group.size()
self.k_proj = TensorParallelColumnLinear.load( self.k_proj = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.k_proj", weights=weights, bias=True config, prefix=f"{prefix}.k_proj", weights=weights, bias=True
) )
...@@ -147,7 +152,11 @@ class IdeficsVisionAttention(nn.Module): ...@@ -147,7 +152,11 @@ class IdeficsVisionAttention(nn.Module):
) )
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() return (
tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
.transpose(1, 2)
.contiguous()
)
def forward( def forward(
self, self,
...@@ -186,7 +195,10 @@ class IdeficsVisionAttention(nn.Module): ...@@ -186,7 +195,10 @@ class IdeficsVisionAttention(nn.Module):
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {causal_attention_mask.size()}" f" {causal_attention_mask.size()}"
) )
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask attn_weights = (
attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ causal_attention_mask
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if attention_mask is not None: if attention_mask is not None:
...@@ -194,7 +206,10 @@ class IdeficsVisionAttention(nn.Module): ...@@ -194,7 +206,10 @@ class IdeficsVisionAttention(nn.Module):
raise ValueError( raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
) )
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = (
attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attention_mask
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = nn.functional.softmax(attn_weights, dim=-1) attn_weights = nn.functional.softmax(attn_weights, dim=-1)
...@@ -204,12 +219,18 @@ class IdeficsVisionAttention(nn.Module): ...@@ -204,12 +219,18 @@ class IdeficsVisionAttention(nn.Module):
# make sure that attn_weights keeps its gradient. # make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to reshaped # In order to do so, attn_weights have to reshaped
# twice and have to be reused in the following # twice and have to be reused in the following
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights_reshaped = attn_weights.view(
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) bsz, self.num_heads, tgt_len, src_len
)
attn_weights = attn_weights_reshaped.view(
bsz * self.num_heads, tgt_len, src_len
)
else: else:
attn_weights_reshaped = None attn_weights_reshaped = None
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) attn_probs = nn.functional.dropout(
attn_weights, p=self.dropout, training=self.training
)
attn_output = torch.bmm(attn_probs, value_states) attn_output = torch.bmm(attn_probs, value_states)
...@@ -253,11 +274,15 @@ class IdeficsVisionEncoderLayer(nn.Module): ...@@ -253,11 +274,15 @@ class IdeficsVisionEncoderLayer(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
self.self_attn = IdeficsVisionAttention(prefix=f"{prefix}.self_attn", config=config, weights=weights) self.self_attn = IdeficsVisionAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights
)
self.layer_norm1 = nn.LayerNorm.load( self.layer_norm1 = nn.LayerNorm.load(
prefix=f"{prefix}.layer_norm1", weights=weights, eps=config.layer_norm_eps prefix=f"{prefix}.layer_norm1", weights=weights, eps=config.layer_norm_eps
) )
self.mlp = IdeficsVisionMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) self.mlp = IdeficsVisionMLP(
prefix=f"{prefix}.mlp", config=config, weights=weights
)
self.layer_norm2 = nn.LayerNorm.load( self.layer_norm2 = nn.LayerNorm.load(
prefix=f"{prefix}.layer_norm2", weights=weights, eps=config.layer_norm_eps prefix=f"{prefix}.layer_norm2", weights=weights, eps=config.layer_norm_eps
) )
...@@ -318,7 +343,11 @@ class IdeficsVisionEncoder(nn.Module): ...@@ -318,7 +343,11 @@ class IdeficsVisionEncoder(nn.Module):
self.config = config self.config = config
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
IdeficsVisionEncoderLayer(prefix=f"{prefix}.encoder.layers.{layer_id}", config=config, weights=weights) IdeficsVisionEncoderLayer(
prefix=f"{prefix}.encoder.layers.{layer_id}",
config=config,
weights=weights,
)
for layer_id in range(config.num_hidden_layers) for layer_id in range(config.num_hidden_layers)
] ]
) )
...@@ -362,11 +391,19 @@ class IdeficsVisionEncoder(nn.Module): ...@@ -362,11 +391,19 @@ class IdeficsVisionEncoder(nn.Module):
return_dict (`bool`, *optional*): return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
""" """
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
encoder_states = () if output_hidden_states else None encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
...@@ -406,9 +443,15 @@ class IdeficsVisionEncoder(nn.Module): ...@@ -406,9 +443,15 @@ class IdeficsVisionEncoder(nn.Module):
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
if not return_dict: if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) return tuple(
v
for v in [hidden_states, encoder_states, all_attentions]
if v is not None
)
return BaseModelOutput( return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions last_hidden_state=hidden_states,
hidden_states=encoder_states,
attentions=all_attentions,
) )
...@@ -419,13 +462,19 @@ class IdeficsVisionTransformer(nn.Module): ...@@ -419,13 +462,19 @@ class IdeficsVisionTransformer(nn.Module):
self.config = config self.config = config
embed_dim = config.hidden_size embed_dim = config.hidden_size
self.embeddings = IdeficsVisionEmbeddings(prefix=f"{prefix}.embeddings", config=config, weights=weights) self.embeddings = IdeficsVisionEmbeddings(
prefix=f"{prefix}.embeddings", config=config, weights=weights
)
self.pre_layrnorm = nn.LayerNorm.load( self.pre_layrnorm = nn.LayerNorm.load(
prefix=f"{prefix}.pre_layrnorm", weights=weights, eps=config.layer_norm_eps prefix=f"{prefix}.pre_layrnorm", weights=weights, eps=config.layer_norm_eps
) )
self.encoder = IdeficsVisionEncoder(prefix=prefix, config=config, weights=weights) self.encoder = IdeficsVisionEncoder(
prefix=prefix, config=config, weights=weights
)
self.post_layernorm = nn.LayerNorm.load( self.post_layernorm = nn.LayerNorm.load(
prefix=f"{prefix}.post_layernorm", weights=weights, eps=config.layer_norm_eps prefix=f"{prefix}.post_layernorm",
weights=weights,
eps=config.layer_norm_eps,
) )
# copied from transformers.models.clip.modeling_clip.CLIPVisionTransformer.forward # copied from transformers.models.clip.modeling_clip.CLIPVisionTransformer.forward
...@@ -440,11 +489,19 @@ class IdeficsVisionTransformer(nn.Module): ...@@ -440,11 +489,19 @@ class IdeficsVisionTransformer(nn.Module):
Returns: Returns:
""" """
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None: if pixel_values is None:
raise ValueError("You have to specify pixel_values") raise ValueError("You have to specify pixel_values")
......
...@@ -49,7 +49,10 @@ from text_generation_server.utils.layers import ( ...@@ -49,7 +49,10 @@ from text_generation_server.utils.layers import (
CUSTOM_KERNELS_ENABLED = False CUSTOM_KERNELS_ENABLED = False
if torch.cuda.is_available() and not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True": if (
torch.cuda.is_available()
and not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True"
):
try: try:
from custom_kernels import fused_attention_cuda from custom_kernels import fused_attention_cuda
......
...@@ -1005,9 +1005,12 @@ class FlashCausalLM(Model): ...@@ -1005,9 +1005,12 @@ class FlashCausalLM(Model):
# Decode generated tokens # Decode generated tokens
output_text, _, _ = self.decode_token( output_text, _, _ = self.decode_token(
all_input_ids, all_input_ids,
prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1, prefix_offset=len(all_input_ids)
read_offset=len(all_input_ids) - stopping_criteria.current_tokens, - stopping_criteria.current_tokens
skip_special_tokens=True - 1,
read_offset=len(all_input_ids)
- stopping_criteria.current_tokens,
skip_special_tokens=True,
) )
generated_text = GeneratedText( generated_text = GeneratedText(
output_text, output_text,
......
...@@ -8,7 +8,13 @@ import re ...@@ -8,7 +8,13 @@ import re
from dataclasses import dataclass from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase, ProcessorMixin from transformers import (
AutoProcessor,
AutoTokenizer,
AutoModelForCausalLM,
PreTrainedTokenizerBase,
ProcessorMixin,
)
from typing import Optional, Tuple, List, Type, Dict from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.models import Model from text_generation_server.models import Model
...@@ -23,7 +29,8 @@ from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sam ...@@ -23,7 +29,8 @@ from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sam
import re import re
IMAGES = re.compile(r'!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)') IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)")
def split(string): def split(string):
parts = [] parts = []
...@@ -41,6 +48,7 @@ def split(string): ...@@ -41,6 +48,7 @@ def split(string):
return parts return parts
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
...@@ -94,7 +102,7 @@ class IdeficsCausalLMBatch(Batch): ...@@ -94,7 +102,7 @@ class IdeficsCausalLMBatch(Batch):
cls, cls,
pb: generate_pb2.Batch, pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
processor: ProcessorMixin, # Hack processor: ProcessorMixin, # Hack
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
) -> "IdeficsCausalLMBatch": ) -> "IdeficsCausalLMBatch":
...@@ -137,12 +145,16 @@ class IdeficsCausalLMBatch(Batch): ...@@ -137,12 +145,16 @@ class IdeficsCausalLMBatch(Batch):
padding=True, padding=True,
truncation=True, truncation=True,
max_length=max_truncation, max_length=max_truncation,
add_end_of_utterance_token=False, # Already taken care of inside the prompts, so bypassing the processor's handling of this token add_end_of_utterance_token=False, # Already taken care of inside the prompts, so bypassing the processor's handling of this token
).to(device) ).to(device)
for _ in pb.requests: for _ in pb.requests:
input_len = tokenized_inputs["input_ids"].shape[1] input_len = tokenized_inputs["input_ids"].shape[1]
prefix_offsets.append(input_len - 5) # To decode without potential fallbacks errors prefix_offsets.append(
read_offsets.append(input_len) # To decode without potential fallbacks errors input_len - 5
) # To decode without potential fallbacks errors
read_offsets.append(
input_len
) # To decode without potential fallbacks errors
input_lengths = tokenized_inputs["attention_mask"].sum(1) input_lengths = tokenized_inputs["attention_mask"].sum(1)
max_input_length = input_lengths.max() max_input_length = input_lengths.max()
...@@ -158,14 +170,21 @@ class IdeficsCausalLMBatch(Batch): ...@@ -158,14 +170,21 @@ class IdeficsCausalLMBatch(Batch):
attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"] attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
# Do the same for image_attention_mask # Do the same for image_attention_mask
image_attention_mask = input_ids.new_zeros( image_attention_mask = input_ids.new_zeros(
(pb.size, max_input_length + padding_right_offset, tokenized_inputs["pixel_values"].size(1)) (
pb.size,
max_input_length + padding_right_offset,
tokenized_inputs["pixel_values"].size(1),
)
) )
image_attention_mask[:, :max_input_length, :] = tokenized_inputs["image_attention_mask"] image_attention_mask[:, :max_input_length, :] = tokenized_inputs[
"image_attention_mask"
]
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1 position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1) # It's input_ids but splitted into a tuple of tensors where each tensor is (seq_len, 1) size. It is then transformed into a list all_input_ids = tokenized_inputs["input_ids"].T.split(
1, dim=1
) # It's input_ids but splitted into a tuple of tensors where each tensor is (seq_len, 1) size. It is then transformed into a list
max_tokens = len(inputs) * (max_input_length + max_decode_tokens) max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
...@@ -259,7 +278,7 @@ class IdeficsCausalLMBatch(Batch): ...@@ -259,7 +278,7 @@ class IdeficsCausalLMBatch(Batch):
self.image_attention_mask.shape[1] - self.padding_right_offset self.image_attention_mask.shape[1] - self.padding_right_offset
) )
+ new_padding_right_offset, + new_padding_right_offset,
: :,
] ]
if self.image_hidden_states is None: if self.image_hidden_states is None:
image_hidden_states = None image_hidden_states = None
...@@ -308,7 +327,9 @@ class IdeficsCausalLMBatch(Batch): ...@@ -308,7 +327,9 @@ class IdeficsCausalLMBatch(Batch):
@classmethod @classmethod
@tracer.start_as_current_span("concatenate") @tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["IdeficsCausalLMBatch"]) -> "IdeficsCausalLMBatch": def concatenate(
cls, batches: List["IdeficsCausalLMBatch"]
) -> "IdeficsCausalLMBatch":
# It adds new requests to the batch # It adds new requests to the batch
# Used for padding # Used for padding
total_batch_size = 0 total_batch_size = 0
...@@ -383,12 +404,20 @@ class IdeficsCausalLMBatch(Batch): ...@@ -383,12 +404,20 @@ class IdeficsCausalLMBatch(Batch):
curr_batch_max_num_images = batch.pixel_values.size(1) curr_batch_max_num_images = batch.pixel_values.size(1)
if pixel_values is None: if pixel_values is None:
pixel_values = batch.pixel_values.new_zeros((total_batch_size, max_num_images, 3, 224, 224)) pixel_values = batch.pixel_values.new_zeros(
pixel_values[start_index:end_index, :curr_batch_max_num_images] = batch.pixel_values (total_batch_size, max_num_images, 3, 224, 224)
)
pixel_values[
start_index:end_index, :curr_batch_max_num_images
] = batch.pixel_values
if image_attention_mask is None: if image_attention_mask is None:
image_attention_mask = batch.image_attention_mask.new_zeros( image_attention_mask = batch.image_attention_mask.new_zeros(
(total_batch_size, max_input_length + padding_right_offset, max_num_images) (
total_batch_size,
max_input_length + padding_right_offset,
max_num_images,
)
) )
# We need to slice the attention mask to remove padding from previous steps # We need to slice the attention mask to remove padding from previous steps
...@@ -409,11 +438,9 @@ class IdeficsCausalLMBatch(Batch): ...@@ -409,11 +438,9 @@ class IdeficsCausalLMBatch(Batch):
image_attention_mask[ image_attention_mask[
start_index:end_index, start_index:end_index,
left_offset:-padding_right_offset, left_offset:-padding_right_offset,
:curr_batch_max_num_images :curr_batch_max_num_images,
] = batch.image_attention_mask[ ] = batch.image_attention_mask[
:, :, batch_left_offset : -batch.padding_right_offset, :
batch_left_offset : - batch.padding_right_offset,
:
] ]
# Create empty tensor # Create empty tensor
...@@ -550,7 +577,9 @@ class IdeficsCausalLM(Model): ...@@ -550,7 +577,9 @@ class IdeficsCausalLM(Model):
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
from text_generation_server.models.custom_modeling.idefics_modeling import IdeficsForVisionText2Text from text_generation_server.models.custom_modeling.idefics_modeling import (
IdeficsForVisionText2Text,
)
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
...@@ -650,9 +679,13 @@ class IdeficsCausalLM(Model): ...@@ -650,9 +679,13 @@ class IdeficsCausalLM(Model):
# this is due to the nature IDEFICS: it's an encoder decoder, and so when decoding, only the currently generated # this is due to the nature IDEFICS: it's an encoder decoder, and so when decoding, only the currently generated
# token need to attend to the encoder hidden states (i.e. the vision encoder) # token need to attend to the encoder hidden states (i.e. the vision encoder)
# Also see seq2seq_lm.Seq2SeqLM.generate_token which has roughly the same logic # Also see seq2seq_lm.Seq2SeqLM.generate_token which has roughly the same logic
image_attention_mask = batch.image_attention_mask[:, -(batch.padding_right_offset+1)].unsqueeze(1) image_attention_mask = batch.image_attention_mask[
:, -(batch.padding_right_offset + 1)
].unsqueeze(1)
else: else:
image_attention_mask = batch.image_attention_mask[:, : -batch.padding_right_offset] image_attention_mask = batch.image_attention_mask[
:, : -batch.padding_right_offset
]
logits, past, image_hidden_states = self.forward( logits, past, image_hidden_states = self.forward(
input_ids=batch.input_ids, input_ids=batch.input_ids,
...@@ -725,9 +758,12 @@ class IdeficsCausalLM(Model): ...@@ -725,9 +758,12 @@ class IdeficsCausalLM(Model):
# Decode generated tokens # Decode generated tokens
output_text, _, _ = self.decode_token( output_text, _, _ = self.decode_token(
all_input_ids[:, 0], all_input_ids[:, 0],
prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1, prefix_offset=len(all_input_ids)
read_offset=len(all_input_ids) - stopping_criteria.current_tokens, - stopping_criteria.current_tokens
skip_special_tokens=True - 1,
read_offset=len(all_input_ids)
- stopping_criteria.current_tokens,
skip_special_tokens=True,
) )
# Get seed # Get seed
if isinstance(next_token_chooser.choice, Sampling): if isinstance(next_token_chooser.choice, Sampling):
...@@ -761,7 +797,7 @@ class IdeficsCausalLM(Model): ...@@ -761,7 +797,7 @@ class IdeficsCausalLM(Model):
else: else:
prefill_tokens = None prefill_tokens = None
top_tokens=None top_tokens = None
generation = Generation( generation = Generation(
request.id, request.id,
...@@ -771,7 +807,7 @@ class IdeficsCausalLM(Model): ...@@ -771,7 +807,7 @@ class IdeficsCausalLM(Model):
next_token_text, next_token_text,
next_token_id_squeezed.item() in self.all_special_ids, next_token_id_squeezed.item() in self.all_special_ids,
generated_text, generated_text,
top_tokens top_tokens,
) )
generations.append(generation) generations.append(generation)
...@@ -793,7 +829,9 @@ class IdeficsCausalLM(Model): ...@@ -793,7 +829,9 @@ class IdeficsCausalLM(Model):
# Update attention_mask as we added a new token to input_ids # Update attention_mask as we added a new token to input_ids
batch.attention_mask[:, -batch.padding_right_offset] = 1 batch.attention_mask[:, -batch.padding_right_offset] = 1
batch.image_attention_mask[:, -batch.padding_right_offset, :] = batch.image_attention_mask[:, -(batch.padding_right_offset+1), :] batch.image_attention_mask[
:, -batch.padding_right_offset, :
] = batch.image_attention_mask[:, -(batch.padding_right_offset + 1), :]
# Decrease right offset # Decrease right offset
batch.padding_right_offset -= 1 batch.padding_right_offset -= 1
......
...@@ -71,7 +71,8 @@ class Model(ABC): ...@@ -71,7 +71,8 @@ class Model(ABC):
# The prefix text is necessary only to defeat cleanup algorithms in the decode # The prefix text is necessary only to defeat cleanup algorithms in the decode
# which decide to add a space or not depending on the surrounding ids. # which decide to add a space or not depending on the surrounding ids.
prefix_text = self.tokenizer.decode( prefix_text = self.tokenizer.decode(
all_input_ids[prefix_offset:read_offset], skip_special_tokens=skip_special_tokens all_input_ids[prefix_offset:read_offset],
skip_special_tokens=skip_special_tokens,
) )
new_text = self.tokenizer.decode( new_text = self.tokenizer.decode(
all_input_ids[prefix_offset:], skip_special_tokens=skip_special_tokens all_input_ids[prefix_offset:], skip_special_tokens=skip_special_tokens
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment