"...resnet50_tensorflow.git" did not exist on "e2618acac3496f4881bbf2820921f569eaaf6e23"
Unverified Commit 9ecfa16b authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Speculative (#1308)

parent 3238c491
...@@ -32,6 +32,7 @@ def serve( ...@@ -32,6 +32,7 @@ def serve(
revision: Optional[str] = None, revision: Optional[str] = None,
sharded: bool = False, sharded: bool = False,
quantize: Optional[Quantization] = None, quantize: Optional[Quantization] = None,
speculate: Optional[int] = None,
dtype: Optional[Dtype] = None, dtype: Optional[Dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
uds_path: Path = "/tmp/text-generation-server", uds_path: Path = "/tmp/text-generation-server",
...@@ -81,7 +82,7 @@ def serve( ...@@ -81,7 +82,7 @@ def serve(
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model." "Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
) )
server.serve( server.serve(
model_id, revision, sharded, quantize, dtype, trust_remote_code, uds_path model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code, uds_path
) )
...@@ -116,7 +117,7 @@ def download_weights( ...@@ -116,7 +117,7 @@ def download_weights(
logger.info("Files are already present on the host. " "Skipping download.") logger.info("Files are already present on the host. " "Skipping download.")
return return
# Local files not found # Local files not found
except (utils.LocalEntryNotFoundError, FileNotFoundError): except (utils.LocalEntryNotFoundError, FileNotFoundError, utils.EntryNotFoundError):
pass pass
is_local_model = (Path(model_id).exists() and Path(model_id).is_dir()) or os.getenv( is_local_model = (Path(model_id).exists() and Path(model_id).is_dir()) or os.getenv(
...@@ -137,6 +138,29 @@ def download_weights( ...@@ -137,6 +138,29 @@ def download_weights(
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
pass pass
try:
import json
medusa_head = hf_hub_download(model_id, revision=revision, filename="medusa_lm_head.pt")
if auto_convert:
medusa_sf = Path(medusa_head[:-len(".pt")] + ".safetensors")
if not medusa_sf.exists():
utils.convert_files([Path(medusa_head)], [medusa_sf], [])
medusa_config = hf_hub_download(model_id, revision=revision, filename="config.json")
with open(medusa_config, "r") as f:
config = json.load(f)
model_id = config["base_model_name_or_path"]
revision = "main"
try:
utils.weight_files(model_id, revision, extension)
logger.info(f"Files for parent {model_id} are already present on the host. " "Skipping download.")
return
# Local files not found
except (utils.LocalEntryNotFoundError, FileNotFoundError, utils.EntryNotFoundError):
pass
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
pass
# Try to download weights from the hub # Try to download weights from the hub
try: try:
filenames = utils.weight_hub_files(model_id, revision, extension) filenames = utils.weight_hub_files(model_id, revision, extension)
......
...@@ -6,6 +6,7 @@ from transformers.configuration_utils import PretrainedConfig ...@@ -6,6 +6,7 @@ from transformers.configuration_utils import PretrainedConfig
from transformers.models.auto import modeling_auto from transformers.models.auto import modeling_auto
from typing import Optional from typing import Optional
from text_generation_server.utils.speculate import get_speculate, set_speculate
from text_generation_server.models.model import Model from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM from text_generation_server.models.causal_lm import CausalLM
from text_generation_server.models.flash_causal_lm import FlashCausalLM from text_generation_server.models.flash_causal_lm import FlashCausalLM
...@@ -77,12 +78,12 @@ except ImportError as e: ...@@ -77,12 +78,12 @@ except ImportError as e:
if MISTRAL: if MISTRAL:
__all__.append(FlashMistral) __all__.append(FlashMistral)
def get_model( def get_model(
model_id: str, model_id: str,
revision: Optional[str], revision: Optional[str],
sharded: bool, sharded: bool,
quantize: Optional[str], quantize: Optional[str],
speculate: Optional[int],
dtype: Optional[str], dtype: Optional[str],
trust_remote_code: bool, trust_remote_code: bool,
) -> Model: ) -> Model:
...@@ -97,6 +98,11 @@ def get_model( ...@@ -97,6 +98,11 @@ def get_model(
else: else:
raise RuntimeError(f"Unknown dtype {dtype}") raise RuntimeError(f"Unknown dtype {dtype}")
if speculate is not None:
set_speculate(speculate)
else:
set_speculate(0)
if "facebook/galactica" in model_id: if "facebook/galactica" in model_id:
return GalacticaSharded( return GalacticaSharded(
model_id, model_id,
...@@ -131,6 +137,33 @@ def get_model( ...@@ -131,6 +137,33 @@ def get_model(
config_dict, _ = PretrainedConfig.get_config_dict( config_dict, _ = PretrainedConfig.get_config_dict(
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
use_medusa = None
if "medusa_num_heads" in config_dict:
use_medusa = model_id
medusa_config = config_dict
model_id = config_dict["base_model_name_or_path"]
revision = "main"
speculate_medusa = config_dict["medusa_num_heads"]
if speculate is not None:
if speculate > speculate_medusa:
raise RuntimeError("Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match")
else:
set_speculate(speculate)
else:
set_speculate(speculate_medusa)
config_dict, _ = PretrainedConfig.get_config_dict(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
method = "medusa"
else:
method = "n-gram"
speculate = get_speculate()
if speculate > 0:
logger.info(f"Using speculation {method} with {speculate} input ids.")
model_type = config_dict["model_type"] model_type = config_dict["model_type"]
if model_type == "gpt_bigcode": if model_type == "gpt_bigcode":
...@@ -206,6 +239,7 @@ def get_model( ...@@ -206,6 +239,7 @@ def get_model(
quantize=quantize, quantize=quantize,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
use_medusa=use_medusa
) )
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
......
...@@ -10,10 +10,9 @@ from typing import Optional, Tuple, List, Type, Dict ...@@ -10,10 +10,9 @@ from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.models import Model from text_generation_server.models import Model
from text_generation_server.models.types import ( from text_generation_server.models.types import (
Batch, Batch,
PrefillTokens, Tokens,
Generation, Generation,
GeneratedText, GeneratedText,
TopTokens,
) )
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
...@@ -676,8 +675,8 @@ class CausalLM(Model): ...@@ -676,8 +675,8 @@ class CausalLM(Model):
clean_up_tokenization_spaces=False, clean_up_tokenization_spaces=False,
skip_special_tokens=False, skip_special_tokens=False,
) )
prefill_tokens = PrefillTokens( prefill_tokens = Tokens(
prefill_token_ids, prefill_logprobs, prefill_texts prefill_token_ids, prefill_logprobs, prefill_texts, is_special=[]
) )
else: else:
prefill_tokens = None prefill_tokens = None
...@@ -691,7 +690,7 @@ class CausalLM(Model): ...@@ -691,7 +690,7 @@ class CausalLM(Model):
special_toptokens = [ special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids token_id in self.all_special_ids for token_id in top_token_ids
] ]
top_tokens = TopTokens( top_tokens = Tokens(
top_token_ids, top_token_ids,
top_token_logprobs, top_token_logprobs,
toptoken_texts, toptoken_texts,
...@@ -703,10 +702,12 @@ class CausalLM(Model): ...@@ -703,10 +702,12 @@ class CausalLM(Model):
generation = Generation( generation = Generation(
request.id, request.id,
prefill_tokens, prefill_tokens,
next_token_id_squeezed, Tokens(
next_token_logprob, [next_token_id_squeezed],
next_token_text, [next_token_logprob],
next_token_id_squeezed.item() in self.all_special_ids, [next_token_text],
[next_token_id_squeezed.item() in self.all_special_ids],
),
generated_text, generated_text,
top_tokens, top_tokens,
) )
......
...@@ -11,13 +11,13 @@ from opentelemetry import trace ...@@ -11,13 +11,13 @@ from opentelemetry import trace
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type, Union, Dict from typing import Optional, Tuple, List, Type, Union, Dict
from text_generation_server.models import Model from text_generation_server.models import Model
from text_generation_server.utils.speculate import get_speculate
from text_generation_server.models.types import ( from text_generation_server.models.types import (
Batch, Batch,
PrefillTokens, Tokens,
Generation, Generation,
GeneratedText, GeneratedText,
TopTokens,
) )
from text_generation_server.models.cache_manager import ( from text_generation_server.models.cache_manager import (
get_cache_manager, get_cache_manager,
...@@ -41,6 +41,7 @@ class FlashCausalLMBatch(Batch): ...@@ -41,6 +41,7 @@ class FlashCausalLMBatch(Batch):
# Decoder values # Decoder values
input_ids: torch.Tensor input_ids: torch.Tensor
position_ids: torch.Tensor position_ids: torch.Tensor
speculative_ids: torch.Tensor
# Flash Attention values # Flash Attention values
...@@ -120,6 +121,7 @@ class FlashCausalLMBatch(Batch): ...@@ -120,6 +121,7 @@ class FlashCausalLMBatch(Batch):
)["input_ids"] )["input_ids"]
position_ids = [] position_ids = []
speculative_ids = []
cu_seqlen_prefill = [0] cu_seqlen_prefill = [0]
needed_blocks_slots = [] needed_blocks_slots = []
start_slots = [] start_slots = []
...@@ -163,6 +165,8 @@ class FlashCausalLMBatch(Batch): ...@@ -163,6 +165,8 @@ class FlashCausalLMBatch(Batch):
input_length = len(tokenized_input) input_length = len(tokenized_input)
input_lengths.append(input_length) input_lengths.append(input_length)
prefix_offsets.append(input_length - 5) prefix_offsets.append(input_length - 5)
read_offsets.append(input_length) read_offsets.append(input_length)
...@@ -186,7 +190,8 @@ class FlashCausalLMBatch(Batch): ...@@ -186,7 +190,8 @@ class FlashCausalLMBatch(Batch):
# Paged attention # Paged attention
# Remove one as the first token des not have a past # Remove one as the first token des not have a past
total_tokens = input_length + max_new_tokens - 1 speculative_length = get_speculate()
total_tokens = input_length + max_new_tokens - 1 + speculative_length
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE) needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
blocks += needed_blocks blocks += needed_blocks
needed_blocks_slots.append((needed_blocks, total_tokens)) needed_blocks_slots.append((needed_blocks, total_tokens))
...@@ -224,7 +229,7 @@ class FlashCausalLMBatch(Batch): ...@@ -224,7 +229,7 @@ class FlashCausalLMBatch(Batch):
cumulative_max_length += total_tokens cumulative_max_length += total_tokens
max_seqlen = max(max_seqlen, input_length) max_seqlen = max(max_seqlen, input_length)
max_blocks = max(max_blocks, needed_blocks) max_blocks = max(max_blocks, needed_blocks)
max_length = max(max_length, input_length + max_new_tokens) max_length = max(max_length, input_length + max_new_tokens + speculative_length)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype, device next_token_chooser_parameters, dtype, device
...@@ -255,7 +260,6 @@ class FlashCausalLMBatch(Batch): ...@@ -255,7 +260,6 @@ class FlashCausalLMBatch(Batch):
cu_seqlen_prefill = torch.tensor( cu_seqlen_prefill = torch.tensor(
cu_seqlen_prefill, device=device, dtype=torch.int32 cu_seqlen_prefill, device=device, dtype=torch.int32
) )
position_ids = position_ids.to(device) position_ids = position_ids.to(device)
slot_indices = slot_indices.to(device) slot_indices = slot_indices.to(device)
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
...@@ -309,6 +313,7 @@ class FlashCausalLMBatch(Batch): ...@@ -309,6 +313,7 @@ class FlashCausalLMBatch(Batch):
top_n_tokens_tensor=top_n_tokens_tensor, top_n_tokens_tensor=top_n_tokens_tensor,
blocks=blocks, blocks=blocks,
max_blocks=max_blocks, max_blocks=max_blocks,
speculative_ids=None,
) )
@tracer.start_as_current_span("filter") @tracer.start_as_current_span("filter")
...@@ -419,6 +424,7 @@ class FlashCausalLMBatch(Batch): ...@@ -419,6 +424,7 @@ class FlashCausalLMBatch(Batch):
slots = self.slots[slot_filtering_indices] slots = self.slots[slot_filtering_indices]
next_token_chooser = self.next_token_chooser.filter(indices) next_token_chooser = self.next_token_chooser.filter(indices)
top_n_tokens_tensor = self.top_n_tokens_tensor[indices] top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
speculative_ids = self.speculative_ids[indices] if self.speculative_ids is not None else None
start_slots = torch.tensor(start_slots, dtype=torch.int64) start_slots = torch.tensor(start_slots, dtype=torch.int64)
...@@ -454,6 +460,7 @@ class FlashCausalLMBatch(Batch): ...@@ -454,6 +460,7 @@ class FlashCausalLMBatch(Batch):
top_n_tokens_tensor=top_n_tokens_tensor, top_n_tokens_tensor=top_n_tokens_tensor,
blocks=blocks, blocks=blocks,
max_blocks=max_blocks, max_blocks=max_blocks,
speculative_ids=speculative_ids,
) )
@classmethod @classmethod
...@@ -473,6 +480,7 @@ class FlashCausalLMBatch(Batch): ...@@ -473,6 +480,7 @@ class FlashCausalLMBatch(Batch):
total_batch_size += len(b) total_batch_size += len(b)
total_slots += len(b.slots) total_slots += len(b.slots)
blocks += b.blocks blocks += b.blocks
speculative_length = b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
max_blocks = max(max_blocks, b.max_blocks) max_blocks = max(max_blocks, b.max_blocks)
max_seqlen = max(max_seqlen, b.max_seqlen) max_seqlen = max(max_seqlen, b.max_seqlen)
max_length = max( max_length = max(
...@@ -480,6 +488,7 @@ class FlashCausalLMBatch(Batch): ...@@ -480,6 +488,7 @@ class FlashCausalLMBatch(Batch):
max( max(
input_length input_length
+ stopping_criteria.max_new_tokens + stopping_criteria.max_new_tokens
+ speculative_length
- stopping_criteria.current_tokens - stopping_criteria.current_tokens
for input_length, stopping_criteria in zip( for input_length, stopping_criteria in zip(
b.input_lengths, b.stopping_criterias b.input_lengths, b.stopping_criterias
...@@ -577,6 +586,8 @@ class FlashCausalLMBatch(Batch): ...@@ -577,6 +586,8 @@ class FlashCausalLMBatch(Batch):
device=batches[0].next_token_chooser.device, device=batches[0].next_token_chooser.device,
) )
speculative_ids = torch.cat([b.speculative_ids for b in batches], dim=0) if batches[0].speculative_ids is not None else None
# Needed to avoid dropping blocks when the batches will go out of scope # Needed to avoid dropping blocks when the batches will go out of scope
for b in batches: for b in batches:
b.block_tables = None b.block_tables = None
...@@ -611,6 +622,7 @@ class FlashCausalLMBatch(Batch): ...@@ -611,6 +622,7 @@ class FlashCausalLMBatch(Batch):
top_n_tokens_tensor=top_n_tokens_tensor, top_n_tokens_tensor=top_n_tokens_tensor,
blocks=blocks, blocks=blocks,
max_blocks=max_blocks, max_blocks=max_blocks,
speculative_ids=speculative_ids
) )
def __del__(self): def __del__(self):
...@@ -714,16 +726,55 @@ class FlashCausalLM(Model): ...@@ -714,16 +726,55 @@ class FlashCausalLM(Model):
def forward(self, batch: FlashCausalLMBatch) -> Tuple[torch.Tensor, torch.Tensor]: def forward(self, batch: FlashCausalLMBatch) -> Tuple[torch.Tensor, torch.Tensor]:
# Model Forward # Model Forward
if batch.speculative_ids is not None:
input_ids=batch.input_ids
position_ids=batch.position_ids
cu_seqlen_prefill=batch.cu_seqlen_prefill
kv_cache=get_cache_manager().kv_cache
block_tables=batch.block_tables_tensor
slots=batch.slots[batch.slot_indices]
input_lengths=batch.input_lengths_tensor
max_s=batch.max_seqlen
lm_head_indices=batch.prefill_head_indices
speculative_ids = batch.speculative_ids
B, speculative_length = speculative_ids.shape
new_length = speculative_length + 1
new_input_ids = torch.cat([input_ids.unsqueeze(-1), speculative_ids], dim=1).reshape(-1)
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
arange_int = arange.to(dtype=torch.int32)
new_position_ids = (position_ids.unsqueeze(-1).expand(B, new_length) + arange).view(-1)
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
input_lengths = (input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
# Add Copy the block tables for all members
block_tables = block_tables.unsqueeze(1).expand(B, new_length, -1).reshape(B* new_length, -1).contiguous()
max_s = max_s + speculative_length
input_ids = new_input_ids
position_ids = new_position_ids
else:
input_ids=batch.input_ids
position_ids=batch.position_ids
cu_seqlen_prefill=batch.cu_seqlen_prefill
kv_cache=get_cache_manager().kv_cache
block_tables=batch.block_tables_tensor
slots=batch.slots[batch.slot_indices]
input_lengths=batch.input_lengths_tensor
max_s=batch.max_seqlen
lm_head_indices=batch.prefill_head_indices
return self.model.forward( return self.model.forward(
input_ids=batch.input_ids, input_ids=input_ids,
position_ids=batch.position_ids, position_ids=position_ids,
cu_seqlen_prefill=batch.cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=get_cache_manager().kv_cache, kv_cache=kv_cache,
block_tables=batch.block_tables_tensor, block_tables=block_tables,
slots=batch.slots[batch.slot_indices], slots=slots,
input_lengths=batch.input_lengths_tensor, input_lengths=input_lengths,
max_s=batch.max_seqlen, max_s=max_s,
lm_head_indices=batch.prefill_head_indices, lm_head_indices=lm_head_indices,
) )
@tracer.start_as_current_span("generate_token") @tracer.start_as_current_span("generate_token")
...@@ -752,21 +803,32 @@ class FlashCausalLM(Model): ...@@ -752,21 +803,32 @@ class FlashCausalLM(Model):
del batch del batch
raise e raise e
if isinstance(out, tuple):
out, speculative_logits = out
else:
speculative_logits = None
if prefill: if prefill:
next_token_logits = ( next_token_logits = (
out[batch.prefill_next_token_indices] if prefill_logprobs else out out[batch.prefill_next_token_indices] if prefill_logprobs else out
) )
if speculative_logits is not None:
speculative_logits = (
speculative_logits[batch.prefill_next_token_indices] if prefill_logprobs else speculative_logits
)
else: else:
next_token_logits = out next_token_logits = out
next_input_ids, next_token_logprobs, logprobs = batch.next_token_chooser( next_input_ids, next_token_logprobs, logprobs, accepted_ids, speculative_ids = batch.next_token_chooser(
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, get_speculate(), batch.speculative_ids, speculative_logits
) )
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs
) )
speculative_length = 0 if speculative_ids is None else speculative_ids.shape[1]
if prefill: if prefill:
if len(batch) > 1 and prefill_logprobs: if len(batch) > 1 and prefill_logprobs:
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
...@@ -792,6 +854,7 @@ class FlashCausalLM(Model): ...@@ -792,6 +854,7 @@ class FlashCausalLM(Model):
iterator = zip( iterator = zip(
batch.input_lengths, batch.input_lengths,
batch.all_input_ids, batch.all_input_ids,
accepted_ids
) )
# We do two for loops as the first one can run completely asynchronously from the GPU while for the second # We do two for loops as the first one can run completely asynchronously from the GPU while for the second
...@@ -799,9 +862,11 @@ class FlashCausalLM(Model): ...@@ -799,9 +862,11 @@ class FlashCausalLM(Model):
# It is faster if we delay this sync for the maximum amount of time # It is faster if we delay this sync for the maximum amount of time
# For each member of the batch # For each member of the batch
index = 0
for i, ( for i, (
input_length, input_length,
all_input_ids, all_input_ids,
n_accepted_ids
) in enumerate(iterator): ) in enumerate(iterator):
# Indexing metadata # Indexing metadata
start_index = cumulative_length start_index = cumulative_length
...@@ -830,15 +895,18 @@ class FlashCausalLM(Model): ...@@ -830,15 +895,18 @@ class FlashCausalLM(Model):
start_index + 1 : start_index + out_length start_index + 1 : start_index + out_length
] ]
batch.all_input_ids_tensor[i, input_length] = next_input_ids[i] for j in range(n_accepted_ids):
batch.all_input_ids_tensor[i, input_length + j] = next_input_ids[index]
index += 1
cumulative_length += input_length cumulative_length += input_length
# Set values in batch
batch.input_ids = next_input_ids batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
batch.position_ids = next_position_ids + 1 batch.speculative_ids = speculative_ids
batch.input_lengths_tensor += 1 batch.position_ids = next_position_ids + accepted_ids
batch.slot_indices += 1 batch.input_lengths_tensor += accepted_ids
batch.slot_indices += accepted_ids
if prefill and prefill_logprobs: if prefill and prefill_logprobs:
# Get prefill logprobs # Get prefill logprobs
...@@ -851,7 +919,7 @@ class FlashCausalLM(Model): ...@@ -851,7 +919,7 @@ class FlashCausalLM(Model):
# GPU <-> CPU sync # GPU <-> CPU sync
next_token_logprobs = next_token_logprobs.tolist() next_token_logprobs = next_token_logprobs.tolist()
next_token_ids = batch.input_ids.tolist() next_token_ids = next_input_ids.tolist()
# Zipped iterator # Zipped iterator
iterator = zip( iterator = zip(
...@@ -864,13 +932,13 @@ class FlashCausalLM(Model): ...@@ -864,13 +932,13 @@ class FlashCausalLM(Model):
batch.next_token_chooser.do_sample, batch.next_token_chooser.do_sample,
batch.next_token_chooser.seeds, batch.next_token_chooser.seeds,
batch.top_n_tokens, batch.top_n_tokens,
next_token_ids, accepted_ids,
next_token_logprobs,
batch_top_token_ids, batch_top_token_ids,
batch_top_token_logprobs, batch_top_token_logprobs,
) )
# For each member of the batch # For each member of the batch
index = 0
for i, ( for i, (
request, request,
input_length, input_length,
...@@ -881,29 +949,43 @@ class FlashCausalLM(Model): ...@@ -881,29 +949,43 @@ class FlashCausalLM(Model):
do_sample, do_sample,
seed, seed,
top_n_tokens, top_n_tokens,
next_token_id, n_accepted_ids,
next_token_logprob,
top_token_ids, top_token_ids,
top_token_logprobs, top_token_logprobs,
) in enumerate(iterator): ) in enumerate(iterator):
# Append next token to all tokens # Append next token to all tokens
all_input_ids.append(next_token_id) next_token_texts = []
left = 0
before = stopping_criteria.current_tokens
current_stopped = False
for j in range(index, index + n_accepted_ids):
# Generated token
next_token_id = next_token_ids[j]
all_input_ids.append(next_token_id)
next_token_text, prefix_offset, read_offset = self.decode_token(
all_input_ids,
prefix_offset,
read_offset,
)
next_token_texts.append(next_token_text)
# Generated token stop, reason = stopping_criteria(
next_token_text, prefix_offset, read_offset = self.decode_token( next_token_id,
all_input_ids, next_token_text,
prefix_offset, )
read_offset,
)
# Evaluate stopping criteria if stop:
stop, reason = stopping_criteria( left = index + n_accepted_ids - j - 1
next_token_id, current_stopped = True
next_token_text, break
) else:
current_stopped = False
stopped = stopped and current_stopped
if not stop: _next_token_ids = next_token_ids[index: index+n_accepted_ids - left]
stopped = False _next_token_logprobs = next_token_logprobs[index: index+n_accepted_ids - left]
index += n_accepted_ids
# Shard generations # Shard generations
# All generations will be appended in the rust sharded client # All generations will be appended in the rust sharded client
...@@ -943,8 +1025,9 @@ class FlashCausalLM(Model): ...@@ -943,8 +1025,9 @@ class FlashCausalLM(Model):
clean_up_tokenization_spaces=False, clean_up_tokenization_spaces=False,
skip_special_tokens=False, skip_special_tokens=False,
) )
prefill_tokens = PrefillTokens(
prefill_token_ids, request_prefill_logprobs, prefill_texts prefill_tokens = Tokens(
prefill_token_ids, request_prefill_logprobs, prefill_texts, is_special = []
) )
else: else:
prefill_tokens = None prefill_tokens = None
...@@ -958,7 +1041,7 @@ class FlashCausalLM(Model): ...@@ -958,7 +1041,7 @@ class FlashCausalLM(Model):
special_toptokens = [ special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids token_id in self.all_special_ids for token_id in top_token_ids
] ]
top_tokens = TopTokens( top_tokens = Tokens(
top_token_ids, top_token_ids,
top_token_logprobs, top_token_logprobs,
toptoken_texts, toptoken_texts,
...@@ -970,10 +1053,12 @@ class FlashCausalLM(Model): ...@@ -970,10 +1053,12 @@ class FlashCausalLM(Model):
generation = Generation( generation = Generation(
request.id, request.id,
prefill_tokens, prefill_tokens,
next_token_id, Tokens(
next_token_logprob, _next_token_ids,
next_token_text, _next_token_logprobs,
next_token_id in self.all_special_ids, next_token_texts,
[nid in self.all_special_ids for nid in _next_token_ids],
),
generated_text, generated_text,
top_tokens, top_tokens,
) )
...@@ -981,7 +1066,9 @@ class FlashCausalLM(Model): ...@@ -981,7 +1066,9 @@ class FlashCausalLM(Model):
generations.append(generation) generations.append(generation)
# Update values # Update values
batch.input_lengths[i] = input_length + 1 batch.input_lengths[i] = input_length + n_accepted_ids.item()
if batch.input_lengths[i] > batch.max_seqlen:
batch.max_seqlen = batch.input_lengths[i]
batch.prefix_offsets[i] = prefix_offset batch.prefix_offsets[i] = prefix_offset
batch.read_offsets[i] = read_offset batch.read_offsets[i] = read_offset
batch.all_input_ids[i] = all_input_ids batch.all_input_ids[i] = all_input_ids
...@@ -994,6 +1081,5 @@ class FlashCausalLM(Model): ...@@ -994,6 +1081,5 @@ class FlashCausalLM(Model):
batch.prefill_cu_outlens = None batch.prefill_cu_outlens = None
batch.prefill_head_indices = None batch.prefill_head_indices = None
batch.prefill_next_token_indices = None batch.prefill_next_token_indices = None
batch.max_seqlen = batch.max_seqlen + 1
return generations, batch return generations, batch
...@@ -28,6 +28,7 @@ class FlashLlama(FlashCausalLM): ...@@ -28,6 +28,7 @@ class FlashLlama(FlashCausalLM):
quantize: Optional[str] = None, quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
use_medusa: Optional[str] = None,
): ):
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
...@@ -66,6 +67,18 @@ class FlashLlama(FlashCausalLM): ...@@ -66,6 +67,18 @@ class FlashLlama(FlashCausalLM):
weights._set_gptq_params(model_id) weights._set_gptq_params(model_id)
model = FlashLlamaForCausalLM(config, weights) model = FlashLlamaForCausalLM(config, weights)
if use_medusa:
from text_generation_server.utils.medusa import MedusaModel
from huggingface_hub import hf_hub_download
import json
medusa_config = hf_hub_download(use_medusa, revision=revision, filename="config.json")
with open(medusa_config, "r") as f:
config = json.load(f)
medusa_head = hf_hub_download(use_medusa, revision=revision, filename="medusa_lm_head.pt")
medusa_sf = medusa_head[:-len(".pt")] + ".safetensors"
weights = Weights([medusa_sf], device, dtype, process_group=self.process_group)
lm_head = model.lm_head
model.lm_head = MedusaModel(config, weights, lm_head)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(FlashLlama, self).__init__( super(FlashLlama, self).__init__(
......
...@@ -21,6 +21,7 @@ from text_generation_server.models.custom_modeling.flash_mistral_modeling import ...@@ -21,6 +21,7 @@ from text_generation_server.models.custom_modeling.flash_mistral_modeling import
FlashMistralForCausalLM, FlashMistralForCausalLM,
MistralConfig, MistralConfig,
) )
from text_generation_server.utils.speculate import get_speculate
from text_generation_server.utils import ( from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
...@@ -132,7 +133,8 @@ class FlashMistralBatch(FlashCausalLMBatch): ...@@ -132,7 +133,8 @@ class FlashMistralBatch(FlashCausalLMBatch):
# Paged attention # Paged attention
# Remove one as the first token des not have a past # Remove one as the first token des not have a past
total_tokens = input_length + max_new_tokens - 1 speculative_length = get_speculate()
total_tokens = input_length + max_new_tokens - 1 + speculative_length
# Needed blocks can not go over SLIDING_WINDOW_BLOCKS # Needed blocks can not go over SLIDING_WINDOW_BLOCKS
needed_blocks = min( needed_blocks = min(
...@@ -183,7 +185,7 @@ class FlashMistralBatch(FlashCausalLMBatch): ...@@ -183,7 +185,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
cumulative_max_length += total_tokens cumulative_max_length += total_tokens
max_seqlen = max(max_seqlen, input_length) max_seqlen = max(max_seqlen, input_length)
max_blocks = max(max_blocks, needed_blocks) max_blocks = max(max_blocks, needed_blocks)
max_length = max(max_length, input_length + max_new_tokens) max_length = max(max_length, input_length + max_new_tokens + speculative_length)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype, device next_token_chooser_parameters, dtype, device
...@@ -272,6 +274,7 @@ class FlashMistralBatch(FlashCausalLMBatch): ...@@ -272,6 +274,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
blocks=blocks, blocks=blocks,
max_blocks=max_blocks, max_blocks=max_blocks,
prefill_cache_indices=prefill_cache_indices, prefill_cache_indices=prefill_cache_indices,
speculative_ids=None
) )
...@@ -340,17 +343,55 @@ class FlashMistral(FlashCausalLM): ...@@ -340,17 +343,55 @@ class FlashMistral(FlashCausalLM):
def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, torch.Tensor]: def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, torch.Tensor]:
# Model Forward # Model Forward
if batch.speculative_ids is not None:
input_ids=batch.input_ids
position_ids=batch.position_ids
cu_seqlen_prefill=batch.cu_seqlen_prefill
kv_cache=get_cache_manager().kv_cache
block_tables=batch.block_tables_tensor
slots=batch.slots[batch.slot_indices]
input_lengths=batch.input_lengths_tensor
max_s=batch.max_seqlen
lm_head_indices=batch.prefill_head_indices
speculative_ids = batch.speculative_ids
B, speculative_length = speculative_ids.shape
new_length = speculative_length + 1
new_input_ids = torch.cat([input_ids.unsqueeze(-1), speculative_ids], dim=1).reshape(-1)
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
arange_int = arange.to(dtype=torch.int32)
new_position_ids = (position_ids.unsqueeze(-1).expand(B, new_length) + arange).view(-1)
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
input_lengths = (input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
# Add Copy the block tables for all members
block_tables = block_tables.unsqueeze(1).expand(B, new_length, -1).reshape(B* new_length, -1).contiguous()
max_s = max_s + speculative_length
input_ids = new_input_ids
position_ids = new_position_ids
else:
input_ids=batch.input_ids
position_ids=batch.position_ids
cu_seqlen_prefill=batch.cu_seqlen_prefill
kv_cache=get_cache_manager().kv_cache
block_tables=batch.block_tables_tensor
slots=batch.slots[batch.slot_indices]
input_lengths=batch.input_lengths_tensor
max_s=batch.max_seqlen
lm_head_indices=batch.prefill_head_indices
logits = self.model.forward( logits = self.model.forward(
input_ids=batch.input_ids, input_ids=input_ids,
position_ids=batch.position_ids, position_ids=position_ids,
cu_seqlen_prefill=batch.cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=get_cache_manager().kv_cache, kv_cache=kv_cache,
block_tables=batch.block_tables_tensor, block_tables=block_tables,
slots=batch.slots[batch.slot_indices], slots=slots,
input_lengths=batch.input_lengths_tensor, input_lengths=input_lengths,
max_s=batch.max_seqlen, max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices, prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=batch.prefill_head_indices, lm_head_indices=lm_head_indices,
) )
if batch.prefill_cache_indices is not None: if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None batch.prefill_cache_indices = None
......
...@@ -20,7 +20,7 @@ from typing import Optional, Tuple, List, Type, Dict ...@@ -20,7 +20,7 @@ from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.models import Model from text_generation_server.models import Model
from text_generation_server.models.types import ( from text_generation_server.models.types import (
Batch, Batch,
PrefillTokens, Tokens,
Generation, Generation,
GeneratedText, GeneratedText,
) )
...@@ -791,8 +791,8 @@ class IdeficsCausalLM(Model): ...@@ -791,8 +791,8 @@ class IdeficsCausalLM(Model):
clean_up_tokenization_spaces=False, clean_up_tokenization_spaces=False,
skip_special_tokens=False, skip_special_tokens=False,
) )
prefill_tokens = PrefillTokens( prefill_tokens = Tokens(
prefill_token_ids, prefill_logprobs, prefill_texts prefill_token_ids, prefill_logprobs, prefill_texts, is_special=[]
) )
else: else:
prefill_tokens = None prefill_tokens = None
...@@ -802,10 +802,12 @@ class IdeficsCausalLM(Model): ...@@ -802,10 +802,12 @@ class IdeficsCausalLM(Model):
generation = Generation( generation = Generation(
request.id, request.id,
prefill_tokens, prefill_tokens,
next_token_id_squeezed, Tokens(
next_token_logprob, [next_token_id_squeezed],
next_token_text, [next_token_logprob],
next_token_id_squeezed.item() in self.all_special_ids, [next_token_text],
[next_token_id_squeezed.item() in self.all_special_ids],
),
generated_text, generated_text,
top_tokens, top_tokens,
) )
......
...@@ -6,6 +6,7 @@ from typing import List, Tuple, Optional, TypeVar, Type ...@@ -6,6 +6,7 @@ from typing import List, Tuple, Optional, TypeVar, Type
from transformers import PreTrainedTokenizerBase, PretrainedConfig from transformers import PreTrainedTokenizerBase, PretrainedConfig
from text_generation_server.models.types import Batch, Generation from text_generation_server.models.types import Batch, Generation
from text_generation_server.utils.speculate import get_speculate
from text_generation_server.pb.generate_pb2 import InfoResponse from text_generation_server.pb.generate_pb2 import InfoResponse
B = TypeVar("B", bound=Batch) B = TypeVar("B", bound=Batch)
...@@ -22,6 +23,7 @@ class Model(ABC): ...@@ -22,6 +23,7 @@ class Model(ABC):
rank: int = 0, rank: int = 0,
world_size: int = 1, world_size: int = 1,
sliding_window: Optional[int] = None, sliding_window: Optional[int] = None,
speculate: Optional[int] = None,
): ):
self.model = model.eval() self.model = model.eval()
self.tokenizer = tokenizer self.tokenizer = tokenizer
...@@ -33,6 +35,10 @@ class Model(ABC): ...@@ -33,6 +35,10 @@ class Model(ABC):
self.world_size = world_size self.world_size = world_size
self.sliding_window = sliding_window self.sliding_window = sliding_window
if speculate is None:
speculate = get_speculate()
self.speculate = speculate
self.has_position_ids = ( self.has_position_ids = (
inspect.signature(model.forward).parameters.get("position_ids", None) inspect.signature(model.forward).parameters.get("position_ids", None)
is not None is not None
...@@ -50,6 +56,7 @@ class Model(ABC): ...@@ -50,6 +56,7 @@ class Model(ABC):
dtype=str(self.dtype), dtype=str(self.dtype),
device_type=self.device.type, device_type=self.device.type,
window_size=self.sliding_window, window_size=self.sliding_window,
speculate=self.speculate
) )
@property @property
......
...@@ -11,8 +11,7 @@ from text_generation_server.models.types import ( ...@@ -11,8 +11,7 @@ from text_generation_server.models.types import (
GeneratedText, GeneratedText,
Batch, Batch,
Generation, Generation,
PrefillTokens, Tokens,
TopTokens,
) )
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
...@@ -733,10 +732,11 @@ class Seq2SeqLM(Model): ...@@ -733,10 +732,11 @@ class Seq2SeqLM(Model):
# Prefill # Prefill
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs: if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
prefill_tokens = PrefillTokens( prefill_tokens = Tokens(
[self.tokenizer.bos_token_id], [self.tokenizer.bos_token_id],
[float("nan")], [float("nan")],
[self.tokenizer.bos_token], [self.tokenizer.bos_token],
[False]
) )
else: else:
prefill_tokens = None prefill_tokens = None
...@@ -750,7 +750,7 @@ class Seq2SeqLM(Model): ...@@ -750,7 +750,7 @@ class Seq2SeqLM(Model):
special_toptokens = [ special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids token_id in self.all_special_ids for token_id in top_token_ids
] ]
top_tokens = TopTokens( top_tokens = Tokens(
top_token_ids, top_token_ids,
top_token_logprobs, top_token_logprobs,
toptoken_texts, toptoken_texts,
...@@ -762,10 +762,12 @@ class Seq2SeqLM(Model): ...@@ -762,10 +762,12 @@ class Seq2SeqLM(Model):
generation = Generation( generation = Generation(
request.id, request.id,
prefill_tokens, prefill_tokens,
next_token_id_squeezed, Tokens(
next_token_logprob, [next_token_id_squeezed],
next_token_text, [next_token_logprob],
next_token_id_squeezed.item() in self.all_special_ids, [next_token_text],
[next_token_id_squeezed.item() in self.all_special_ids],
),
generated_text, generated_text,
top_tokens, top_tokens,
) )
......
...@@ -58,33 +58,15 @@ class GeneratedText: ...@@ -58,33 +58,15 @@ class GeneratedText:
@dataclass @dataclass
class PrefillTokens: class Tokens:
token_ids: List[int]
logprobs: List[float]
texts: List[str]
def to_pb(self) -> generate_pb2.PrefillTokens:
return generate_pb2.PrefillTokens(
ids=self.token_ids, logprobs=self.logprobs, texts=self.texts
)
def __len__(self):
return len(self.token_ids)
@dataclass
class TopTokens:
token_ids: List[int] token_ids: List[int]
logprobs: List[float] logprobs: List[float]
texts: List[str] texts: List[str]
is_special: List[bool] is_special: List[bool]
def to_pb(self) -> generate_pb2.TopTokens: def to_pb(self) -> generate_pb2.Tokens:
return generate_pb2.TopTokens( return generate_pb2.Tokens(
ids=self.token_ids, ids=self.token_ids, logprobs=self.logprobs, texts=self.texts, is_special=self.is_special
logprobs=self.logprobs,
texts=self.texts,
is_special=self.is_special,
) )
def __len__(self): def __len__(self):
...@@ -94,14 +76,11 @@ class TopTokens: ...@@ -94,14 +76,11 @@ class TopTokens:
@dataclass @dataclass
class Generation: class Generation:
request_id: int request_id: int
prefill_tokens: Optional[PrefillTokens] prefill_tokens: Optional[Tokens]
token_id: int tokens: Tokens
token_logprob: float
token_text: str
token_is_special: bool
generated_text: Optional[GeneratedText] generated_text: Optional[GeneratedText]
# Optional for now, since it's not yet supported for every model. # Optional for now, since it's not yet supported for every model.
top_tokens: Optional[TopTokens] top_tokens: Optional[List[Tokens]]
def to_pb(self) -> generate_pb2.Generation: def to_pb(self) -> generate_pb2.Generation:
return generate_pb2.Generation( return generate_pb2.Generation(
...@@ -109,10 +88,7 @@ class Generation: ...@@ -109,10 +88,7 @@ class Generation:
prefill_tokens=self.prefill_tokens.to_pb() prefill_tokens=self.prefill_tokens.to_pb()
if self.prefill_tokens is not None if self.prefill_tokens is not None
else None, else None,
token_id=self.token_id, tokens=self.tokens.to_pb(),
token_logprob=self.token_logprob,
token_text=self.token_text,
token_is_special=self.token_is_special,
generated_text=self.generated_text.to_pb() generated_text=self.generated_text.to_pb()
if self.generated_text is not None if self.generated_text is not None
else None, else None,
......
...@@ -132,6 +132,7 @@ def serve( ...@@ -132,6 +132,7 @@ def serve(
revision: Optional[str], revision: Optional[str],
sharded: bool, sharded: bool,
quantize: Optional[str], quantize: Optional[str],
speculate: Optional[int],
dtype: Optional[str], dtype: Optional[str],
trust_remote_code: bool, trust_remote_code: bool,
uds_path: Path, uds_path: Path,
...@@ -141,6 +142,7 @@ def serve( ...@@ -141,6 +142,7 @@ def serve(
revision: Optional[str], revision: Optional[str],
sharded: bool = False, sharded: bool = False,
quantize: Optional[str] = None, quantize: Optional[str] = None,
speculate: Optional[int] = None,
dtype: Optional[str] = None, dtype: Optional[str] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
...@@ -157,7 +159,7 @@ def serve( ...@@ -157,7 +159,7 @@ def serve(
try: try:
model = get_model( model = get_model(
model_id, revision, sharded, quantize, dtype, trust_remote_code model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code
) )
except Exception: except Exception:
logger.exception("Error when initializing model") logger.exception("Error when initializing model")
...@@ -205,5 +207,5 @@ def serve( ...@@ -205,5 +207,5 @@ def serve(
await server.stop(0) await server.stop(0)
asyncio.run( asyncio.run(
serve_inner(model_id, revision, sharded, quantize, dtype, trust_remote_code) serve_inner(model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code)
) )
import torch
from dataclasses import dataclass
from text_generation_server.utils.layers import TensorParallelHead, FastLinear
@dataclass
class Output:
logits: torch.FloatTensor = None
speculative_logits: torch.FloatTensor = None
class ResBlock(torch.nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
self.linear = FastLinear.load(config, prefix=f"{prefix}.linear", weights=weights, bias=True)
self.act = torch.nn.SiLU()
def forward(self, x):
return x + self.act(self.linear(x))
class MedusaModel(torch.nn.Module):
def __init__(
self,
config,
weights,
lm_head
):
super().__init__()
self.heads = torch.nn.ModuleList(
[MedusaHead(config, prefix=f"{i}", weights=weights) for i in range(config["medusa_num_heads"])]
)
self.lm_head = lm_head
def forward(self, x):
logits = self.lm_head(x)
speculative_logits = torch.stack([head(x) for head in self.heads], dim=1)
return logits, speculative_logits
class MedusaHead(torch.nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
self.blocks = torch.nn.ModuleList([ResBlock(config, prefix=f"{prefix}.{i}", weights=weights) for i in range(config["medusa_num_layers"])])
n = len(self.blocks)
self.out = FastLinear.load(config, prefix=f"{prefix}.{n}", weights=weights, bias=False)
def forward(self, x):
for block in self.blocks:
x = block(x)
x = self.out(x)
return x
SPECULATE = None
def get_speculate() -> int:
global SPECULATE
return SPECULATE
def set_speculate(speculate: int):
global SPECULATE
SPECULATE = speculate
...@@ -16,7 +16,6 @@ from text_generation_server.utils.logits_process import ( ...@@ -16,7 +16,6 @@ from text_generation_server.utils.logits_process import (
from text_generation_server.utils.watermark import WatermarkLogitsProcessor from text_generation_server.utils.watermark import WatermarkLogitsProcessor
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
class NextTokenChooser: class NextTokenChooser:
def __init__( def __init__(
self, self,
...@@ -146,6 +145,20 @@ class StoppingCriteria: ...@@ -146,6 +145,20 @@ class StoppingCriteria:
pb.ignore_eos_token, pb.ignore_eos_token,
) )
def create_n_gram_speculation(input_ids: torch.Tensor, next_ids: torch.Tensor, accepted_ids: torch.Tensor, speculate: int, verbose: bool):
# Very trivial approach, find first match in the string.
# This is much less refined than actual n-gram but seems to work
# relatively OK in grounded mode and is by far much faster with
# much less worst case complexity as everything happens on device.
B = accepted_ids.shape[0]
device = input_ids.device
seeds = next_ids[accepted_ids.cumsum(dim=-1) -1 ]
indices = (input_ids == seeds.unsqueeze(-1)).max(dim=1).indices + 1
all_indices = indices.unsqueeze(-1).expand(B, speculate) + torch.arange(speculate, device=device)
all_indices = torch.clamp(all_indices, max=input_ids.shape[1] - 1)
speculative_ids = input_ids.gather(dim=-1, index=all_indices)
return speculative_ids
class HeterogeneousNextTokenChooser: class HeterogeneousNextTokenChooser:
def __init__( def __init__(
...@@ -215,20 +228,79 @@ class HeterogeneousNextTokenChooser: ...@@ -215,20 +228,79 @@ class HeterogeneousNextTokenChooser:
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor): def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, speculate: int, speculated_ids: Optional[torch.Tensor] = None, speculative_scores: Optional[torch.Tensor] = None, verbose=False):
if self.watermark_processor is not None: if speculated_ids is not None:
scores = self.watermark_processor(input_ids, scores) B = scores.shape[0] // (speculated_ids.shape[1] + 1)
if self.repetition_processor is not None: S = speculated_ids.shape[1] + 1
scores = self.repetition_processor(input_ids, scores) scores = scores.view(B, S, -1)
else:
for warper in self.warpers: B = scores.shape[0]
scores = warper(input_ids, scores) S = 1
scores = scores.view(B, S, -1)
next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long)
for j in range(S):
_scores = scores[:, j]
if self.watermark_processor is not None:
_scores = self.watermark_processor(input_ids, _scores)
if self.repetition_processor is not None:
_scores = self.repetition_processor(input_ids, _scores)
for warper in self.warpers:
_scores = warper(input_ids, _scores)
_next_ids = self.choice(_scores)
scores[:, j] = _scores
next_ids[:, j] = _next_ids
next_ids = next_ids.view(B*S)
scores = scores.view( B* S, -1)
if speculated_ids is not None:
accepted_ids = []
B = next_ids.shape[0] // (speculated_ids.shape[1] + 1)
S = speculated_ids.shape[1] + 1
indices = []
for i in range(B):
_next_ids = next_ids[i*S: (i + 1)*S]
_speculated_ids = speculated_ids[i]
validate_speculative = _next_ids[:-1] == _speculated_ids
index = i * S
accepted = 1
# First is always valid
indices.append(index)
for valid in validate_speculative.tolist():
if valid:
index += 1
accepted += 1
indices.append(index)
else:
break
accepted_ids.append(accepted)
accepted_ids = torch.tensor(accepted_ids, device=input_ids.device, dtype=input_ids.dtype)
next_ids = next_ids[indices]
scores = scores[indices]
indices = torch.arange(B, device=input_ids.device) * S
if speculative_scores is not None:
speculative_scores = speculative_scores[indices + accepted_ids - 1]
else:
accepted_ids = torch.ones_like(next_ids)
next_ids = self.choice(scores)
logprobs = torch.log_softmax(scores, -1) logprobs = torch.log_softmax(scores, -1)
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1) next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
return next_ids, next_logprobs, logprobs if speculate > 0:
if speculative_scores is not None:
# Medusa provided some scores
speculative_ids = Greedy()(speculative_scores)
else:
# n-gram
speculative_ids = create_n_gram_speculation(input_ids, next_ids, accepted_ids, speculate, verbose)
else:
speculative_ids = None
return next_ids, next_logprobs, logprobs, accepted_ids, speculative_ids
def filter(self, indices): def filter(self, indices):
if self.watermark_processor is not None: if self.watermark_processor is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment