Unverified Commit 5a582261 authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

fix(server): fix decode token (#334)



Fixes #333

---------
Co-authored-by: default avatarNicolas Patry <patry.nicolas@protonmail.com>
parent dbdc587d
...@@ -13,23 +13,20 @@ B = TypeVar("B", bound=Batch) ...@@ -13,23 +13,20 @@ B = TypeVar("B", bound=Batch)
class Model(ABC): class Model(ABC):
def __init__( def __init__(
self, self,
model: torch.nn.Module,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
requires_padding: bool, requires_padding: bool,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
decode_buffer: int = 3,
rank: int = 0, rank: int = 0,
world_size: int = 1, world_size: int = 1,
): ):
if decode_buffer < 1: self.model = model.eval()
raise ValueError("decode_buffer must be >= 1")
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.all_special_ids = set(tokenizer.all_special_ids) self.all_special_ids = set(tokenizer.all_special_ids)
self.requires_padding = requires_padding self.requires_padding = requires_padding
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
self.decode_buffer = decode_buffer
self.rank = rank self.rank = rank
self.world_size = world_size self.world_size = world_size
self.check_initialized() self.check_initialized()
...@@ -54,52 +51,29 @@ class Model(ABC): ...@@ -54,52 +51,29 @@ class Model(ABC):
def decode_token( def decode_token(
self, self,
all_input_ids: List[int], all_input_ids: List[int],
offset: Optional[int] = None, prefix_offset: int = 0,
token_offset: Optional[int] = None, read_offset: int = 0,
) -> Tuple[str, Optional[int], Optional[int]]: ) -> Tuple[str, int, int]:
"""Hack to hopefully support generate_stream for the maximum number of tokenizers""" """Hack to hopefully support generate_stream for the maximum number of tokenizers"""
if all_input_ids[-1] in self.all_special_ids:
return (
self.tokenizer.decode(all_input_ids[-1], skip_special_tokens=False),
None,
None,
)
if token_offset is None: # The prefix text is necessary only to defeat cleanup algorithms in the decode
token_offset = len(all_input_ids) - self.decode_buffer # which decide to add a space or not depending on the surrounding ids.
# left token buffer prefix_text = self.tokenizer.decode(
if self.decode_buffer > 1: all_input_ids[prefix_offset:read_offset], skip_special_tokens=False
# Decode token_offset token minus last one and token_offset tokens
raw_texts = self.tokenizer.batch_decode(
[all_input_ids[token_offset:-1], all_input_ids[token_offset:]],
skip_special_tokens=False,
) )
new_text = self.tokenizer.decode(
# default offset is only the last token all_input_ids[prefix_offset:], skip_special_tokens=False
offset = len(raw_texts[0])
sequence_text = raw_texts[1]
else:
# Only decode the last token without using a token buffer
sequence_text = self.tokenizer.decode(
all_input_ids[-1], skip_special_tokens=False
) )
# no offset in this case
offset = 0
else:
assert offset is not None
sequence_text = self.tokenizer.decode(
all_input_ids[token_offset:],
skip_special_tokens=False,
)
# get text
token_text = sequence_text[offset:]
# if text is utf-8 if len(new_text) > len(prefix_text) and not new_text.endswith("�"):
if token_text and token_text[-1] != "�": # utf-8 char at the end means it's a potential unfinished byte sequence
return token_text, None, None # from byte fallback tokenization.
# If it's in the middle, it's probably a real invalid id generated
# by the model
new_text = new_text[len(prefix_text) :]
return new_text, read_offset, len(all_input_ids)
else: else:
return "", offset, token_offset return "", prefix_offset, read_offset
def check_initialized(self): def check_initialized(self):
uninitialized_parameters = [] uninitialized_parameters = []
......
...@@ -86,9 +86,9 @@ class OPTSharded(OPT): ...@@ -86,9 +86,9 @@ class OPTSharded(OPT):
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
) )
self.model = model.eval()
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=True, requires_padding=True,
dtype=dtype, dtype=dtype,
......
...@@ -46,24 +46,20 @@ class SantaCoder(CausalLM): ...@@ -46,24 +46,20 @@ class SantaCoder(CausalLM):
} }
) )
self.model = ( model = AutoModelForCausalLM.from_pretrained(
AutoModelForCausalLM.from_pretrained(
model_id, model_id,
revision=revision, revision=revision,
torch_dtype=dtype, torch_dtype=dtype,
load_in_8bit=quantize == "bitsandbytes", load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=True, # required trust_remote_code=True, # required
) ).to(device)
.to(device)
.eval()
)
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=True, requires_padding=True,
dtype=dtype, dtype=dtype,
device=device, device=device,
decode_buffer=1,
) )
def decode(self, generated_ids: List[int]) -> str: def decode(self, generated_ids: List[int]) -> str:
......
...@@ -42,8 +42,8 @@ class Seq2SeqLMBatch(Batch): ...@@ -42,8 +42,8 @@ class Seq2SeqLMBatch(Batch):
# Lengths of all generations present in the batch # Lengths of all generations present in the batch
input_lengths: List[int] input_lengths: List[int]
decoder_input_lengths: List[int] decoder_input_lengths: List[int]
offsets: List[Optional[int]] prefix_offsets: List[int]
token_offsets: List[Optional[int]] read_offsets: List[int]
# Generation helpers # Generation helpers
next_token_choosers: List[NextTokenChooser] next_token_choosers: List[NextTokenChooser]
...@@ -79,8 +79,8 @@ class Seq2SeqLMBatch(Batch): ...@@ -79,8 +79,8 @@ class Seq2SeqLMBatch(Batch):
stopping_criterias = [] stopping_criterias = []
decoder_input_lengths = [] decoder_input_lengths = []
offsets = [] prefix_offsets = []
token_offsets = [] read_offsets = []
requests_idx_mapping = {} requests_idx_mapping = {}
# Parse batch # Parse batch
...@@ -91,8 +91,6 @@ class Seq2SeqLMBatch(Batch): ...@@ -91,8 +91,6 @@ class Seq2SeqLMBatch(Batch):
inputs.append(r.inputs) inputs.append(r.inputs)
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
decoder_input_lengths.append(1) decoder_input_lengths.append(1)
offsets.append(None)
token_offsets.append(None)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
stopping_criteria = StoppingCriteria.from_pb( stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer r.stopping_parameters, tokenizer
...@@ -123,6 +121,9 @@ class Seq2SeqLMBatch(Batch): ...@@ -123,6 +121,9 @@ class Seq2SeqLMBatch(Batch):
.repeat(len(pb.requests)) .repeat(len(pb.requests))
.view(-1, 1) .view(-1, 1)
) )
for _ in pb.requests:
prefix_offsets.append(0)
read_offsets.append(1)
all_decoder_input_ids = decoder_input_ids.view(-1).split(1) all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
max_tokens = len(inputs) * max_input_length + max_decode_tokens max_tokens = len(inputs) * max_input_length + max_decode_tokens
...@@ -140,8 +141,8 @@ class Seq2SeqLMBatch(Batch): ...@@ -140,8 +141,8 @@ class Seq2SeqLMBatch(Batch):
past_key_values=None, past_key_values=None,
input_lengths=input_lengths.tolist(), input_lengths=input_lengths.tolist(),
decoder_input_lengths=decoder_input_lengths, decoder_input_lengths=decoder_input_lengths,
offsets=offsets, prefix_offsets=prefix_offsets,
token_offsets=token_offsets, read_offsets=read_offsets,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
max_input_length=max_input_length.item(), max_input_length=max_input_length.item(),
...@@ -165,8 +166,8 @@ class Seq2SeqLMBatch(Batch): ...@@ -165,8 +166,8 @@ class Seq2SeqLMBatch(Batch):
requests_idx_mapping = {} requests_idx_mapping = {}
input_lengths = [] input_lengths = []
decoder_input_lengths = [] decoder_input_lengths = []
offsets = [] prefix_offsets = []
token_offsets = [] read_offsets = []
all_decoder_input_ids = [] all_decoder_input_ids = []
...@@ -184,8 +185,8 @@ class Seq2SeqLMBatch(Batch): ...@@ -184,8 +185,8 @@ class Seq2SeqLMBatch(Batch):
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
keep_indices.append(idx) keep_indices.append(idx)
offsets.append(self.offsets[idx]) prefix_offsets.append(self.prefix_offsets[idx])
token_offsets.append(self.token_offsets[idx]) read_offsets.append(self.read_offsets[idx])
all_decoder_input_ids.append(self.all_decoder_input_ids[idx]) all_decoder_input_ids.append(self.all_decoder_input_ids[idx])
...@@ -248,8 +249,8 @@ class Seq2SeqLMBatch(Batch): ...@@ -248,8 +249,8 @@ class Seq2SeqLMBatch(Batch):
self.all_decoder_input_ids = all_decoder_input_ids self.all_decoder_input_ids = all_decoder_input_ids
self.input_lengths = input_lengths self.input_lengths = input_lengths
self.decoder_input_lengths = decoder_input_lengths self.decoder_input_lengths = decoder_input_lengths
self.offsets = offsets self.prefix_offsets = prefix_offsets
self.token_offsets = token_offsets self.read_offsets = read_offsets
self.next_token_choosers = next_token_choosers self.next_token_choosers = next_token_choosers
self.stopping_criterias = stopping_criterias self.stopping_criterias = stopping_criterias
self.max_input_length = max_input_length self.max_input_length = max_input_length
...@@ -283,8 +284,8 @@ class Seq2SeqLMBatch(Batch): ...@@ -283,8 +284,8 @@ class Seq2SeqLMBatch(Batch):
all_decoder_input_ids = [] all_decoder_input_ids = []
input_lengths = [] input_lengths = []
decoder_input_lengths = [] decoder_input_lengths = []
offsets = [] prefix_offsets = []
token_offsets = [] read_offsets = []
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
max_tokens = 0 max_tokens = 0
...@@ -306,8 +307,8 @@ class Seq2SeqLMBatch(Batch): ...@@ -306,8 +307,8 @@ class Seq2SeqLMBatch(Batch):
all_decoder_input_ids.extend(batch.all_decoder_input_ids) all_decoder_input_ids.extend(batch.all_decoder_input_ids)
input_lengths.extend(batch.input_lengths) input_lengths.extend(batch.input_lengths)
decoder_input_lengths.extend(batch.decoder_input_lengths) decoder_input_lengths.extend(batch.decoder_input_lengths)
offsets.extend(batch.offsets) prefix_offsets.extend(batch.prefix_offsets)
token_offsets.extend(batch.token_offsets) read_offsets.extend(batch.read_offsets)
next_token_choosers.extend(batch.next_token_choosers) next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias) stopping_criterias.extend(batch.stopping_criterias)
...@@ -482,8 +483,8 @@ class Seq2SeqLMBatch(Batch): ...@@ -482,8 +483,8 @@ class Seq2SeqLMBatch(Batch):
past_key_values=past_key_values, past_key_values=past_key_values,
input_lengths=input_lengths, input_lengths=input_lengths,
decoder_input_lengths=decoder_input_lengths, decoder_input_lengths=decoder_input_lengths,
offsets=offsets, prefix_offsets=prefix_offsets,
token_offsets=token_offsets, read_offsets=read_offsets,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
max_input_length=max_input_length, max_input_length=max_input_length,
...@@ -502,7 +503,6 @@ class Seq2SeqLM(Model): ...@@ -502,7 +503,6 @@ class Seq2SeqLM(Model):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
decode_buffer: int = 3,
): ):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
...@@ -514,24 +514,24 @@ class Seq2SeqLM(Model): ...@@ -514,24 +514,24 @@ class Seq2SeqLM(Model):
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float32 dtype = torch.float32
self.model = AutoModelForSeq2SeqLM.from_pretrained( model = AutoModelForSeq2SeqLM.from_pretrained(
model_id, model_id,
revision=revision, revision=revision,
torch_dtype=dtype, torch_dtype=dtype,
device_map="auto" if torch.cuda.is_available() else None, device_map="auto" if torch.cuda.is_available() else None,
load_in_8bit=quantize == "bitsandbytes", load_in_8bit=quantize == "bitsandbytes",
).eval() )
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left", truncation_side="left" model_id, revision=revision, padding_side="left", truncation_side="left"
) )
tokenizer.bos_token_id = self.model.config.decoder_start_token_id tokenizer.bos_token_id = model.config.decoder_start_token_id
super(Seq2SeqLM, self).__init__( super(Seq2SeqLM, self).__init__(
model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=True, requires_padding=True,
dtype=dtype, dtype=dtype,
device=device, device=device,
decode_buffer=decode_buffer,
) )
@property @property
...@@ -608,8 +608,8 @@ class Seq2SeqLM(Model): ...@@ -608,8 +608,8 @@ class Seq2SeqLM(Model):
iterator = zip( iterator = zip(
batch.requests, batch.requests,
batch.input_lengths, batch.input_lengths,
batch.offsets, batch.prefix_offsets,
batch.token_offsets, batch.read_offsets,
batch.decoder_input_lengths, batch.decoder_input_lengths,
logits, logits,
batch.next_token_choosers, batch.next_token_choosers,
...@@ -621,8 +621,8 @@ class Seq2SeqLM(Model): ...@@ -621,8 +621,8 @@ class Seq2SeqLM(Model):
for i, ( for i, (
request, request,
input_length, input_length,
offset, prefix_offset,
token_offset, read_offset,
decoder_input_length, decoder_input_length,
logits, logits,
next_token_chooser, next_token_chooser,
...@@ -643,8 +643,8 @@ class Seq2SeqLM(Model): ...@@ -643,8 +643,8 @@ class Seq2SeqLM(Model):
# Generated token # Generated token
next_token_logprob = logprobs[-1, next_token_id] next_token_logprob = logprobs[-1, next_token_id]
next_token_id_squeezed = next_token_id.squeeze() next_token_id_squeezed = next_token_id.squeeze()
next_token_text, offset, token_offset = self.decode_token( next_token_text, prefix_offset, read_offset = self.decode_token(
all_decoder_input_ids, offset, token_offset all_decoder_input_ids, prefix_offset, read_offset
) )
# Evaluate stopping criteria # Evaluate stopping criteria
...@@ -702,8 +702,8 @@ class Seq2SeqLM(Model): ...@@ -702,8 +702,8 @@ class Seq2SeqLM(Model):
batch.all_decoder_input_ids[i] = all_decoder_input_ids batch.all_decoder_input_ids[i] = all_decoder_input_ids
batch.input_lengths[i] = input_length batch.input_lengths[i] = input_length
batch.decoder_input_lengths[i] = new_decoder_input_length batch.decoder_input_lengths[i] = new_decoder_input_length
batch.offsets[i] = offset batch.prefix_offsets[i] = prefix_offset
batch.token_offsets[i] = token_offset batch.read_offsets[i] = read_offset
batch.max_input_length = max(batch.max_input_length, input_length) batch.max_input_length = max(batch.max_input_length, input_length)
batch.max_decoder_input_length = max( batch.max_decoder_input_length = max(
batch.max_decoder_input_length, new_decoder_input_length batch.max_decoder_input_length, new_decoder_input_length
......
...@@ -16,9 +16,6 @@ from text_generation_server.utils import ( ...@@ -16,9 +16,6 @@ from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
) )
from text_generation_server.utils.layers import (
FastLinear,
)
from transformers.models.t5.parallel_layers import ( from transformers.models.t5.parallel_layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
...@@ -73,9 +70,9 @@ class T5Sharded(Seq2SeqLM): ...@@ -73,9 +70,9 @@ class T5Sharded(Seq2SeqLM):
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
) )
self.model = model.eval()
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(Seq2SeqLM, self).__init__( super(Seq2SeqLM, self).__init__(
model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=True, requires_padding=True,
dtype=dtype, dtype=dtype,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment