Unverified Commit fb2f74e2 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Refactor dead code - Removing all `flash_xxx.py` files. (#2166)

* Refactor dead code.

* First working step.

* Remove a lot of duplicated code.

* More dead code.

* More cleanup.

* Fix Santacoder test.

* Fixing the simple tests.

* Fixing sharding.

* Fixes for VLM.

* Fixing santacoder (num_kv_heads hardcoded).

* Removing more dead code.

* Fixing `config.n_head`.

* Stopping earlier because of `<end_of_utterance>` in idefics2.

* Addresses comments.

* Removing the dead code.

* Fuse back mistral into FlashCausalLM.

* Finish removal.

* Fixing docs + causal_lm `batch_class`.

* Fixing docs + causal.lm.

* Add default to Gemma Causality.

* Default value for gemma/gemma2.

* Wrong default.
parent c6bcadf8
import torch import torch
import torch.distributed
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase from transformers import (
AutoTokenizer,
AutoModelForSeq2SeqLM,
PreTrainedTokenizerBase,
AutoConfig,
)
from typing import Optional, Tuple, List, Type, Dict from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.models import Model from text_generation_server.models import Model
...@@ -531,6 +542,80 @@ class Seq2SeqLM(Model): ...@@ -531,6 +542,80 @@ class Seq2SeqLM(Model):
def __init__( def __init__(
self, self,
model_id: str, model_id: str,
model_class,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
default_dtype=torch.float16,
trust_remote_code: bool = False,
config_class=AutoConfig,
tokenizer_class=AutoTokenizer,
aliases=None,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = default_dtype if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = default_dtype if dtype is None else dtype
else:
device = torch.device("cpu")
# Float16 doesn't exist on target.
dtype = torch.bfloat16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
config = config_class.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
)
config.quantize = quantize
config.speculator = speculator
tokenizer = tokenizer_class.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
tokenizer.bos_token_id = config.decoder_start_token_id
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames,
device=device,
dtype=dtype,
process_group=self.process_group,
aliases=aliases,
)
if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = model_class(config, weights)
torch.distributed.barrier(group=self.process_group)
super().__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
@classmethod
def fallback(
cls,
model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
speculator: Optional[str] = None, speculator: Optional[str] = None,
...@@ -574,7 +659,11 @@ class Seq2SeqLM(Model): ...@@ -574,7 +659,11 @@ class Seq2SeqLM(Model):
) )
tokenizer.bos_token_id = model.config.decoder_start_token_id tokenizer.bos_token_id = model.config.decoder_start_token_id
super(Seq2SeqLM, self).__init__( self = cls.__new__(
cls,
)
super().__init__(
self,
model_id=model_id, model_id=model_id,
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
...@@ -582,16 +671,12 @@ class Seq2SeqLM(Model): ...@@ -582,16 +671,12 @@ class Seq2SeqLM(Model):
dtype=dtype, dtype=dtype,
device=device, device=device,
) )
return self
@property @property
def batch_type(self) -> Type[Seq2SeqLMBatch]: def batch_type(self) -> Type[Seq2SeqLMBatch]:
return Seq2SeqLMBatch return Seq2SeqLMBatch
def decode(self, decoder_ids: List[int]) -> str:
return self.tokenizer.decode(
decoder_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
def forward( def forward(
self, self,
input_ids, input_ids,
......
import torch
import torch.distributed
from typing import List, Optional, Tuple
from transformers import (
AutoTokenizer,
AutoConfig,
)
from text_generation_server.models import Seq2SeqLM
from text_generation_server.models.custom_modeling.t5_modeling import (
T5ForConditionalGeneration,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
class T5Sharded(Seq2SeqLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
config = AutoConfig.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
)
config.quantize = quantize
config.speculator = speculator
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
tokenizer.bos_token_id = config.decoder_start_token_id
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames,
device=device,
dtype=dtype,
process_group=self.process_group,
aliases={
"shared.weight": [
"encoder.embed_tokens.weight",
"decoder.embed_tokens.weight",
]
},
)
model = T5ForConditionalGeneration(config, weights)
torch.distributed.barrier(group=self.process_group)
super(Seq2SeqLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
def forward(
self,
input_ids,
attention_mask,
decoder_input_ids,
decoder_attention_mask: Optional,
encoder_last_hidden_state: Optional,
past_key_values: Optional = None,
) -> Tuple[
torch.Tensor,
torch.Tensor,
List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
]:
# Model Forward
outputs, speculative_logits = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_last_hidden_state,
past_key_values=past_key_values,
use_cache=True,
)
return (
outputs.logits,
speculative_logits,
outputs.encoder_last_hidden_state,
outputs.past_key_values,
)
...@@ -9,10 +9,11 @@ from typing import Iterable, Optional, Tuple, List, Type, Dict ...@@ -9,10 +9,11 @@ from typing import Iterable, Optional, Tuple, List, Type, Dict
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from transformers.image_processing_utils import select_best_resolution from transformers.image_processing_utils import select_best_resolution
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch from text_generation_server.models.flash_causal_lm import (
from text_generation_server.models.flash_mistral import ( FlashCausalLMBatch,
BaseFlashMistral, FlashCausalLM,
) )
from transformers import AutoProcessor
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
...@@ -239,10 +240,35 @@ class VlmCausalLMBatch(FlashCausalLMBatch): ...@@ -239,10 +240,35 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
return batch return batch
class VlmCausalLM(BaseFlashMistral): class VlmCausalLM(FlashCausalLM):
def __init__(
self,
model_id: str,
*,
processor_class=AutoProcessor,
processor_kwargs=None,
batch_class=VlmCausalLMBatch,
revision,
trust_remote_code: bool,
**kwargs,
):
if processor_kwargs is None:
processor_kwargs = {}
self.processor = processor_class.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
**processor_kwargs,
)
self.batch_class = batch_class
super().__init__(model_id=model_id, **kwargs)
@property @property
def batch_type(self) -> Type[VlmCausalLMBatch]: def batch_type(self) -> Type[VlmCausalLMBatch]:
return VlmCausalLMBatch return self.batch_class
def max_past(self) -> Optional[int]:
return getattr(self.model.text_model, "max_past", None)
def forward( def forward(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment