Unverified Commit 9ecfa16b authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Speculative (#1308)

parent 3238c491
......@@ -32,6 +32,7 @@ def serve(
revision: Optional[str] = None,
sharded: bool = False,
quantize: Optional[Quantization] = None,
speculate: Optional[int] = None,
dtype: Optional[Dtype] = None,
trust_remote_code: bool = False,
uds_path: Path = "/tmp/text-generation-server",
......@@ -81,7 +82,7 @@ def serve(
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
)
server.serve(
model_id, revision, sharded, quantize, dtype, trust_remote_code, uds_path
model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code, uds_path
)
......@@ -116,7 +117,7 @@ def download_weights(
logger.info("Files are already present on the host. " "Skipping download.")
return
# Local files not found
except (utils.LocalEntryNotFoundError, FileNotFoundError):
except (utils.LocalEntryNotFoundError, FileNotFoundError, utils.EntryNotFoundError):
pass
is_local_model = (Path(model_id).exists() and Path(model_id).is_dir()) or os.getenv(
......@@ -137,6 +138,29 @@ def download_weights(
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
pass
try:
import json
medusa_head = hf_hub_download(model_id, revision=revision, filename="medusa_lm_head.pt")
if auto_convert:
medusa_sf = Path(medusa_head[:-len(".pt")] + ".safetensors")
if not medusa_sf.exists():
utils.convert_files([Path(medusa_head)], [medusa_sf], [])
medusa_config = hf_hub_download(model_id, revision=revision, filename="config.json")
with open(medusa_config, "r") as f:
config = json.load(f)
model_id = config["base_model_name_or_path"]
revision = "main"
try:
utils.weight_files(model_id, revision, extension)
logger.info(f"Files for parent {model_id} are already present on the host. " "Skipping download.")
return
# Local files not found
except (utils.LocalEntryNotFoundError, FileNotFoundError, utils.EntryNotFoundError):
pass
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
pass
# Try to download weights from the hub
try:
filenames = utils.weight_hub_files(model_id, revision, extension)
......
......@@ -6,6 +6,7 @@ from transformers.configuration_utils import PretrainedConfig
from transformers.models.auto import modeling_auto
from typing import Optional
from text_generation_server.utils.speculate import get_speculate, set_speculate
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
from text_generation_server.models.flash_causal_lm import FlashCausalLM
......@@ -77,12 +78,12 @@ except ImportError as e:
if MISTRAL:
__all__.append(FlashMistral)
def get_model(
model_id: str,
revision: Optional[str],
sharded: bool,
quantize: Optional[str],
speculate: Optional[int],
dtype: Optional[str],
trust_remote_code: bool,
) -> Model:
......@@ -97,6 +98,11 @@ def get_model(
else:
raise RuntimeError(f"Unknown dtype {dtype}")
if speculate is not None:
set_speculate(speculate)
else:
set_speculate(0)
if "facebook/galactica" in model_id:
return GalacticaSharded(
model_id,
......@@ -131,6 +137,33 @@ def get_model(
config_dict, _ = PretrainedConfig.get_config_dict(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
use_medusa = None
if "medusa_num_heads" in config_dict:
use_medusa = model_id
medusa_config = config_dict
model_id = config_dict["base_model_name_or_path"]
revision = "main"
speculate_medusa = config_dict["medusa_num_heads"]
if speculate is not None:
if speculate > speculate_medusa:
raise RuntimeError("Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match")
else:
set_speculate(speculate)
else:
set_speculate(speculate_medusa)
config_dict, _ = PretrainedConfig.get_config_dict(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
method = "medusa"
else:
method = "n-gram"
speculate = get_speculate()
if speculate > 0:
logger.info(f"Using speculation {method} with {speculate} input ids.")
model_type = config_dict["model_type"]
if model_type == "gpt_bigcode":
......@@ -206,6 +239,7 @@ def get_model(
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
use_medusa=use_medusa
)
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
......
......@@ -10,10 +10,9 @@ from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.models import Model
from text_generation_server.models.types import (
Batch,
PrefillTokens,
Tokens,
Generation,
GeneratedText,
TopTokens,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
......@@ -676,8 +675,8 @@ class CausalLM(Model):
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
prefill_tokens = PrefillTokens(
prefill_token_ids, prefill_logprobs, prefill_texts
prefill_tokens = Tokens(
prefill_token_ids, prefill_logprobs, prefill_texts, is_special=[]
)
else:
prefill_tokens = None
......@@ -691,7 +690,7 @@ class CausalLM(Model):
special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids
]
top_tokens = TopTokens(
top_tokens = Tokens(
top_token_ids,
top_token_logprobs,
toptoken_texts,
......@@ -703,10 +702,12 @@ class CausalLM(Model):
generation = Generation(
request.id,
prefill_tokens,
next_token_id_squeezed,
next_token_logprob,
next_token_text,
next_token_id_squeezed.item() in self.all_special_ids,
Tokens(
[next_token_id_squeezed],
[next_token_logprob],
[next_token_text],
[next_token_id_squeezed.item() in self.all_special_ids],
),
generated_text,
top_tokens,
)
......
......@@ -11,13 +11,13 @@ from opentelemetry import trace
from transformers import PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type, Union, Dict
from text_generation_server.models import Model
from text_generation_server.models import Model
from text_generation_server.utils.speculate import get_speculate
from text_generation_server.models.types import (
Batch,
PrefillTokens,
Tokens,
Generation,
GeneratedText,
TopTokens,
)
from text_generation_server.models.cache_manager import (
get_cache_manager,
......@@ -41,6 +41,7 @@ class FlashCausalLMBatch(Batch):
# Decoder values
input_ids: torch.Tensor
position_ids: torch.Tensor
speculative_ids: torch.Tensor
# Flash Attention values
......@@ -120,6 +121,7 @@ class FlashCausalLMBatch(Batch):
)["input_ids"]
position_ids = []
speculative_ids = []
cu_seqlen_prefill = [0]
needed_blocks_slots = []
start_slots = []
......@@ -163,6 +165,8 @@ class FlashCausalLMBatch(Batch):
input_length = len(tokenized_input)
input_lengths.append(input_length)
prefix_offsets.append(input_length - 5)
read_offsets.append(input_length)
......@@ -186,7 +190,8 @@ class FlashCausalLMBatch(Batch):
# Paged attention
# Remove one as the first token des not have a past
total_tokens = input_length + max_new_tokens - 1
speculative_length = get_speculate()
total_tokens = input_length + max_new_tokens - 1 + speculative_length
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
blocks += needed_blocks
needed_blocks_slots.append((needed_blocks, total_tokens))
......@@ -224,7 +229,7 @@ class FlashCausalLMBatch(Batch):
cumulative_max_length += total_tokens
max_seqlen = max(max_seqlen, input_length)
max_blocks = max(max_blocks, needed_blocks)
max_length = max(max_length, input_length + max_new_tokens)
max_length = max(max_length, input_length + max_new_tokens + speculative_length)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype, device
......@@ -255,7 +260,6 @@ class FlashCausalLMBatch(Batch):
cu_seqlen_prefill = torch.tensor(
cu_seqlen_prefill, device=device, dtype=torch.int32
)
position_ids = position_ids.to(device)
slot_indices = slot_indices.to(device)
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
......@@ -309,6 +313,7 @@ class FlashCausalLMBatch(Batch):
top_n_tokens_tensor=top_n_tokens_tensor,
blocks=blocks,
max_blocks=max_blocks,
speculative_ids=None,
)
@tracer.start_as_current_span("filter")
......@@ -419,6 +424,7 @@ class FlashCausalLMBatch(Batch):
slots = self.slots[slot_filtering_indices]
next_token_chooser = self.next_token_chooser.filter(indices)
top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
speculative_ids = self.speculative_ids[indices] if self.speculative_ids is not None else None
start_slots = torch.tensor(start_slots, dtype=torch.int64)
......@@ -454,6 +460,7 @@ class FlashCausalLMBatch(Batch):
top_n_tokens_tensor=top_n_tokens_tensor,
blocks=blocks,
max_blocks=max_blocks,
speculative_ids=speculative_ids,
)
@classmethod
......@@ -473,6 +480,7 @@ class FlashCausalLMBatch(Batch):
total_batch_size += len(b)
total_slots += len(b.slots)
blocks += b.blocks
speculative_length = b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
max_blocks = max(max_blocks, b.max_blocks)
max_seqlen = max(max_seqlen, b.max_seqlen)
max_length = max(
......@@ -480,6 +488,7 @@ class FlashCausalLMBatch(Batch):
max(
input_length
+ stopping_criteria.max_new_tokens
+ speculative_length
- stopping_criteria.current_tokens
for input_length, stopping_criteria in zip(
b.input_lengths, b.stopping_criterias
......@@ -577,6 +586,8 @@ class FlashCausalLMBatch(Batch):
device=batches[0].next_token_chooser.device,
)
speculative_ids = torch.cat([b.speculative_ids for b in batches], dim=0) if batches[0].speculative_ids is not None else None
# Needed to avoid dropping blocks when the batches will go out of scope
for b in batches:
b.block_tables = None
......@@ -611,6 +622,7 @@ class FlashCausalLMBatch(Batch):
top_n_tokens_tensor=top_n_tokens_tensor,
blocks=blocks,
max_blocks=max_blocks,
speculative_ids=speculative_ids
)
def __del__(self):
......@@ -714,16 +726,55 @@ class FlashCausalLM(Model):
def forward(self, batch: FlashCausalLMBatch) -> Tuple[torch.Tensor, torch.Tensor]:
# Model Forward
if batch.speculative_ids is not None:
input_ids=batch.input_ids
position_ids=batch.position_ids
cu_seqlen_prefill=batch.cu_seqlen_prefill
kv_cache=get_cache_manager().kv_cache
block_tables=batch.block_tables_tensor
slots=batch.slots[batch.slot_indices]
input_lengths=batch.input_lengths_tensor
max_s=batch.max_seqlen
lm_head_indices=batch.prefill_head_indices
speculative_ids = batch.speculative_ids
B, speculative_length = speculative_ids.shape
new_length = speculative_length + 1
new_input_ids = torch.cat([input_ids.unsqueeze(-1), speculative_ids], dim=1).reshape(-1)
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
arange_int = arange.to(dtype=torch.int32)
new_position_ids = (position_ids.unsqueeze(-1).expand(B, new_length) + arange).view(-1)
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
input_lengths = (input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
# Add Copy the block tables for all members
block_tables = block_tables.unsqueeze(1).expand(B, new_length, -1).reshape(B* new_length, -1).contiguous()
max_s = max_s + speculative_length
input_ids = new_input_ids
position_ids = new_position_ids
else:
input_ids=batch.input_ids
position_ids=batch.position_ids
cu_seqlen_prefill=batch.cu_seqlen_prefill
kv_cache=get_cache_manager().kv_cache
block_tables=batch.block_tables_tensor
slots=batch.slots[batch.slot_indices]
input_lengths=batch.input_lengths_tensor
max_s=batch.max_seqlen
lm_head_indices=batch.prefill_head_indices
return self.model.forward(
input_ids=batch.input_ids,
position_ids=batch.position_ids,
cu_seqlen_prefill=batch.cu_seqlen_prefill,
kv_cache=get_cache_manager().kv_cache,
block_tables=batch.block_tables_tensor,
slots=batch.slots[batch.slot_indices],
input_lengths=batch.input_lengths_tensor,
max_s=batch.max_seqlen,
lm_head_indices=batch.prefill_head_indices,
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
lm_head_indices=lm_head_indices,
)
@tracer.start_as_current_span("generate_token")
......@@ -752,21 +803,32 @@ class FlashCausalLM(Model):
del batch
raise e
if isinstance(out, tuple):
out, speculative_logits = out
else:
speculative_logits = None
if prefill:
next_token_logits = (
out[batch.prefill_next_token_indices] if prefill_logprobs else out
)
if speculative_logits is not None:
speculative_logits = (
speculative_logits[batch.prefill_next_token_indices] if prefill_logprobs else speculative_logits
)
else:
next_token_logits = out
next_input_ids, next_token_logprobs, logprobs = batch.next_token_chooser(
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits
next_input_ids, next_token_logprobs, logprobs, accepted_ids, speculative_ids = batch.next_token_chooser(
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, get_speculate(), batch.speculative_ids, speculative_logits
)
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs
)
speculative_length = 0 if speculative_ids is None else speculative_ids.shape[1]
if prefill:
if len(batch) > 1 and prefill_logprobs:
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
......@@ -792,6 +854,7 @@ class FlashCausalLM(Model):
iterator = zip(
batch.input_lengths,
batch.all_input_ids,
accepted_ids
)
# We do two for loops as the first one can run completely asynchronously from the GPU while for the second
......@@ -799,9 +862,11 @@ class FlashCausalLM(Model):
# It is faster if we delay this sync for the maximum amount of time
# For each member of the batch
index = 0
for i, (
input_length,
all_input_ids,
n_accepted_ids
) in enumerate(iterator):
# Indexing metadata
start_index = cumulative_length
......@@ -830,15 +895,18 @@ class FlashCausalLM(Model):
start_index + 1 : start_index + out_length
]
batch.all_input_ids_tensor[i, input_length] = next_input_ids[i]
for j in range(n_accepted_ids):
batch.all_input_ids_tensor[i, input_length + j] = next_input_ids[index]
index += 1
cumulative_length += input_length
# Set values in batch
batch.input_ids = next_input_ids
batch.position_ids = next_position_ids + 1
batch.input_lengths_tensor += 1
batch.slot_indices += 1
batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
batch.speculative_ids = speculative_ids
batch.position_ids = next_position_ids + accepted_ids
batch.input_lengths_tensor += accepted_ids
batch.slot_indices += accepted_ids
if prefill and prefill_logprobs:
# Get prefill logprobs
......@@ -851,7 +919,7 @@ class FlashCausalLM(Model):
# GPU <-> CPU sync
next_token_logprobs = next_token_logprobs.tolist()
next_token_ids = batch.input_ids.tolist()
next_token_ids = next_input_ids.tolist()
# Zipped iterator
iterator = zip(
......@@ -864,13 +932,13 @@ class FlashCausalLM(Model):
batch.next_token_chooser.do_sample,
batch.next_token_chooser.seeds,
batch.top_n_tokens,
next_token_ids,
next_token_logprobs,
accepted_ids,
batch_top_token_ids,
batch_top_token_logprobs,
)
# For each member of the batch
index = 0
for i, (
request,
input_length,
......@@ -881,29 +949,43 @@ class FlashCausalLM(Model):
do_sample,
seed,
top_n_tokens,
next_token_id,
next_token_logprob,
n_accepted_ids,
top_token_ids,
top_token_logprobs,
) in enumerate(iterator):
# Append next token to all tokens
all_input_ids.append(next_token_id)
next_token_texts = []
left = 0
before = stopping_criteria.current_tokens
current_stopped = False
for j in range(index, index + n_accepted_ids):
# Generated token
next_token_id = next_token_ids[j]
all_input_ids.append(next_token_id)
next_token_text, prefix_offset, read_offset = self.decode_token(
all_input_ids,
prefix_offset,
read_offset,
)
next_token_texts.append(next_token_text)
# Generated token
next_token_text, prefix_offset, read_offset = self.decode_token(
all_input_ids,
prefix_offset,
read_offset,
)
stop, reason = stopping_criteria(
next_token_id,
next_token_text,
)
# Evaluate stopping criteria
stop, reason = stopping_criteria(
next_token_id,
next_token_text,
)
if stop:
left = index + n_accepted_ids - j - 1
current_stopped = True
break
else:
current_stopped = False
stopped = stopped and current_stopped
if not stop:
stopped = False
_next_token_ids = next_token_ids[index: index+n_accepted_ids - left]
_next_token_logprobs = next_token_logprobs[index: index+n_accepted_ids - left]
index += n_accepted_ids
# Shard generations
# All generations will be appended in the rust sharded client
......@@ -943,8 +1025,9 @@ class FlashCausalLM(Model):
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
prefill_tokens = PrefillTokens(
prefill_token_ids, request_prefill_logprobs, prefill_texts
prefill_tokens = Tokens(
prefill_token_ids, request_prefill_logprobs, prefill_texts, is_special = []
)
else:
prefill_tokens = None
......@@ -958,7 +1041,7 @@ class FlashCausalLM(Model):
special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids
]
top_tokens = TopTokens(
top_tokens = Tokens(
top_token_ids,
top_token_logprobs,
toptoken_texts,
......@@ -970,10 +1053,12 @@ class FlashCausalLM(Model):
generation = Generation(
request.id,
prefill_tokens,
next_token_id,
next_token_logprob,
next_token_text,
next_token_id in self.all_special_ids,
Tokens(
_next_token_ids,
_next_token_logprobs,
next_token_texts,
[nid in self.all_special_ids for nid in _next_token_ids],
),
generated_text,
top_tokens,
)
......@@ -981,7 +1066,9 @@ class FlashCausalLM(Model):
generations.append(generation)
# Update values
batch.input_lengths[i] = input_length + 1
batch.input_lengths[i] = input_length + n_accepted_ids.item()
if batch.input_lengths[i] > batch.max_seqlen:
batch.max_seqlen = batch.input_lengths[i]
batch.prefix_offsets[i] = prefix_offset
batch.read_offsets[i] = read_offset
batch.all_input_ids[i] = all_input_ids
......@@ -994,6 +1081,5 @@ class FlashCausalLM(Model):
batch.prefill_cu_outlens = None
batch.prefill_head_indices = None
batch.prefill_next_token_indices = None
batch.max_seqlen = batch.max_seqlen + 1
return generations, batch
......@@ -28,6 +28,7 @@ class FlashLlama(FlashCausalLM):
quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
use_medusa: Optional[str] = None,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
......@@ -66,6 +67,18 @@ class FlashLlama(FlashCausalLM):
weights._set_gptq_params(model_id)
model = FlashLlamaForCausalLM(config, weights)
if use_medusa:
from text_generation_server.utils.medusa import MedusaModel
from huggingface_hub import hf_hub_download
import json
medusa_config = hf_hub_download(use_medusa, revision=revision, filename="config.json")
with open(medusa_config, "r") as f:
config = json.load(f)
medusa_head = hf_hub_download(use_medusa, revision=revision, filename="medusa_lm_head.pt")
medusa_sf = medusa_head[:-len(".pt")] + ".safetensors"
weights = Weights([medusa_sf], device, dtype, process_group=self.process_group)
lm_head = model.lm_head
model.lm_head = MedusaModel(config, weights, lm_head)
torch.distributed.barrier(group=self.process_group)
super(FlashLlama, self).__init__(
......
......@@ -21,6 +21,7 @@ from text_generation_server.models.custom_modeling.flash_mistral_modeling import
FlashMistralForCausalLM,
MistralConfig,
)
from text_generation_server.utils.speculate import get_speculate
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
......@@ -132,7 +133,8 @@ class FlashMistralBatch(FlashCausalLMBatch):
# Paged attention
# Remove one as the first token des not have a past
total_tokens = input_length + max_new_tokens - 1
speculative_length = get_speculate()
total_tokens = input_length + max_new_tokens - 1 + speculative_length
# Needed blocks can not go over SLIDING_WINDOW_BLOCKS
needed_blocks = min(
......@@ -183,7 +185,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
cumulative_max_length += total_tokens
max_seqlen = max(max_seqlen, input_length)
max_blocks = max(max_blocks, needed_blocks)
max_length = max(max_length, input_length + max_new_tokens)
max_length = max(max_length, input_length + max_new_tokens + speculative_length)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype, device
......@@ -272,6 +274,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
blocks=blocks,
max_blocks=max_blocks,
prefill_cache_indices=prefill_cache_indices,
speculative_ids=None
)
......@@ -340,17 +343,55 @@ class FlashMistral(FlashCausalLM):
def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, torch.Tensor]:
# Model Forward
if batch.speculative_ids is not None:
input_ids=batch.input_ids
position_ids=batch.position_ids
cu_seqlen_prefill=batch.cu_seqlen_prefill
kv_cache=get_cache_manager().kv_cache
block_tables=batch.block_tables_tensor
slots=batch.slots[batch.slot_indices]
input_lengths=batch.input_lengths_tensor
max_s=batch.max_seqlen
lm_head_indices=batch.prefill_head_indices
speculative_ids = batch.speculative_ids
B, speculative_length = speculative_ids.shape
new_length = speculative_length + 1
new_input_ids = torch.cat([input_ids.unsqueeze(-1), speculative_ids], dim=1).reshape(-1)
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
arange_int = arange.to(dtype=torch.int32)
new_position_ids = (position_ids.unsqueeze(-1).expand(B, new_length) + arange).view(-1)
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
input_lengths = (input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
# Add Copy the block tables for all members
block_tables = block_tables.unsqueeze(1).expand(B, new_length, -1).reshape(B* new_length, -1).contiguous()
max_s = max_s + speculative_length
input_ids = new_input_ids
position_ids = new_position_ids
else:
input_ids=batch.input_ids
position_ids=batch.position_ids
cu_seqlen_prefill=batch.cu_seqlen_prefill
kv_cache=get_cache_manager().kv_cache
block_tables=batch.block_tables_tensor
slots=batch.slots[batch.slot_indices]
input_lengths=batch.input_lengths_tensor
max_s=batch.max_seqlen
lm_head_indices=batch.prefill_head_indices
logits = self.model.forward(
input_ids=batch.input_ids,
position_ids=batch.position_ids,
cu_seqlen_prefill=batch.cu_seqlen_prefill,
kv_cache=get_cache_manager().kv_cache,
block_tables=batch.block_tables_tensor,
slots=batch.slots[batch.slot_indices],
input_lengths=batch.input_lengths_tensor,
max_s=batch.max_seqlen,
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=batch.prefill_head_indices,
lm_head_indices=lm_head_indices,
)
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None
......
......@@ -20,7 +20,7 @@ from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.models import Model
from text_generation_server.models.types import (
Batch,
PrefillTokens,
Tokens,
Generation,
GeneratedText,
)
......@@ -791,8 +791,8 @@ class IdeficsCausalLM(Model):
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
prefill_tokens = PrefillTokens(
prefill_token_ids, prefill_logprobs, prefill_texts
prefill_tokens = Tokens(
prefill_token_ids, prefill_logprobs, prefill_texts, is_special=[]
)
else:
prefill_tokens = None
......@@ -802,10 +802,12 @@ class IdeficsCausalLM(Model):
generation = Generation(
request.id,
prefill_tokens,
next_token_id_squeezed,
next_token_logprob,
next_token_text,
next_token_id_squeezed.item() in self.all_special_ids,
Tokens(
[next_token_id_squeezed],
[next_token_logprob],
[next_token_text],
[next_token_id_squeezed.item() in self.all_special_ids],
),
generated_text,
top_tokens,
)
......
......@@ -6,6 +6,7 @@ from typing import List, Tuple, Optional, TypeVar, Type
from transformers import PreTrainedTokenizerBase, PretrainedConfig
from text_generation_server.models.types import Batch, Generation
from text_generation_server.utils.speculate import get_speculate
from text_generation_server.pb.generate_pb2 import InfoResponse
B = TypeVar("B", bound=Batch)
......@@ -22,6 +23,7 @@ class Model(ABC):
rank: int = 0,
world_size: int = 1,
sliding_window: Optional[int] = None,
speculate: Optional[int] = None,
):
self.model = model.eval()
self.tokenizer = tokenizer
......@@ -33,6 +35,10 @@ class Model(ABC):
self.world_size = world_size
self.sliding_window = sliding_window
if speculate is None:
speculate = get_speculate()
self.speculate = speculate
self.has_position_ids = (
inspect.signature(model.forward).parameters.get("position_ids", None)
is not None
......@@ -50,6 +56,7 @@ class Model(ABC):
dtype=str(self.dtype),
device_type=self.device.type,
window_size=self.sliding_window,
speculate=self.speculate
)
@property
......
......@@ -11,8 +11,7 @@ from text_generation_server.models.types import (
GeneratedText,
Batch,
Generation,
PrefillTokens,
TopTokens,
Tokens,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
......@@ -733,10 +732,11 @@ class Seq2SeqLM(Model):
# Prefill
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
prefill_tokens = PrefillTokens(
prefill_tokens = Tokens(
[self.tokenizer.bos_token_id],
[float("nan")],
[self.tokenizer.bos_token],
[False]
)
else:
prefill_tokens = None
......@@ -750,7 +750,7 @@ class Seq2SeqLM(Model):
special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids
]
top_tokens = TopTokens(
top_tokens = Tokens(
top_token_ids,
top_token_logprobs,
toptoken_texts,
......@@ -762,10 +762,12 @@ class Seq2SeqLM(Model):
generation = Generation(
request.id,
prefill_tokens,
next_token_id_squeezed,
next_token_logprob,
next_token_text,
next_token_id_squeezed.item() in self.all_special_ids,
Tokens(
[next_token_id_squeezed],
[next_token_logprob],
[next_token_text],
[next_token_id_squeezed.item() in self.all_special_ids],
),
generated_text,
top_tokens,
)
......
......@@ -58,33 +58,15 @@ class GeneratedText:
@dataclass
class PrefillTokens:
token_ids: List[int]
logprobs: List[float]
texts: List[str]
def to_pb(self) -> generate_pb2.PrefillTokens:
return generate_pb2.PrefillTokens(
ids=self.token_ids, logprobs=self.logprobs, texts=self.texts
)
def __len__(self):
return len(self.token_ids)
@dataclass
class TopTokens:
class Tokens:
token_ids: List[int]
logprobs: List[float]
texts: List[str]
is_special: List[bool]
def to_pb(self) -> generate_pb2.TopTokens:
return generate_pb2.TopTokens(
ids=self.token_ids,
logprobs=self.logprobs,
texts=self.texts,
is_special=self.is_special,
def to_pb(self) -> generate_pb2.Tokens:
return generate_pb2.Tokens(
ids=self.token_ids, logprobs=self.logprobs, texts=self.texts, is_special=self.is_special
)
def __len__(self):
......@@ -94,14 +76,11 @@ class TopTokens:
@dataclass
class Generation:
request_id: int
prefill_tokens: Optional[PrefillTokens]
token_id: int
token_logprob: float
token_text: str
token_is_special: bool
prefill_tokens: Optional[Tokens]
tokens: Tokens
generated_text: Optional[GeneratedText]
# Optional for now, since it's not yet supported for every model.
top_tokens: Optional[TopTokens]
top_tokens: Optional[List[Tokens]]
def to_pb(self) -> generate_pb2.Generation:
return generate_pb2.Generation(
......@@ -109,10 +88,7 @@ class Generation:
prefill_tokens=self.prefill_tokens.to_pb()
if self.prefill_tokens is not None
else None,
token_id=self.token_id,
token_logprob=self.token_logprob,
token_text=self.token_text,
token_is_special=self.token_is_special,
tokens=self.tokens.to_pb(),
generated_text=self.generated_text.to_pb()
if self.generated_text is not None
else None,
......
......@@ -132,6 +132,7 @@ def serve(
revision: Optional[str],
sharded: bool,
quantize: Optional[str],
speculate: Optional[int],
dtype: Optional[str],
trust_remote_code: bool,
uds_path: Path,
......@@ -141,6 +142,7 @@ def serve(
revision: Optional[str],
sharded: bool = False,
quantize: Optional[str] = None,
speculate: Optional[int] = None,
dtype: Optional[str] = None,
trust_remote_code: bool = False,
):
......@@ -157,7 +159,7 @@ def serve(
try:
model = get_model(
model_id, revision, sharded, quantize, dtype, trust_remote_code
model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code
)
except Exception:
logger.exception("Error when initializing model")
......@@ -205,5 +207,5 @@ def serve(
await server.stop(0)
asyncio.run(
serve_inner(model_id, revision, sharded, quantize, dtype, trust_remote_code)
serve_inner(model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code)
)
import torch
from dataclasses import dataclass
from text_generation_server.utils.layers import TensorParallelHead, FastLinear
@dataclass
class Output:
logits: torch.FloatTensor = None
speculative_logits: torch.FloatTensor = None
class ResBlock(torch.nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
self.linear = FastLinear.load(config, prefix=f"{prefix}.linear", weights=weights, bias=True)
self.act = torch.nn.SiLU()
def forward(self, x):
return x + self.act(self.linear(x))
class MedusaModel(torch.nn.Module):
def __init__(
self,
config,
weights,
lm_head
):
super().__init__()
self.heads = torch.nn.ModuleList(
[MedusaHead(config, prefix=f"{i}", weights=weights) for i in range(config["medusa_num_heads"])]
)
self.lm_head = lm_head
def forward(self, x):
logits = self.lm_head(x)
speculative_logits = torch.stack([head(x) for head in self.heads], dim=1)
return logits, speculative_logits
class MedusaHead(torch.nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
self.blocks = torch.nn.ModuleList([ResBlock(config, prefix=f"{prefix}.{i}", weights=weights) for i in range(config["medusa_num_layers"])])
n = len(self.blocks)
self.out = FastLinear.load(config, prefix=f"{prefix}.{n}", weights=weights, bias=False)
def forward(self, x):
for block in self.blocks:
x = block(x)
x = self.out(x)
return x
SPECULATE = None
def get_speculate() -> int:
global SPECULATE
return SPECULATE
def set_speculate(speculate: int):
global SPECULATE
SPECULATE = speculate
......@@ -16,7 +16,6 @@ from text_generation_server.utils.logits_process import (
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
class NextTokenChooser:
def __init__(
self,
......@@ -146,6 +145,20 @@ class StoppingCriteria:
pb.ignore_eos_token,
)
def create_n_gram_speculation(input_ids: torch.Tensor, next_ids: torch.Tensor, accepted_ids: torch.Tensor, speculate: int, verbose: bool):
# Very trivial approach, find first match in the string.
# This is much less refined than actual n-gram but seems to work
# relatively OK in grounded mode and is by far much faster with
# much less worst case complexity as everything happens on device.
B = accepted_ids.shape[0]
device = input_ids.device
seeds = next_ids[accepted_ids.cumsum(dim=-1) -1 ]
indices = (input_ids == seeds.unsqueeze(-1)).max(dim=1).indices + 1
all_indices = indices.unsqueeze(-1).expand(B, speculate) + torch.arange(speculate, device=device)
all_indices = torch.clamp(all_indices, max=input_ids.shape[1] - 1)
speculative_ids = input_ids.gather(dim=-1, index=all_indices)
return speculative_ids
class HeterogeneousNextTokenChooser:
def __init__(
......@@ -215,20 +228,79 @@ class HeterogeneousNextTokenChooser:
self.dtype = dtype
self.device = device
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor):
if self.watermark_processor is not None:
scores = self.watermark_processor(input_ids, scores)
if self.repetition_processor is not None:
scores = self.repetition_processor(input_ids, scores)
for warper in self.warpers:
scores = warper(input_ids, scores)
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, speculate: int, speculated_ids: Optional[torch.Tensor] = None, speculative_scores: Optional[torch.Tensor] = None, verbose=False):
if speculated_ids is not None:
B = scores.shape[0] // (speculated_ids.shape[1] + 1)
S = speculated_ids.shape[1] + 1
scores = scores.view(B, S, -1)
else:
B = scores.shape[0]
S = 1
scores = scores.view(B, S, -1)
next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long)
for j in range(S):
_scores = scores[:, j]
if self.watermark_processor is not None:
_scores = self.watermark_processor(input_ids, _scores)
if self.repetition_processor is not None:
_scores = self.repetition_processor(input_ids, _scores)
for warper in self.warpers:
_scores = warper(input_ids, _scores)
_next_ids = self.choice(_scores)
scores[:, j] = _scores
next_ids[:, j] = _next_ids
next_ids = next_ids.view(B*S)
scores = scores.view( B* S, -1)
if speculated_ids is not None:
accepted_ids = []
B = next_ids.shape[0] // (speculated_ids.shape[1] + 1)
S = speculated_ids.shape[1] + 1
indices = []
for i in range(B):
_next_ids = next_ids[i*S: (i + 1)*S]
_speculated_ids = speculated_ids[i]
validate_speculative = _next_ids[:-1] == _speculated_ids
index = i * S
accepted = 1
# First is always valid
indices.append(index)
for valid in validate_speculative.tolist():
if valid:
index += 1
accepted += 1
indices.append(index)
else:
break
accepted_ids.append(accepted)
accepted_ids = torch.tensor(accepted_ids, device=input_ids.device, dtype=input_ids.dtype)
next_ids = next_ids[indices]
scores = scores[indices]
indices = torch.arange(B, device=input_ids.device) * S
if speculative_scores is not None:
speculative_scores = speculative_scores[indices + accepted_ids - 1]
else:
accepted_ids = torch.ones_like(next_ids)
next_ids = self.choice(scores)
logprobs = torch.log_softmax(scores, -1)
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
return next_ids, next_logprobs, logprobs
if speculate > 0:
if speculative_scores is not None:
# Medusa provided some scores
speculative_ids = Greedy()(speculative_scores)
else:
# n-gram
speculative_ids = create_n_gram_speculation(input_ids, next_ids, accepted_ids, speculate, verbose)
else:
speculative_ids = None
return next_ids, next_logprobs, logprobs, accepted_ids, speculative_ids
def filter(self, indices):
if self.watermark_processor is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment