Unverified Commit 9ecfa16b authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Speculative (#1308)

parent 3238c491
...@@ -32,6 +32,7 @@ def serve( ...@@ -32,6 +32,7 @@ def serve(
revision: Optional[str] = None, revision: Optional[str] = None,
sharded: bool = False, sharded: bool = False,
quantize: Optional[Quantization] = None, quantize: Optional[Quantization] = None,
speculate: Optional[int] = None,
dtype: Optional[Dtype] = None, dtype: Optional[Dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
uds_path: Path = "/tmp/text-generation-server", uds_path: Path = "/tmp/text-generation-server",
...@@ -81,7 +82,7 @@ def serve( ...@@ -81,7 +82,7 @@ def serve(
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model." "Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
) )
server.serve( server.serve(
model_id, revision, sharded, quantize, dtype, trust_remote_code, uds_path model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code, uds_path
) )
...@@ -116,7 +117,7 @@ def download_weights( ...@@ -116,7 +117,7 @@ def download_weights(
logger.info("Files are already present on the host. " "Skipping download.") logger.info("Files are already present on the host. " "Skipping download.")
return return
# Local files not found # Local files not found
except (utils.LocalEntryNotFoundError, FileNotFoundError): except (utils.LocalEntryNotFoundError, FileNotFoundError, utils.EntryNotFoundError):
pass pass
is_local_model = (Path(model_id).exists() and Path(model_id).is_dir()) or os.getenv( is_local_model = (Path(model_id).exists() and Path(model_id).is_dir()) or os.getenv(
...@@ -137,6 +138,29 @@ def download_weights( ...@@ -137,6 +138,29 @@ def download_weights(
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
pass pass
try:
import json
medusa_head = hf_hub_download(model_id, revision=revision, filename="medusa_lm_head.pt")
if auto_convert:
medusa_sf = Path(medusa_head[:-len(".pt")] + ".safetensors")
if not medusa_sf.exists():
utils.convert_files([Path(medusa_head)], [medusa_sf], [])
medusa_config = hf_hub_download(model_id, revision=revision, filename="config.json")
with open(medusa_config, "r") as f:
config = json.load(f)
model_id = config["base_model_name_or_path"]
revision = "main"
try:
utils.weight_files(model_id, revision, extension)
logger.info(f"Files for parent {model_id} are already present on the host. " "Skipping download.")
return
# Local files not found
except (utils.LocalEntryNotFoundError, FileNotFoundError, utils.EntryNotFoundError):
pass
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
pass
# Try to download weights from the hub # Try to download weights from the hub
try: try:
filenames = utils.weight_hub_files(model_id, revision, extension) filenames = utils.weight_hub_files(model_id, revision, extension)
......
...@@ -6,6 +6,7 @@ from transformers.configuration_utils import PretrainedConfig ...@@ -6,6 +6,7 @@ from transformers.configuration_utils import PretrainedConfig
from transformers.models.auto import modeling_auto from transformers.models.auto import modeling_auto
from typing import Optional from typing import Optional
from text_generation_server.utils.speculate import get_speculate, set_speculate
from text_generation_server.models.model import Model from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM from text_generation_server.models.causal_lm import CausalLM
from text_generation_server.models.flash_causal_lm import FlashCausalLM from text_generation_server.models.flash_causal_lm import FlashCausalLM
...@@ -77,12 +78,12 @@ except ImportError as e: ...@@ -77,12 +78,12 @@ except ImportError as e:
if MISTRAL: if MISTRAL:
__all__.append(FlashMistral) __all__.append(FlashMistral)
def get_model( def get_model(
model_id: str, model_id: str,
revision: Optional[str], revision: Optional[str],
sharded: bool, sharded: bool,
quantize: Optional[str], quantize: Optional[str],
speculate: Optional[int],
dtype: Optional[str], dtype: Optional[str],
trust_remote_code: bool, trust_remote_code: bool,
) -> Model: ) -> Model:
...@@ -97,6 +98,11 @@ def get_model( ...@@ -97,6 +98,11 @@ def get_model(
else: else:
raise RuntimeError(f"Unknown dtype {dtype}") raise RuntimeError(f"Unknown dtype {dtype}")
if speculate is not None:
set_speculate(speculate)
else:
set_speculate(0)
if "facebook/galactica" in model_id: if "facebook/galactica" in model_id:
return GalacticaSharded( return GalacticaSharded(
model_id, model_id,
...@@ -131,6 +137,33 @@ def get_model( ...@@ -131,6 +137,33 @@ def get_model(
config_dict, _ = PretrainedConfig.get_config_dict( config_dict, _ = PretrainedConfig.get_config_dict(
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
use_medusa = None
if "medusa_num_heads" in config_dict:
use_medusa = model_id
medusa_config = config_dict
model_id = config_dict["base_model_name_or_path"]
revision = "main"
speculate_medusa = config_dict["medusa_num_heads"]
if speculate is not None:
if speculate > speculate_medusa:
raise RuntimeError("Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match")
else:
set_speculate(speculate)
else:
set_speculate(speculate_medusa)
config_dict, _ = PretrainedConfig.get_config_dict(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
method = "medusa"
else:
method = "n-gram"
speculate = get_speculate()
if speculate > 0:
logger.info(f"Using speculation {method} with {speculate} input ids.")
model_type = config_dict["model_type"] model_type = config_dict["model_type"]
if model_type == "gpt_bigcode": if model_type == "gpt_bigcode":
...@@ -206,6 +239,7 @@ def get_model( ...@@ -206,6 +239,7 @@ def get_model(
quantize=quantize, quantize=quantize,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
use_medusa=use_medusa
) )
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
......
...@@ -10,10 +10,9 @@ from typing import Optional, Tuple, List, Type, Dict ...@@ -10,10 +10,9 @@ from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.models import Model from text_generation_server.models import Model
from text_generation_server.models.types import ( from text_generation_server.models.types import (
Batch, Batch,
PrefillTokens, Tokens,
Generation, Generation,
GeneratedText, GeneratedText,
TopTokens,
) )
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
...@@ -676,8 +675,8 @@ class CausalLM(Model): ...@@ -676,8 +675,8 @@ class CausalLM(Model):
clean_up_tokenization_spaces=False, clean_up_tokenization_spaces=False,
skip_special_tokens=False, skip_special_tokens=False,
) )
prefill_tokens = PrefillTokens( prefill_tokens = Tokens(
prefill_token_ids, prefill_logprobs, prefill_texts prefill_token_ids, prefill_logprobs, prefill_texts, is_special=[]
) )
else: else:
prefill_tokens = None prefill_tokens = None
...@@ -691,7 +690,7 @@ class CausalLM(Model): ...@@ -691,7 +690,7 @@ class CausalLM(Model):
special_toptokens = [ special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids token_id in self.all_special_ids for token_id in top_token_ids
] ]
top_tokens = TopTokens( top_tokens = Tokens(
top_token_ids, top_token_ids,
top_token_logprobs, top_token_logprobs,
toptoken_texts, toptoken_texts,
...@@ -703,10 +702,12 @@ class CausalLM(Model): ...@@ -703,10 +702,12 @@ class CausalLM(Model):
generation = Generation( generation = Generation(
request.id, request.id,
prefill_tokens, prefill_tokens,
next_token_id_squeezed, Tokens(
next_token_logprob, [next_token_id_squeezed],
next_token_text, [next_token_logprob],
next_token_id_squeezed.item() in self.all_special_ids, [next_token_text],
[next_token_id_squeezed.item() in self.all_special_ids],
),
generated_text, generated_text,
top_tokens, top_tokens,
) )
......
...@@ -12,12 +12,12 @@ from transformers import PreTrainedTokenizerBase ...@@ -12,12 +12,12 @@ from transformers import PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type, Union, Dict from typing import Optional, Tuple, List, Type, Union, Dict
from text_generation_server.models import Model from text_generation_server.models import Model
from text_generation_server.utils.speculate import get_speculate
from text_generation_server.models.types import ( from text_generation_server.models.types import (
Batch, Batch,
PrefillTokens, Tokens,
Generation, Generation,
GeneratedText, GeneratedText,
TopTokens,
) )
from text_generation_server.models.cache_manager import ( from text_generation_server.models.cache_manager import (
get_cache_manager, get_cache_manager,
...@@ -41,6 +41,7 @@ class FlashCausalLMBatch(Batch): ...@@ -41,6 +41,7 @@ class FlashCausalLMBatch(Batch):
# Decoder values # Decoder values
input_ids: torch.Tensor input_ids: torch.Tensor
position_ids: torch.Tensor position_ids: torch.Tensor
speculative_ids: torch.Tensor
# Flash Attention values # Flash Attention values
...@@ -120,6 +121,7 @@ class FlashCausalLMBatch(Batch): ...@@ -120,6 +121,7 @@ class FlashCausalLMBatch(Batch):
)["input_ids"] )["input_ids"]
position_ids = [] position_ids = []
speculative_ids = []
cu_seqlen_prefill = [0] cu_seqlen_prefill = [0]
needed_blocks_slots = [] needed_blocks_slots = []
start_slots = [] start_slots = []
...@@ -163,6 +165,8 @@ class FlashCausalLMBatch(Batch): ...@@ -163,6 +165,8 @@ class FlashCausalLMBatch(Batch):
input_length = len(tokenized_input) input_length = len(tokenized_input)
input_lengths.append(input_length) input_lengths.append(input_length)
prefix_offsets.append(input_length - 5) prefix_offsets.append(input_length - 5)
read_offsets.append(input_length) read_offsets.append(input_length)
...@@ -186,7 +190,8 @@ class FlashCausalLMBatch(Batch): ...@@ -186,7 +190,8 @@ class FlashCausalLMBatch(Batch):
# Paged attention # Paged attention
# Remove one as the first token des not have a past # Remove one as the first token des not have a past
total_tokens = input_length + max_new_tokens - 1 speculative_length = get_speculate()
total_tokens = input_length + max_new_tokens - 1 + speculative_length
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE) needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
blocks += needed_blocks blocks += needed_blocks
needed_blocks_slots.append((needed_blocks, total_tokens)) needed_blocks_slots.append((needed_blocks, total_tokens))
...@@ -224,7 +229,7 @@ class FlashCausalLMBatch(Batch): ...@@ -224,7 +229,7 @@ class FlashCausalLMBatch(Batch):
cumulative_max_length += total_tokens cumulative_max_length += total_tokens
max_seqlen = max(max_seqlen, input_length) max_seqlen = max(max_seqlen, input_length)
max_blocks = max(max_blocks, needed_blocks) max_blocks = max(max_blocks, needed_blocks)
max_length = max(max_length, input_length + max_new_tokens) max_length = max(max_length, input_length + max_new_tokens + speculative_length)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype, device next_token_chooser_parameters, dtype, device
...@@ -255,7 +260,6 @@ class FlashCausalLMBatch(Batch): ...@@ -255,7 +260,6 @@ class FlashCausalLMBatch(Batch):
cu_seqlen_prefill = torch.tensor( cu_seqlen_prefill = torch.tensor(
cu_seqlen_prefill, device=device, dtype=torch.int32 cu_seqlen_prefill, device=device, dtype=torch.int32
) )
position_ids = position_ids.to(device) position_ids = position_ids.to(device)
slot_indices = slot_indices.to(device) slot_indices = slot_indices.to(device)
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
...@@ -309,6 +313,7 @@ class FlashCausalLMBatch(Batch): ...@@ -309,6 +313,7 @@ class FlashCausalLMBatch(Batch):
top_n_tokens_tensor=top_n_tokens_tensor, top_n_tokens_tensor=top_n_tokens_tensor,
blocks=blocks, blocks=blocks,
max_blocks=max_blocks, max_blocks=max_blocks,
speculative_ids=None,
) )
@tracer.start_as_current_span("filter") @tracer.start_as_current_span("filter")
...@@ -419,6 +424,7 @@ class FlashCausalLMBatch(Batch): ...@@ -419,6 +424,7 @@ class FlashCausalLMBatch(Batch):
slots = self.slots[slot_filtering_indices] slots = self.slots[slot_filtering_indices]
next_token_chooser = self.next_token_chooser.filter(indices) next_token_chooser = self.next_token_chooser.filter(indices)
top_n_tokens_tensor = self.top_n_tokens_tensor[indices] top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
speculative_ids = self.speculative_ids[indices] if self.speculative_ids is not None else None
start_slots = torch.tensor(start_slots, dtype=torch.int64) start_slots = torch.tensor(start_slots, dtype=torch.int64)
...@@ -454,6 +460,7 @@ class FlashCausalLMBatch(Batch): ...@@ -454,6 +460,7 @@ class FlashCausalLMBatch(Batch):
top_n_tokens_tensor=top_n_tokens_tensor, top_n_tokens_tensor=top_n_tokens_tensor,
blocks=blocks, blocks=blocks,
max_blocks=max_blocks, max_blocks=max_blocks,
speculative_ids=speculative_ids,
) )
@classmethod @classmethod
...@@ -473,6 +480,7 @@ class FlashCausalLMBatch(Batch): ...@@ -473,6 +480,7 @@ class FlashCausalLMBatch(Batch):
total_batch_size += len(b) total_batch_size += len(b)
total_slots += len(b.slots) total_slots += len(b.slots)
blocks += b.blocks blocks += b.blocks
speculative_length = b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
max_blocks = max(max_blocks, b.max_blocks) max_blocks = max(max_blocks, b.max_blocks)
max_seqlen = max(max_seqlen, b.max_seqlen) max_seqlen = max(max_seqlen, b.max_seqlen)
max_length = max( max_length = max(
...@@ -480,6 +488,7 @@ class FlashCausalLMBatch(Batch): ...@@ -480,6 +488,7 @@ class FlashCausalLMBatch(Batch):
max( max(
input_length input_length
+ stopping_criteria.max_new_tokens + stopping_criteria.max_new_tokens
+ speculative_length
- stopping_criteria.current_tokens - stopping_criteria.current_tokens
for input_length, stopping_criteria in zip( for input_length, stopping_criteria in zip(
b.input_lengths, b.stopping_criterias b.input_lengths, b.stopping_criterias
...@@ -577,6 +586,8 @@ class FlashCausalLMBatch(Batch): ...@@ -577,6 +586,8 @@ class FlashCausalLMBatch(Batch):
device=batches[0].next_token_chooser.device, device=batches[0].next_token_chooser.device,
) )
speculative_ids = torch.cat([b.speculative_ids for b in batches], dim=0) if batches[0].speculative_ids is not None else None
# Needed to avoid dropping blocks when the batches will go out of scope # Needed to avoid dropping blocks when the batches will go out of scope
for b in batches: for b in batches:
b.block_tables = None b.block_tables = None
...@@ -611,6 +622,7 @@ class FlashCausalLMBatch(Batch): ...@@ -611,6 +622,7 @@ class FlashCausalLMBatch(Batch):
top_n_tokens_tensor=top_n_tokens_tensor, top_n_tokens_tensor=top_n_tokens_tensor,
blocks=blocks, blocks=blocks,
max_blocks=max_blocks, max_blocks=max_blocks,
speculative_ids=speculative_ids
) )
def __del__(self): def __del__(self):
...@@ -714,16 +726,55 @@ class FlashCausalLM(Model): ...@@ -714,16 +726,55 @@ class FlashCausalLM(Model):
def forward(self, batch: FlashCausalLMBatch) -> Tuple[torch.Tensor, torch.Tensor]: def forward(self, batch: FlashCausalLMBatch) -> Tuple[torch.Tensor, torch.Tensor]:
# Model Forward # Model Forward
if batch.speculative_ids is not None:
input_ids=batch.input_ids
position_ids=batch.position_ids
cu_seqlen_prefill=batch.cu_seqlen_prefill
kv_cache=get_cache_manager().kv_cache
block_tables=batch.block_tables_tensor
slots=batch.slots[batch.slot_indices]
input_lengths=batch.input_lengths_tensor
max_s=batch.max_seqlen
lm_head_indices=batch.prefill_head_indices
speculative_ids = batch.speculative_ids
B, speculative_length = speculative_ids.shape
new_length = speculative_length + 1
new_input_ids = torch.cat([input_ids.unsqueeze(-1), speculative_ids], dim=1).reshape(-1)
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
arange_int = arange.to(dtype=torch.int32)
new_position_ids = (position_ids.unsqueeze(-1).expand(B, new_length) + arange).view(-1)
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
input_lengths = (input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
# Add Copy the block tables for all members
block_tables = block_tables.unsqueeze(1).expand(B, new_length, -1).reshape(B* new_length, -1).contiguous()
max_s = max_s + speculative_length
input_ids = new_input_ids
position_ids = new_position_ids
else:
input_ids=batch.input_ids
position_ids=batch.position_ids
cu_seqlen_prefill=batch.cu_seqlen_prefill
kv_cache=get_cache_manager().kv_cache
block_tables=batch.block_tables_tensor
slots=batch.slots[batch.slot_indices]
input_lengths=batch.input_lengths_tensor
max_s=batch.max_seqlen
lm_head_indices=batch.prefill_head_indices
return self.model.forward( return self.model.forward(
input_ids=batch.input_ids, input_ids=input_ids,
position_ids=batch.position_ids, position_ids=position_ids,
cu_seqlen_prefill=batch.cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=get_cache_manager().kv_cache, kv_cache=kv_cache,
block_tables=batch.block_tables_tensor, block_tables=block_tables,
slots=batch.slots[batch.slot_indices], slots=slots,
input_lengths=batch.input_lengths_tensor, input_lengths=input_lengths,
max_s=batch.max_seqlen, max_s=max_s,
lm_head_indices=batch.prefill_head_indices, lm_head_indices=lm_head_indices,
) )
@tracer.start_as_current_span("generate_token") @tracer.start_as_current_span("generate_token")
...@@ -752,21 +803,32 @@ class FlashCausalLM(Model): ...@@ -752,21 +803,32 @@ class FlashCausalLM(Model):
del batch del batch
raise e raise e
if isinstance(out, tuple):
out, speculative_logits = out
else:
speculative_logits = None
if prefill: if prefill:
next_token_logits = ( next_token_logits = (
out[batch.prefill_next_token_indices] if prefill_logprobs else out out[batch.prefill_next_token_indices] if prefill_logprobs else out
) )
if speculative_logits is not None:
speculative_logits = (
speculative_logits[batch.prefill_next_token_indices] if prefill_logprobs else speculative_logits
)
else: else:
next_token_logits = out next_token_logits = out
next_input_ids, next_token_logprobs, logprobs = batch.next_token_chooser( next_input_ids, next_token_logprobs, logprobs, accepted_ids, speculative_ids = batch.next_token_chooser(
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, get_speculate(), batch.speculative_ids, speculative_logits
) )
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs
) )
speculative_length = 0 if speculative_ids is None else speculative_ids.shape[1]
if prefill: if prefill:
if len(batch) > 1 and prefill_logprobs: if len(batch) > 1 and prefill_logprobs:
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
...@@ -792,6 +854,7 @@ class FlashCausalLM(Model): ...@@ -792,6 +854,7 @@ class FlashCausalLM(Model):
iterator = zip( iterator = zip(
batch.input_lengths, batch.input_lengths,
batch.all_input_ids, batch.all_input_ids,
accepted_ids
) )
# We do two for loops as the first one can run completely asynchronously from the GPU while for the second # We do two for loops as the first one can run completely asynchronously from the GPU while for the second
...@@ -799,9 +862,11 @@ class FlashCausalLM(Model): ...@@ -799,9 +862,11 @@ class FlashCausalLM(Model):
# It is faster if we delay this sync for the maximum amount of time # It is faster if we delay this sync for the maximum amount of time
# For each member of the batch # For each member of the batch
index = 0
for i, ( for i, (
input_length, input_length,
all_input_ids, all_input_ids,
n_accepted_ids
) in enumerate(iterator): ) in enumerate(iterator):
# Indexing metadata # Indexing metadata
start_index = cumulative_length start_index = cumulative_length
...@@ -830,15 +895,18 @@ class FlashCausalLM(Model): ...@@ -830,15 +895,18 @@ class FlashCausalLM(Model):
start_index + 1 : start_index + out_length start_index + 1 : start_index + out_length
] ]
batch.all_input_ids_tensor[i, input_length] = next_input_ids[i] for j in range(n_accepted_ids):
batch.all_input_ids_tensor[i, input_length + j] = next_input_ids[index]
index += 1
cumulative_length += input_length cumulative_length += input_length
# Set values in batch
batch.input_ids = next_input_ids batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
batch.position_ids = next_position_ids + 1 batch.speculative_ids = speculative_ids
batch.input_lengths_tensor += 1 batch.position_ids = next_position_ids + accepted_ids
batch.slot_indices += 1 batch.input_lengths_tensor += accepted_ids
batch.slot_indices += accepted_ids
if prefill and prefill_logprobs: if prefill and prefill_logprobs:
# Get prefill logprobs # Get prefill logprobs
...@@ -851,7 +919,7 @@ class FlashCausalLM(Model): ...@@ -851,7 +919,7 @@ class FlashCausalLM(Model):
# GPU <-> CPU sync # GPU <-> CPU sync
next_token_logprobs = next_token_logprobs.tolist() next_token_logprobs = next_token_logprobs.tolist()
next_token_ids = batch.input_ids.tolist() next_token_ids = next_input_ids.tolist()
# Zipped iterator # Zipped iterator
iterator = zip( iterator = zip(
...@@ -864,13 +932,13 @@ class FlashCausalLM(Model): ...@@ -864,13 +932,13 @@ class FlashCausalLM(Model):
batch.next_token_chooser.do_sample, batch.next_token_chooser.do_sample,
batch.next_token_chooser.seeds, batch.next_token_chooser.seeds,
batch.top_n_tokens, batch.top_n_tokens,
next_token_ids, accepted_ids,
next_token_logprobs,
batch_top_token_ids, batch_top_token_ids,
batch_top_token_logprobs, batch_top_token_logprobs,
) )
# For each member of the batch # For each member of the batch
index = 0
for i, ( for i, (
request, request,
input_length, input_length,
...@@ -881,29 +949,43 @@ class FlashCausalLM(Model): ...@@ -881,29 +949,43 @@ class FlashCausalLM(Model):
do_sample, do_sample,
seed, seed,
top_n_tokens, top_n_tokens,
next_token_id, n_accepted_ids,
next_token_logprob,
top_token_ids, top_token_ids,
top_token_logprobs, top_token_logprobs,
) in enumerate(iterator): ) in enumerate(iterator):
# Append next token to all tokens # Append next token to all tokens
all_input_ids.append(next_token_id) next_token_texts = []
left = 0
before = stopping_criteria.current_tokens
current_stopped = False
for j in range(index, index + n_accepted_ids):
# Generated token # Generated token
next_token_id = next_token_ids[j]
all_input_ids.append(next_token_id)
next_token_text, prefix_offset, read_offset = self.decode_token( next_token_text, prefix_offset, read_offset = self.decode_token(
all_input_ids, all_input_ids,
prefix_offset, prefix_offset,
read_offset, read_offset,
) )
next_token_texts.append(next_token_text)
# Evaluate stopping criteria
stop, reason = stopping_criteria( stop, reason = stopping_criteria(
next_token_id, next_token_id,
next_token_text, next_token_text,
) )
if not stop: if stop:
stopped = False left = index + n_accepted_ids - j - 1
current_stopped = True
break
else:
current_stopped = False
stopped = stopped and current_stopped
_next_token_ids = next_token_ids[index: index+n_accepted_ids - left]
_next_token_logprobs = next_token_logprobs[index: index+n_accepted_ids - left]
index += n_accepted_ids
# Shard generations # Shard generations
# All generations will be appended in the rust sharded client # All generations will be appended in the rust sharded client
...@@ -943,8 +1025,9 @@ class FlashCausalLM(Model): ...@@ -943,8 +1025,9 @@ class FlashCausalLM(Model):
clean_up_tokenization_spaces=False, clean_up_tokenization_spaces=False,
skip_special_tokens=False, skip_special_tokens=False,
) )
prefill_tokens = PrefillTokens(
prefill_token_ids, request_prefill_logprobs, prefill_texts prefill_tokens = Tokens(
prefill_token_ids, request_prefill_logprobs, prefill_texts, is_special = []
) )
else: else:
prefill_tokens = None prefill_tokens = None
...@@ -958,7 +1041,7 @@ class FlashCausalLM(Model): ...@@ -958,7 +1041,7 @@ class FlashCausalLM(Model):
special_toptokens = [ special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids token_id in self.all_special_ids for token_id in top_token_ids
] ]
top_tokens = TopTokens( top_tokens = Tokens(
top_token_ids, top_token_ids,
top_token_logprobs, top_token_logprobs,
toptoken_texts, toptoken_texts,
...@@ -970,10 +1053,12 @@ class FlashCausalLM(Model): ...@@ -970,10 +1053,12 @@ class FlashCausalLM(Model):
generation = Generation( generation = Generation(
request.id, request.id,
prefill_tokens, prefill_tokens,
next_token_id, Tokens(
next_token_logprob, _next_token_ids,
next_token_text, _next_token_logprobs,
next_token_id in self.all_special_ids, next_token_texts,
[nid in self.all_special_ids for nid in _next_token_ids],
),
generated_text, generated_text,
top_tokens, top_tokens,
) )
...@@ -981,7 +1066,9 @@ class FlashCausalLM(Model): ...@@ -981,7 +1066,9 @@ class FlashCausalLM(Model):
generations.append(generation) generations.append(generation)
# Update values # Update values
batch.input_lengths[i] = input_length + 1 batch.input_lengths[i] = input_length + n_accepted_ids.item()
if batch.input_lengths[i] > batch.max_seqlen:
batch.max_seqlen = batch.input_lengths[i]
batch.prefix_offsets[i] = prefix_offset batch.prefix_offsets[i] = prefix_offset
batch.read_offsets[i] = read_offset batch.read_offsets[i] = read_offset
batch.all_input_ids[i] = all_input_ids batch.all_input_ids[i] = all_input_ids
...@@ -994,6 +1081,5 @@ class FlashCausalLM(Model): ...@@ -994,6 +1081,5 @@ class FlashCausalLM(Model):
batch.prefill_cu_outlens = None batch.prefill_cu_outlens = None
batch.prefill_head_indices = None batch.prefill_head_indices = None
batch.prefill_next_token_indices = None batch.prefill_next_token_indices = None
batch.max_seqlen = batch.max_seqlen + 1
return generations, batch return generations, batch
...@@ -28,6 +28,7 @@ class FlashLlama(FlashCausalLM): ...@@ -28,6 +28,7 @@ class FlashLlama(FlashCausalLM):
quantize: Optional[str] = None, quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
use_medusa: Optional[str] = None,
): ):
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
...@@ -66,6 +67,18 @@ class FlashLlama(FlashCausalLM): ...@@ -66,6 +67,18 @@ class FlashLlama(FlashCausalLM):
weights._set_gptq_params(model_id) weights._set_gptq_params(model_id)
model = FlashLlamaForCausalLM(config, weights) model = FlashLlamaForCausalLM(config, weights)
if use_medusa:
from text_generation_server.utils.medusa import MedusaModel
from huggingface_hub import hf_hub_download
import json
medusa_config = hf_hub_download(use_medusa, revision=revision, filename="config.json")
with open(medusa_config, "r") as f:
config = json.load(f)
medusa_head = hf_hub_download(use_medusa, revision=revision, filename="medusa_lm_head.pt")
medusa_sf = medusa_head[:-len(".pt")] + ".safetensors"
weights = Weights([medusa_sf], device, dtype, process_group=self.process_group)
lm_head = model.lm_head
model.lm_head = MedusaModel(config, weights, lm_head)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(FlashLlama, self).__init__( super(FlashLlama, self).__init__(
......
...@@ -21,6 +21,7 @@ from text_generation_server.models.custom_modeling.flash_mistral_modeling import ...@@ -21,6 +21,7 @@ from text_generation_server.models.custom_modeling.flash_mistral_modeling import
FlashMistralForCausalLM, FlashMistralForCausalLM,
MistralConfig, MistralConfig,
) )
from text_generation_server.utils.speculate import get_speculate
from text_generation_server.utils import ( from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
...@@ -132,7 +133,8 @@ class FlashMistralBatch(FlashCausalLMBatch): ...@@ -132,7 +133,8 @@ class FlashMistralBatch(FlashCausalLMBatch):
# Paged attention # Paged attention
# Remove one as the first token des not have a past # Remove one as the first token des not have a past
total_tokens = input_length + max_new_tokens - 1 speculative_length = get_speculate()
total_tokens = input_length + max_new_tokens - 1 + speculative_length
# Needed blocks can not go over SLIDING_WINDOW_BLOCKS # Needed blocks can not go over SLIDING_WINDOW_BLOCKS
needed_blocks = min( needed_blocks = min(
...@@ -183,7 +185,7 @@ class FlashMistralBatch(FlashCausalLMBatch): ...@@ -183,7 +185,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
cumulative_max_length += total_tokens cumulative_max_length += total_tokens
max_seqlen = max(max_seqlen, input_length) max_seqlen = max(max_seqlen, input_length)
max_blocks = max(max_blocks, needed_blocks) max_blocks = max(max_blocks, needed_blocks)
max_length = max(max_length, input_length + max_new_tokens) max_length = max(max_length, input_length + max_new_tokens + speculative_length)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype, device next_token_chooser_parameters, dtype, device
...@@ -272,6 +274,7 @@ class FlashMistralBatch(FlashCausalLMBatch): ...@@ -272,6 +274,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
blocks=blocks, blocks=blocks,
max_blocks=max_blocks, max_blocks=max_blocks,
prefill_cache_indices=prefill_cache_indices, prefill_cache_indices=prefill_cache_indices,
speculative_ids=None
) )
...@@ -340,17 +343,55 @@ class FlashMistral(FlashCausalLM): ...@@ -340,17 +343,55 @@ class FlashMistral(FlashCausalLM):
def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, torch.Tensor]: def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, torch.Tensor]:
# Model Forward # Model Forward
if batch.speculative_ids is not None:
input_ids=batch.input_ids
position_ids=batch.position_ids
cu_seqlen_prefill=batch.cu_seqlen_prefill
kv_cache=get_cache_manager().kv_cache
block_tables=batch.block_tables_tensor
slots=batch.slots[batch.slot_indices]
input_lengths=batch.input_lengths_tensor
max_s=batch.max_seqlen
lm_head_indices=batch.prefill_head_indices
speculative_ids = batch.speculative_ids
B, speculative_length = speculative_ids.shape
new_length = speculative_length + 1
new_input_ids = torch.cat([input_ids.unsqueeze(-1), speculative_ids], dim=1).reshape(-1)
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
arange_int = arange.to(dtype=torch.int32)
new_position_ids = (position_ids.unsqueeze(-1).expand(B, new_length) + arange).view(-1)
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
input_lengths = (input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
# Add Copy the block tables for all members
block_tables = block_tables.unsqueeze(1).expand(B, new_length, -1).reshape(B* new_length, -1).contiguous()
max_s = max_s + speculative_length
input_ids = new_input_ids
position_ids = new_position_ids
else:
input_ids=batch.input_ids
position_ids=batch.position_ids
cu_seqlen_prefill=batch.cu_seqlen_prefill
kv_cache=get_cache_manager().kv_cache
block_tables=batch.block_tables_tensor
slots=batch.slots[batch.slot_indices]
input_lengths=batch.input_lengths_tensor
max_s=batch.max_seqlen
lm_head_indices=batch.prefill_head_indices
logits = self.model.forward( logits = self.model.forward(
input_ids=batch.input_ids, input_ids=input_ids,
position_ids=batch.position_ids, position_ids=position_ids,
cu_seqlen_prefill=batch.cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=get_cache_manager().kv_cache, kv_cache=kv_cache,
block_tables=batch.block_tables_tensor, block_tables=block_tables,
slots=batch.slots[batch.slot_indices], slots=slots,
input_lengths=batch.input_lengths_tensor, input_lengths=input_lengths,
max_s=batch.max_seqlen, max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices, prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=batch.prefill_head_indices, lm_head_indices=lm_head_indices,
) )
if batch.prefill_cache_indices is not None: if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None batch.prefill_cache_indices = None
......
...@@ -20,7 +20,7 @@ from typing import Optional, Tuple, List, Type, Dict ...@@ -20,7 +20,7 @@ from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.models import Model from text_generation_server.models import Model
from text_generation_server.models.types import ( from text_generation_server.models.types import (
Batch, Batch,
PrefillTokens, Tokens,
Generation, Generation,
GeneratedText, GeneratedText,
) )
...@@ -791,8 +791,8 @@ class IdeficsCausalLM(Model): ...@@ -791,8 +791,8 @@ class IdeficsCausalLM(Model):
clean_up_tokenization_spaces=False, clean_up_tokenization_spaces=False,
skip_special_tokens=False, skip_special_tokens=False,
) )
prefill_tokens = PrefillTokens( prefill_tokens = Tokens(
prefill_token_ids, prefill_logprobs, prefill_texts prefill_token_ids, prefill_logprobs, prefill_texts, is_special=[]
) )
else: else:
prefill_tokens = None prefill_tokens = None
...@@ -802,10 +802,12 @@ class IdeficsCausalLM(Model): ...@@ -802,10 +802,12 @@ class IdeficsCausalLM(Model):
generation = Generation( generation = Generation(
request.id, request.id,
prefill_tokens, prefill_tokens,
next_token_id_squeezed, Tokens(
next_token_logprob, [next_token_id_squeezed],
next_token_text, [next_token_logprob],
next_token_id_squeezed.item() in self.all_special_ids, [next_token_text],
[next_token_id_squeezed.item() in self.all_special_ids],
),
generated_text, generated_text,
top_tokens, top_tokens,
) )
......
...@@ -6,6 +6,7 @@ from typing import List, Tuple, Optional, TypeVar, Type ...@@ -6,6 +6,7 @@ from typing import List, Tuple, Optional, TypeVar, Type
from transformers import PreTrainedTokenizerBase, PretrainedConfig from transformers import PreTrainedTokenizerBase, PretrainedConfig
from text_generation_server.models.types import Batch, Generation from text_generation_server.models.types import Batch, Generation
from text_generation_server.utils.speculate import get_speculate
from text_generation_server.pb.generate_pb2 import InfoResponse from text_generation_server.pb.generate_pb2 import InfoResponse
B = TypeVar("B", bound=Batch) B = TypeVar("B", bound=Batch)
...@@ -22,6 +23,7 @@ class Model(ABC): ...@@ -22,6 +23,7 @@ class Model(ABC):
rank: int = 0, rank: int = 0,
world_size: int = 1, world_size: int = 1,
sliding_window: Optional[int] = None, sliding_window: Optional[int] = None,
speculate: Optional[int] = None,
): ):
self.model = model.eval() self.model = model.eval()
self.tokenizer = tokenizer self.tokenizer = tokenizer
...@@ -33,6 +35,10 @@ class Model(ABC): ...@@ -33,6 +35,10 @@ class Model(ABC):
self.world_size = world_size self.world_size = world_size
self.sliding_window = sliding_window self.sliding_window = sliding_window
if speculate is None:
speculate = get_speculate()
self.speculate = speculate
self.has_position_ids = ( self.has_position_ids = (
inspect.signature(model.forward).parameters.get("position_ids", None) inspect.signature(model.forward).parameters.get("position_ids", None)
is not None is not None
...@@ -50,6 +56,7 @@ class Model(ABC): ...@@ -50,6 +56,7 @@ class Model(ABC):
dtype=str(self.dtype), dtype=str(self.dtype),
device_type=self.device.type, device_type=self.device.type,
window_size=self.sliding_window, window_size=self.sliding_window,
speculate=self.speculate
) )
@property @property
......
...@@ -11,8 +11,7 @@ from text_generation_server.models.types import ( ...@@ -11,8 +11,7 @@ from text_generation_server.models.types import (
GeneratedText, GeneratedText,
Batch, Batch,
Generation, Generation,
PrefillTokens, Tokens,
TopTokens,
) )
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
...@@ -733,10 +732,11 @@ class Seq2SeqLM(Model): ...@@ -733,10 +732,11 @@ class Seq2SeqLM(Model):
# Prefill # Prefill
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs: if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
prefill_tokens = PrefillTokens( prefill_tokens = Tokens(
[self.tokenizer.bos_token_id], [self.tokenizer.bos_token_id],
[float("nan")], [float("nan")],
[self.tokenizer.bos_token], [self.tokenizer.bos_token],
[False]
) )
else: else:
prefill_tokens = None prefill_tokens = None
...@@ -750,7 +750,7 @@ class Seq2SeqLM(Model): ...@@ -750,7 +750,7 @@ class Seq2SeqLM(Model):
special_toptokens = [ special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids token_id in self.all_special_ids for token_id in top_token_ids
] ]
top_tokens = TopTokens( top_tokens = Tokens(
top_token_ids, top_token_ids,
top_token_logprobs, top_token_logprobs,
toptoken_texts, toptoken_texts,
...@@ -762,10 +762,12 @@ class Seq2SeqLM(Model): ...@@ -762,10 +762,12 @@ class Seq2SeqLM(Model):
generation = Generation( generation = Generation(
request.id, request.id,
prefill_tokens, prefill_tokens,
next_token_id_squeezed, Tokens(
next_token_logprob, [next_token_id_squeezed],
next_token_text, [next_token_logprob],
next_token_id_squeezed.item() in self.all_special_ids, [next_token_text],
[next_token_id_squeezed.item() in self.all_special_ids],
),
generated_text, generated_text,
top_tokens, top_tokens,
) )
......
...@@ -58,33 +58,15 @@ class GeneratedText: ...@@ -58,33 +58,15 @@ class GeneratedText:
@dataclass @dataclass
class PrefillTokens: class Tokens:
token_ids: List[int]
logprobs: List[float]
texts: List[str]
def to_pb(self) -> generate_pb2.PrefillTokens:
return generate_pb2.PrefillTokens(
ids=self.token_ids, logprobs=self.logprobs, texts=self.texts
)
def __len__(self):
return len(self.token_ids)
@dataclass
class TopTokens:
token_ids: List[int] token_ids: List[int]
logprobs: List[float] logprobs: List[float]
texts: List[str] texts: List[str]
is_special: List[bool] is_special: List[bool]
def to_pb(self) -> generate_pb2.TopTokens: def to_pb(self) -> generate_pb2.Tokens:
return generate_pb2.TopTokens( return generate_pb2.Tokens(
ids=self.token_ids, ids=self.token_ids, logprobs=self.logprobs, texts=self.texts, is_special=self.is_special
logprobs=self.logprobs,
texts=self.texts,
is_special=self.is_special,
) )
def __len__(self): def __len__(self):
...@@ -94,14 +76,11 @@ class TopTokens: ...@@ -94,14 +76,11 @@ class TopTokens:
@dataclass @dataclass
class Generation: class Generation:
request_id: int request_id: int
prefill_tokens: Optional[PrefillTokens] prefill_tokens: Optional[Tokens]
token_id: int tokens: Tokens
token_logprob: float
token_text: str
token_is_special: bool
generated_text: Optional[GeneratedText] generated_text: Optional[GeneratedText]
# Optional for now, since it's not yet supported for every model. # Optional for now, since it's not yet supported for every model.
top_tokens: Optional[TopTokens] top_tokens: Optional[List[Tokens]]
def to_pb(self) -> generate_pb2.Generation: def to_pb(self) -> generate_pb2.Generation:
return generate_pb2.Generation( return generate_pb2.Generation(
...@@ -109,10 +88,7 @@ class Generation: ...@@ -109,10 +88,7 @@ class Generation:
prefill_tokens=self.prefill_tokens.to_pb() prefill_tokens=self.prefill_tokens.to_pb()
if self.prefill_tokens is not None if self.prefill_tokens is not None
else None, else None,
token_id=self.token_id, tokens=self.tokens.to_pb(),
token_logprob=self.token_logprob,
token_text=self.token_text,
token_is_special=self.token_is_special,
generated_text=self.generated_text.to_pb() generated_text=self.generated_text.to_pb()
if self.generated_text is not None if self.generated_text is not None
else None, else None,
......
...@@ -132,6 +132,7 @@ def serve( ...@@ -132,6 +132,7 @@ def serve(
revision: Optional[str], revision: Optional[str],
sharded: bool, sharded: bool,
quantize: Optional[str], quantize: Optional[str],
speculate: Optional[int],
dtype: Optional[str], dtype: Optional[str],
trust_remote_code: bool, trust_remote_code: bool,
uds_path: Path, uds_path: Path,
...@@ -141,6 +142,7 @@ def serve( ...@@ -141,6 +142,7 @@ def serve(
revision: Optional[str], revision: Optional[str],
sharded: bool = False, sharded: bool = False,
quantize: Optional[str] = None, quantize: Optional[str] = None,
speculate: Optional[int] = None,
dtype: Optional[str] = None, dtype: Optional[str] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
...@@ -157,7 +159,7 @@ def serve( ...@@ -157,7 +159,7 @@ def serve(
try: try:
model = get_model( model = get_model(
model_id, revision, sharded, quantize, dtype, trust_remote_code model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code
) )
except Exception: except Exception:
logger.exception("Error when initializing model") logger.exception("Error when initializing model")
...@@ -205,5 +207,5 @@ def serve( ...@@ -205,5 +207,5 @@ def serve(
await server.stop(0) await server.stop(0)
asyncio.run( asyncio.run(
serve_inner(model_id, revision, sharded, quantize, dtype, trust_remote_code) serve_inner(model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code)
) )
import torch
from dataclasses import dataclass
from text_generation_server.utils.layers import TensorParallelHead, FastLinear
@dataclass
class Output:
logits: torch.FloatTensor = None
speculative_logits: torch.FloatTensor = None
class ResBlock(torch.nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
self.linear = FastLinear.load(config, prefix=f"{prefix}.linear", weights=weights, bias=True)
self.act = torch.nn.SiLU()
def forward(self, x):
return x + self.act(self.linear(x))
class MedusaModel(torch.nn.Module):
def __init__(
self,
config,
weights,
lm_head
):
super().__init__()
self.heads = torch.nn.ModuleList(
[MedusaHead(config, prefix=f"{i}", weights=weights) for i in range(config["medusa_num_heads"])]
)
self.lm_head = lm_head
def forward(self, x):
logits = self.lm_head(x)
speculative_logits = torch.stack([head(x) for head in self.heads], dim=1)
return logits, speculative_logits
class MedusaHead(torch.nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
self.blocks = torch.nn.ModuleList([ResBlock(config, prefix=f"{prefix}.{i}", weights=weights) for i in range(config["medusa_num_layers"])])
n = len(self.blocks)
self.out = FastLinear.load(config, prefix=f"{prefix}.{n}", weights=weights, bias=False)
def forward(self, x):
for block in self.blocks:
x = block(x)
x = self.out(x)
return x
SPECULATE = None
def get_speculate() -> int:
global SPECULATE
return SPECULATE
def set_speculate(speculate: int):
global SPECULATE
SPECULATE = speculate
...@@ -16,7 +16,6 @@ from text_generation_server.utils.logits_process import ( ...@@ -16,7 +16,6 @@ from text_generation_server.utils.logits_process import (
from text_generation_server.utils.watermark import WatermarkLogitsProcessor from text_generation_server.utils.watermark import WatermarkLogitsProcessor
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
class NextTokenChooser: class NextTokenChooser:
def __init__( def __init__(
self, self,
...@@ -146,6 +145,20 @@ class StoppingCriteria: ...@@ -146,6 +145,20 @@ class StoppingCriteria:
pb.ignore_eos_token, pb.ignore_eos_token,
) )
def create_n_gram_speculation(input_ids: torch.Tensor, next_ids: torch.Tensor, accepted_ids: torch.Tensor, speculate: int, verbose: bool):
# Very trivial approach, find first match in the string.
# This is much less refined than actual n-gram but seems to work
# relatively OK in grounded mode and is by far much faster with
# much less worst case complexity as everything happens on device.
B = accepted_ids.shape[0]
device = input_ids.device
seeds = next_ids[accepted_ids.cumsum(dim=-1) -1 ]
indices = (input_ids == seeds.unsqueeze(-1)).max(dim=1).indices + 1
all_indices = indices.unsqueeze(-1).expand(B, speculate) + torch.arange(speculate, device=device)
all_indices = torch.clamp(all_indices, max=input_ids.shape[1] - 1)
speculative_ids = input_ids.gather(dim=-1, index=all_indices)
return speculative_ids
class HeterogeneousNextTokenChooser: class HeterogeneousNextTokenChooser:
def __init__( def __init__(
...@@ -215,20 +228,79 @@ class HeterogeneousNextTokenChooser: ...@@ -215,20 +228,79 @@ class HeterogeneousNextTokenChooser:
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor): def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, speculate: int, speculated_ids: Optional[torch.Tensor] = None, speculative_scores: Optional[torch.Tensor] = None, verbose=False):
if speculated_ids is not None:
B = scores.shape[0] // (speculated_ids.shape[1] + 1)
S = speculated_ids.shape[1] + 1
scores = scores.view(B, S, -1)
else:
B = scores.shape[0]
S = 1
scores = scores.view(B, S, -1)
next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long)
for j in range(S):
_scores = scores[:, j]
if self.watermark_processor is not None: if self.watermark_processor is not None:
scores = self.watermark_processor(input_ids, scores) _scores = self.watermark_processor(input_ids, _scores)
if self.repetition_processor is not None: if self.repetition_processor is not None:
scores = self.repetition_processor(input_ids, scores) _scores = self.repetition_processor(input_ids, _scores)
for warper in self.warpers: for warper in self.warpers:
scores = warper(input_ids, scores) _scores = warper(input_ids, _scores)
_next_ids = self.choice(_scores)
scores[:, j] = _scores
next_ids[:, j] = _next_ids
next_ids = next_ids.view(B*S)
scores = scores.view( B* S, -1)
if speculated_ids is not None:
accepted_ids = []
B = next_ids.shape[0] // (speculated_ids.shape[1] + 1)
S = speculated_ids.shape[1] + 1
indices = []
for i in range(B):
_next_ids = next_ids[i*S: (i + 1)*S]
_speculated_ids = speculated_ids[i]
validate_speculative = _next_ids[:-1] == _speculated_ids
index = i * S
accepted = 1
# First is always valid
indices.append(index)
for valid in validate_speculative.tolist():
if valid:
index += 1
accepted += 1
indices.append(index)
else:
break
accepted_ids.append(accepted)
accepted_ids = torch.tensor(accepted_ids, device=input_ids.device, dtype=input_ids.dtype)
next_ids = next_ids[indices]
scores = scores[indices]
indices = torch.arange(B, device=input_ids.device) * S
if speculative_scores is not None:
speculative_scores = speculative_scores[indices + accepted_ids - 1]
else:
accepted_ids = torch.ones_like(next_ids)
next_ids = self.choice(scores)
logprobs = torch.log_softmax(scores, -1) logprobs = torch.log_softmax(scores, -1)
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1) next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
return next_ids, next_logprobs, logprobs if speculate > 0:
if speculative_scores is not None:
# Medusa provided some scores
speculative_ids = Greedy()(speculative_scores)
else:
# n-gram
speculative_ids = create_n_gram_speculation(input_ids, next_ids, accepted_ids, speculate, verbose)
else:
speculative_ids = None
return next_ids, next_logprobs, logprobs, accepted_ids, speculative_ids
def filter(self, indices): def filter(self, indices):
if self.watermark_processor is not None: if self.watermark_processor is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment