Unverified Commit fb2f74e2 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Refactor dead code - Removing all `flash_xxx.py` files. (#2166)

* Refactor dead code.

* First working step.

* Remove a lot of duplicated code.

* More dead code.

* More cleanup.

* Fix Santacoder test.

* Fixing the simple tests.

* Fixing sharding.

* Fixes for VLM.

* Fixing santacoder (num_kv_heads hardcoded).

* Removing more dead code.

* Fixing `config.n_head`.

* Stopping earlier because of `<end_of_utterance>` in idefics2.

* Addresses comments.

* Removing the dead code.

* Fuse back mistral into FlashCausalLM.

* Finish removal.

* Fixing docs + causal_lm `batch_class`.

* Fixing docs + causal.lm.

* Add default to Gemma Causality.

* Default value for gemma/gemma2.

* Wrong default.
parent c6bcadf8
import torch
import torch.distributed
import time
from dataclasses import dataclass
from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
from transformers import (
AutoTokenizer,
AutoModelForSeq2SeqLM,
PreTrainedTokenizerBase,
AutoConfig,
)
from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.models import Model
......@@ -531,6 +542,80 @@ class Seq2SeqLM(Model):
def __init__(
self,
model_id: str,
model_class,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
default_dtype=torch.float16,
trust_remote_code: bool = False,
config_class=AutoConfig,
tokenizer_class=AutoTokenizer,
aliases=None,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = default_dtype if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = default_dtype if dtype is None else dtype
else:
device = torch.device("cpu")
# Float16 doesn't exist on target.
dtype = torch.bfloat16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
config = config_class.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
)
config.quantize = quantize
config.speculator = speculator
tokenizer = tokenizer_class.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
tokenizer.bos_token_id = config.decoder_start_token_id
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames,
device=device,
dtype=dtype,
process_group=self.process_group,
aliases=aliases,
)
if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = model_class(config, weights)
torch.distributed.barrier(group=self.process_group)
super().__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
@classmethod
def fallback(
cls,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
......@@ -574,7 +659,11 @@ class Seq2SeqLM(Model):
)
tokenizer.bos_token_id = model.config.decoder_start_token_id
super(Seq2SeqLM, self).__init__(
self = cls.__new__(
cls,
)
super().__init__(
self,
model_id=model_id,
model=model,
tokenizer=tokenizer,
......@@ -582,16 +671,12 @@ class Seq2SeqLM(Model):
dtype=dtype,
device=device,
)
return self
@property
def batch_type(self) -> Type[Seq2SeqLMBatch]:
return Seq2SeqLMBatch
def decode(self, decoder_ids: List[int]) -> str:
return self.tokenizer.decode(
decoder_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
def forward(
self,
input_ids,
......
import torch
import torch.distributed
from typing import List, Optional, Tuple
from transformers import (
AutoTokenizer,
AutoConfig,
)
from text_generation_server.models import Seq2SeqLM
from text_generation_server.models.custom_modeling.t5_modeling import (
T5ForConditionalGeneration,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
class T5Sharded(Seq2SeqLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
config = AutoConfig.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
)
config.quantize = quantize
config.speculator = speculator
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
tokenizer.bos_token_id = config.decoder_start_token_id
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames,
device=device,
dtype=dtype,
process_group=self.process_group,
aliases={
"shared.weight": [
"encoder.embed_tokens.weight",
"decoder.embed_tokens.weight",
]
},
)
model = T5ForConditionalGeneration(config, weights)
torch.distributed.barrier(group=self.process_group)
super(Seq2SeqLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
def forward(
self,
input_ids,
attention_mask,
decoder_input_ids,
decoder_attention_mask: Optional,
encoder_last_hidden_state: Optional,
past_key_values: Optional = None,
) -> Tuple[
torch.Tensor,
torch.Tensor,
List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
]:
# Model Forward
outputs, speculative_logits = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_last_hidden_state,
past_key_values=past_key_values,
use_cache=True,
)
return (
outputs.logits,
speculative_logits,
outputs.encoder_last_hidden_state,
outputs.past_key_values,
)
......@@ -9,10 +9,11 @@ from typing import Iterable, Optional, Tuple, List, Type, Dict
from transformers import PreTrainedTokenizerBase
from transformers.image_processing_utils import select_best_resolution
from text_generation_server.pb import generate_pb2
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch
from text_generation_server.models.flash_mistral import (
BaseFlashMistral,
from text_generation_server.models.flash_causal_lm import (
FlashCausalLMBatch,
FlashCausalLM,
)
from transformers import AutoProcessor
tracer = trace.get_tracer(__name__)
......@@ -239,10 +240,35 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
return batch
class VlmCausalLM(BaseFlashMistral):
class VlmCausalLM(FlashCausalLM):
def __init__(
self,
model_id: str,
*,
processor_class=AutoProcessor,
processor_kwargs=None,
batch_class=VlmCausalLMBatch,
revision,
trust_remote_code: bool,
**kwargs,
):
if processor_kwargs is None:
processor_kwargs = {}
self.processor = processor_class.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
**processor_kwargs,
)
self.batch_class = batch_class
super().__init__(model_id=model_id, **kwargs)
@property
def batch_type(self) -> Type[VlmCausalLMBatch]:
return VlmCausalLMBatch
return self.batch_class
def max_past(self) -> Optional[int]:
return getattr(self.model.text_model, "max_past", None)
def forward(
self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment