Unverified Commit 5fa8ae04 authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat(server): optimize decode for sane tokenizers (#170)

parent 6f0f1d70
...@@ -853,7 +853,7 @@ checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" ...@@ -853,7 +853,7 @@ checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b"
[[package]] [[package]]
name = "grpc-metadata" name = "grpc-metadata"
version = "0.4.1" version = "0.1.0"
dependencies = [ dependencies = [
"opentelemetry", "opentelemetry",
"tonic", "tonic",
...@@ -2140,7 +2140,7 @@ dependencies = [ ...@@ -2140,7 +2140,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-client" name = "text-generation-client"
version = "0.4.3" version = "0.5.0"
dependencies = [ dependencies = [
"futures", "futures",
"grpc-metadata", "grpc-metadata",
......
...@@ -49,6 +49,11 @@ class BloomCausalLMBatch(CausalLMBatch): ...@@ -49,6 +49,11 @@ class BloomCausalLMBatch(CausalLMBatch):
class BLOOM(CausalLM): class BLOOM(CausalLM):
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
super(BLOOM, self).__init__(
model_id=model_id, revision=revision, quantize=quantize, decode_buffer=1
)
@property @property
def batch_type(self) -> Type[CausalLMBatch]: def batch_type(self) -> Type[CausalLMBatch]:
return BloomCausalLMBatch return BloomCausalLMBatch
...@@ -94,8 +99,7 @@ class BLOOMSharded(BLOOM): ...@@ -94,8 +99,7 @@ class BLOOMSharded(BLOOM):
self.model = model.eval().to(dtype) self.model = model.eval().to(dtype)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
tokenizer=tokenizer, tokenizer=tokenizer, device=device, decode_buffer=1
device=device,
) )
@staticmethod @staticmethod
......
...@@ -291,7 +291,13 @@ class CausalLMBatch(Batch): ...@@ -291,7 +291,13 @@ class CausalLMBatch(Batch):
class CausalLM(Model): class CausalLM(Model):
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False): def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: bool = False,
decode_buffer: int = 3,
):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
...@@ -319,8 +325,7 @@ class CausalLM(Model): ...@@ -319,8 +325,7 @@ class CausalLM(Model):
) )
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
tokenizer=tokenizer, tokenizer=tokenizer, device=device, decode_buffer=decode_buffer
device=device,
) )
@property @property
......
...@@ -212,7 +212,8 @@ class FlashCausalLM(Model): ...@@ -212,7 +212,8 @@ class FlashCausalLM(Model):
model_cls: Type[PreTrainedModel], model_cls: Type[PreTrainedModel],
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize=False, quantize: bool = False,
decode_buffer: int = 3,
): ):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
...@@ -237,8 +238,7 @@ class FlashCausalLM(Model): ...@@ -237,8 +238,7 @@ class FlashCausalLM(Model):
) )
super(FlashCausalLM, self).__init__( super(FlashCausalLM, self).__init__(
tokenizer=tokenizer, tokenizer=tokenizer, device=device, decode_buffer=decode_buffer
device=device,
) )
@property @property
......
...@@ -62,8 +62,7 @@ class FlashSantacoder(FlashCausalLM): ...@@ -62,8 +62,7 @@ class FlashSantacoder(FlashCausalLM):
self.model = model.eval().to(device).to(dtype) self.model = model.eval().to(device).to(dtype)
super(FlashCausalLM, self).__init__( super(FlashCausalLM, self).__init__(
tokenizer=tokenizer, tokenizer=tokenizer, device=device, decode_buffer=1
device=device,
) )
@staticmethod @staticmethod
......
...@@ -10,10 +10,19 @@ B = TypeVar("B", bound=Batch) ...@@ -10,10 +10,19 @@ B = TypeVar("B", bound=Batch)
class Model(ABC): class Model(ABC):
def __init__(self, tokenizer: PreTrainedTokenizerBase, device: torch.device): def __init__(
self,
tokenizer: PreTrainedTokenizerBase,
device: torch.device,
decode_buffer: int = 3,
):
if decode_buffer < 1:
raise ValueError("decode_buffer must be >= 1")
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.all_special_ids = set(tokenizer.all_special_ids) self.all_special_ids = set(tokenizer.all_special_ids)
self.device = device self.device = device
self.decode_buffer = decode_buffer
@property @property
@abstractmethod @abstractmethod
...@@ -39,23 +48,37 @@ class Model(ABC): ...@@ -39,23 +48,37 @@ class Model(ABC):
) )
if token_offset is None: if token_offset is None:
token_offset = len(all_input_ids) - 3 token_offset = len(all_input_ids) - self.decode_buffer
# left token buffer
# Decode token_offset token minus last one and token_offset tokens if self.decode_buffer > 1:
results = self.tokenizer.batch_decode( # Decode token_offset token minus last one and token_offset tokens
[all_input_ids[token_offset:-1], all_input_ids[token_offset:]], raw_texts = self.tokenizer.batch_decode(
skip_special_tokens=False, [all_input_ids[token_offset:-1], all_input_ids[token_offset:]],
) skip_special_tokens=False,
)
# default offset is only the last token # default offset is only the last token
if offset is None: offset = len(raw_texts[0])
offset = len(results[0]) sequence_text = raw_texts[1]
else:
# Only decode the last token without using a token buffer
sequence_text = self.tokenizer.decode(
all_input_ids[-1], skip_special_tokens=False
)
# no offset in this case
offset = 0
else:
assert offset is not None
sequence_text = self.tokenizer.decode(
all_input_ids[token_offset:],
skip_special_tokens=False,
)
# get text # get text
text = results[1][offset:] token_text = sequence_text[offset:]
# if text is utf-8 # if text is utf-8
if text and text[-1] != "�": if token_text and token_text[-1] != "�":
return text, None, None return token_text, None, None
else: else:
return "", offset, token_offset return "", offset, token_offset
...@@ -54,8 +54,7 @@ class SantaCoder(CausalLM): ...@@ -54,8 +54,7 @@ class SantaCoder(CausalLM):
) )
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
tokenizer=tokenizer, tokenizer=tokenizer, device=device, decode_buffer=1
device=device,
) )
def decode(self, generated_ids: List[int]) -> str: def decode(self, generated_ids: List[int]) -> str:
......
...@@ -330,7 +330,13 @@ class Seq2SeqLMBatch(Batch): ...@@ -330,7 +330,13 @@ class Seq2SeqLMBatch(Batch):
class Seq2SeqLM(Model): class Seq2SeqLM(Model):
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False): def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: bool = False,
decode_buffer: int = 3,
):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
...@@ -354,8 +360,7 @@ class Seq2SeqLM(Model): ...@@ -354,8 +360,7 @@ class Seq2SeqLM(Model):
tokenizer.bos_token_id = self.model.config.decoder_start_token_id tokenizer.bos_token_id = self.model.config.decoder_start_token_id
super(Seq2SeqLM, self).__init__( super(Seq2SeqLM, self).__init__(
tokenizer=tokenizer, tokenizer=tokenizer, device=device, decode_buffer=decode_buffer
device=device,
) )
@property @property
...@@ -496,7 +501,7 @@ class Seq2SeqLM(Model): ...@@ -496,7 +501,7 @@ class Seq2SeqLM(Model):
if stop: if stop:
# Slice with decoder_input_length to remove padding # Slice with decoder_input_length to remove padding
# Decode all tokens # Decode all tokens
output_text = self.decode(decoder_input_ids[-new_decoder_input_length:]) output_text = self.decode(decoder_input_ids[-decoder_input_length:])
# Get seed # Get seed
if isinstance(next_token_chooser.choice, Sampling): if isinstance(next_token_chooser.choice, Sampling):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment