Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
9ecfa16b
Unverified
Commit
9ecfa16b
authored
Dec 11, 2023
by
Nicolas Patry
Committed by
GitHub
Dec 11, 2023
Browse files
Speculative (#1308)
parent
3238c491
Changes
34
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
457 additions
and
134 deletions
+457
-134
server/text_generation_server/cli.py
server/text_generation_server/cli.py
+26
-2
server/text_generation_server/models/__init__.py
server/text_generation_server/models/__init__.py
+35
-1
server/text_generation_server/models/causal_lm.py
server/text_generation_server/models/causal_lm.py
+10
-9
server/text_generation_server/models/flash_causal_lm.py
server/text_generation_server/models/flash_causal_lm.py
+137
-51
server/text_generation_server/models/flash_llama.py
server/text_generation_server/models/flash_llama.py
+13
-0
server/text_generation_server/models/flash_mistral.py
server/text_generation_server/models/flash_mistral.py
+52
-11
server/text_generation_server/models/idefics_causal_lm.py
server/text_generation_server/models/idefics_causal_lm.py
+9
-7
server/text_generation_server/models/model.py
server/text_generation_server/models/model.py
+7
-0
server/text_generation_server/models/seq2seq_lm.py
server/text_generation_server/models/seq2seq_lm.py
+10
-8
server/text_generation_server/models/types.py
server/text_generation_server/models/types.py
+8
-32
server/text_generation_server/server.py
server/text_generation_server/server.py
+4
-2
server/text_generation_server/utils/medusa.py
server/text_generation_server/utils/medusa.py
+51
-0
server/text_generation_server/utils/speculate.py
server/text_generation_server/utils/speculate.py
+12
-0
server/text_generation_server/utils/tokens.py
server/text_generation_server/utils/tokens.py
+83
-11
No files found.
server/text_generation_server/cli.py
View file @
9ecfa16b
...
...
@@ -32,6 +32,7 @@ def serve(
revision
:
Optional
[
str
]
=
None
,
sharded
:
bool
=
False
,
quantize
:
Optional
[
Quantization
]
=
None
,
speculate
:
Optional
[
int
]
=
None
,
dtype
:
Optional
[
Dtype
]
=
None
,
trust_remote_code
:
bool
=
False
,
uds_path
:
Path
=
"/tmp/text-generation-server"
,
...
...
@@ -81,7 +82,7 @@ def serve(
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
)
server
.
serve
(
model_id
,
revision
,
sharded
,
quantize
,
dtype
,
trust_remote_code
,
uds_path
model_id
,
revision
,
sharded
,
quantize
,
speculate
,
dtype
,
trust_remote_code
,
uds_path
)
...
...
@@ -116,7 +117,7 @@ def download_weights(
logger
.
info
(
"Files are already present on the host. "
"Skipping download."
)
return
# Local files not found
except
(
utils
.
LocalEntryNotFoundError
,
FileNotFoundError
):
except
(
utils
.
LocalEntryNotFoundError
,
FileNotFoundError
,
utils
.
EntryNotFoundError
):
pass
is_local_model
=
(
Path
(
model_id
).
exists
()
and
Path
(
model_id
).
is_dir
())
or
os
.
getenv
(
...
...
@@ -137,6 +138,29 @@ def download_weights(
except
(
utils
.
LocalEntryNotFoundError
,
utils
.
EntryNotFoundError
):
pass
try
:
import
json
medusa_head
=
hf_hub_download
(
model_id
,
revision
=
revision
,
filename
=
"medusa_lm_head.pt"
)
if
auto_convert
:
medusa_sf
=
Path
(
medusa_head
[:
-
len
(
".pt"
)]
+
".safetensors"
)
if
not
medusa_sf
.
exists
():
utils
.
convert_files
([
Path
(
medusa_head
)],
[
medusa_sf
],
[])
medusa_config
=
hf_hub_download
(
model_id
,
revision
=
revision
,
filename
=
"config.json"
)
with
open
(
medusa_config
,
"r"
)
as
f
:
config
=
json
.
load
(
f
)
model_id
=
config
[
"base_model_name_or_path"
]
revision
=
"main"
try
:
utils
.
weight_files
(
model_id
,
revision
,
extension
)
logger
.
info
(
f
"Files for parent
{
model_id
}
are already present on the host. "
"Skipping download."
)
return
# Local files not found
except
(
utils
.
LocalEntryNotFoundError
,
FileNotFoundError
,
utils
.
EntryNotFoundError
):
pass
except
(
utils
.
LocalEntryNotFoundError
,
utils
.
EntryNotFoundError
):
pass
# Try to download weights from the hub
try
:
filenames
=
utils
.
weight_hub_files
(
model_id
,
revision
,
extension
)
...
...
server/text_generation_server/models/__init__.py
View file @
9ecfa16b
...
...
@@ -6,6 +6,7 @@ from transformers.configuration_utils import PretrainedConfig
from
transformers.models.auto
import
modeling_auto
from
typing
import
Optional
from
text_generation_server.utils.speculate
import
get_speculate
,
set_speculate
from
text_generation_server.models.model
import
Model
from
text_generation_server.models.causal_lm
import
CausalLM
from
text_generation_server.models.flash_causal_lm
import
FlashCausalLM
...
...
@@ -77,12 +78,12 @@ except ImportError as e:
if
MISTRAL
:
__all__
.
append
(
FlashMistral
)
def
get_model
(
model_id
:
str
,
revision
:
Optional
[
str
],
sharded
:
bool
,
quantize
:
Optional
[
str
],
speculate
:
Optional
[
int
],
dtype
:
Optional
[
str
],
trust_remote_code
:
bool
,
)
->
Model
:
...
...
@@ -97,6 +98,11 @@ def get_model(
else
:
raise
RuntimeError
(
f
"Unknown dtype
{
dtype
}
"
)
if
speculate
is
not
None
:
set_speculate
(
speculate
)
else
:
set_speculate
(
0
)
if
"facebook/galactica"
in
model_id
:
return
GalacticaSharded
(
model_id
,
...
...
@@ -131,6 +137,33 @@ def get_model(
config_dict
,
_
=
PretrainedConfig
.
get_config_dict
(
model_id
,
revision
=
revision
,
trust_remote_code
=
trust_remote_code
)
use_medusa
=
None
if
"medusa_num_heads"
in
config_dict
:
use_medusa
=
model_id
medusa_config
=
config_dict
model_id
=
config_dict
[
"base_model_name_or_path"
]
revision
=
"main"
speculate_medusa
=
config_dict
[
"medusa_num_heads"
]
if
speculate
is
not
None
:
if
speculate
>
speculate_medusa
:
raise
RuntimeError
(
"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
)
else
:
set_speculate
(
speculate
)
else
:
set_speculate
(
speculate_medusa
)
config_dict
,
_
=
PretrainedConfig
.
get_config_dict
(
model_id
,
revision
=
revision
,
trust_remote_code
=
trust_remote_code
)
method
=
"medusa"
else
:
method
=
"n-gram"
speculate
=
get_speculate
()
if
speculate
>
0
:
logger
.
info
(
f
"Using speculation
{
method
}
with
{
speculate
}
input ids."
)
model_type
=
config_dict
[
"model_type"
]
if
model_type
==
"gpt_bigcode"
:
...
...
@@ -206,6 +239,7 @@ def get_model(
quantize
=
quantize
,
dtype
=
dtype
,
trust_remote_code
=
trust_remote_code
,
use_medusa
=
use_medusa
)
elif
sharded
:
raise
NotImplementedError
(
FLASH_ATT_ERROR_MESSAGE
.
format
(
"Sharded Llama"
))
...
...
server/text_generation_server/models/causal_lm.py
View file @
9ecfa16b
...
...
@@ -10,10 +10,9 @@ from typing import Optional, Tuple, List, Type, Dict
from
text_generation_server.models
import
Model
from
text_generation_server.models.types
import
(
Batch
,
Prefill
Tokens
,
Tokens
,
Generation
,
GeneratedText
,
TopTokens
,
)
from
text_generation_server.pb
import
generate_pb2
from
text_generation_server.utils
import
NextTokenChooser
,
StoppingCriteria
,
Sampling
...
...
@@ -676,8 +675,8 @@ class CausalLM(Model):
clean_up_tokenization_spaces
=
False
,
skip_special_tokens
=
False
,
)
prefill_tokens
=
Prefill
Tokens
(
prefill_token_ids
,
prefill_logprobs
,
prefill_texts
prefill_tokens
=
Tokens
(
prefill_token_ids
,
prefill_logprobs
,
prefill_texts
,
is_special
=
[]
)
else
:
prefill_tokens
=
None
...
...
@@ -691,7 +690,7 @@ class CausalLM(Model):
special_toptokens
=
[
token_id
in
self
.
all_special_ids
for
token_id
in
top_token_ids
]
top_tokens
=
Top
Tokens
(
top_tokens
=
Tokens
(
top_token_ids
,
top_token_logprobs
,
toptoken_texts
,
...
...
@@ -703,10 +702,12 @@ class CausalLM(Model):
generation
=
Generation
(
request
.
id
,
prefill_tokens
,
next_token_id_squeezed
,
next_token_logprob
,
next_token_text
,
next_token_id_squeezed
.
item
()
in
self
.
all_special_ids
,
Tokens
(
[
next_token_id_squeezed
],
[
next_token_logprob
],
[
next_token_text
],
[
next_token_id_squeezed
.
item
()
in
self
.
all_special_ids
],
),
generated_text
,
top_tokens
,
)
...
...
server/text_generation_server/models/flash_causal_lm.py
View file @
9ecfa16b
...
...
@@ -11,13 +11,13 @@ from opentelemetry import trace
from
transformers
import
PreTrainedTokenizerBase
from
typing
import
Optional
,
Tuple
,
List
,
Type
,
Union
,
Dict
from
text_generation_server.models
import
Model
from
text_generation_server.models
import
Model
from
text_generation_server.utils.speculate
import
get_speculate
from
text_generation_server.models.types
import
(
Batch
,
Prefill
Tokens
,
Tokens
,
Generation
,
GeneratedText
,
TopTokens
,
)
from
text_generation_server.models.cache_manager
import
(
get_cache_manager
,
...
...
@@ -41,6 +41,7 @@ class FlashCausalLMBatch(Batch):
# Decoder values
input_ids
:
torch
.
Tensor
position_ids
:
torch
.
Tensor
speculative_ids
:
torch
.
Tensor
# Flash Attention values
...
...
@@ -120,6 +121,7 @@ class FlashCausalLMBatch(Batch):
)[
"input_ids"
]
position_ids
=
[]
speculative_ids
=
[]
cu_seqlen_prefill
=
[
0
]
needed_blocks_slots
=
[]
start_slots
=
[]
...
...
@@ -163,6 +165,8 @@ class FlashCausalLMBatch(Batch):
input_length
=
len
(
tokenized_input
)
input_lengths
.
append
(
input_length
)
prefix_offsets
.
append
(
input_length
-
5
)
read_offsets
.
append
(
input_length
)
...
...
@@ -186,7 +190,8 @@ class FlashCausalLMBatch(Batch):
# Paged attention
# Remove one as the first token des not have a past
total_tokens
=
input_length
+
max_new_tokens
-
1
speculative_length
=
get_speculate
()
total_tokens
=
input_length
+
max_new_tokens
-
1
+
speculative_length
needed_blocks
=
math
.
ceil
(
total_tokens
/
BLOCK_SIZE
)
blocks
+=
needed_blocks
needed_blocks_slots
.
append
((
needed_blocks
,
total_tokens
))
...
...
@@ -224,7 +229,7 @@ class FlashCausalLMBatch(Batch):
cumulative_max_length
+=
total_tokens
max_seqlen
=
max
(
max_seqlen
,
input_length
)
max_blocks
=
max
(
max_blocks
,
needed_blocks
)
max_length
=
max
(
max_length
,
input_length
+
max_new_tokens
)
max_length
=
max
(
max_length
,
input_length
+
max_new_tokens
+
speculative_length
)
next_token_chooser
=
HeterogeneousNextTokenChooser
.
from_pb
(
next_token_chooser_parameters
,
dtype
,
device
...
...
@@ -255,7 +260,6 @@ class FlashCausalLMBatch(Batch):
cu_seqlen_prefill
=
torch
.
tensor
(
cu_seqlen_prefill
,
device
=
device
,
dtype
=
torch
.
int32
)
position_ids
=
position_ids
.
to
(
device
)
slot_indices
=
slot_indices
.
to
(
device
)
input_ids
=
torch
.
tensor
(
input_ids
,
dtype
=
torch
.
int64
,
device
=
device
)
...
...
@@ -309,6 +313,7 @@ class FlashCausalLMBatch(Batch):
top_n_tokens_tensor
=
top_n_tokens_tensor
,
blocks
=
blocks
,
max_blocks
=
max_blocks
,
speculative_ids
=
None
,
)
@
tracer
.
start_as_current_span
(
"filter"
)
...
...
@@ -419,6 +424,7 @@ class FlashCausalLMBatch(Batch):
slots
=
self
.
slots
[
slot_filtering_indices
]
next_token_chooser
=
self
.
next_token_chooser
.
filter
(
indices
)
top_n_tokens_tensor
=
self
.
top_n_tokens_tensor
[
indices
]
speculative_ids
=
self
.
speculative_ids
[
indices
]
if
self
.
speculative_ids
is
not
None
else
None
start_slots
=
torch
.
tensor
(
start_slots
,
dtype
=
torch
.
int64
)
...
...
@@ -454,6 +460,7 @@ class FlashCausalLMBatch(Batch):
top_n_tokens_tensor
=
top_n_tokens_tensor
,
blocks
=
blocks
,
max_blocks
=
max_blocks
,
speculative_ids
=
speculative_ids
,
)
@
classmethod
...
...
@@ -473,6 +480,7 @@ class FlashCausalLMBatch(Batch):
total_batch_size
+=
len
(
b
)
total_slots
+=
len
(
b
.
slots
)
blocks
+=
b
.
blocks
speculative_length
=
b
.
speculative_ids
.
shape
[
1
]
if
b
.
speculative_ids
is
not
None
else
0
max_blocks
=
max
(
max_blocks
,
b
.
max_blocks
)
max_seqlen
=
max
(
max_seqlen
,
b
.
max_seqlen
)
max_length
=
max
(
...
...
@@ -480,6 +488,7 @@ class FlashCausalLMBatch(Batch):
max
(
input_length
+
stopping_criteria
.
max_new_tokens
+
speculative_length
-
stopping_criteria
.
current_tokens
for
input_length
,
stopping_criteria
in
zip
(
b
.
input_lengths
,
b
.
stopping_criterias
...
...
@@ -577,6 +586,8 @@ class FlashCausalLMBatch(Batch):
device
=
batches
[
0
].
next_token_chooser
.
device
,
)
speculative_ids
=
torch
.
cat
([
b
.
speculative_ids
for
b
in
batches
],
dim
=
0
)
if
batches
[
0
].
speculative_ids
is
not
None
else
None
# Needed to avoid dropping blocks when the batches will go out of scope
for
b
in
batches
:
b
.
block_tables
=
None
...
...
@@ -611,6 +622,7 @@ class FlashCausalLMBatch(Batch):
top_n_tokens_tensor
=
top_n_tokens_tensor
,
blocks
=
blocks
,
max_blocks
=
max_blocks
,
speculative_ids
=
speculative_ids
)
def
__del__
(
self
):
...
...
@@ -714,16 +726,55 @@ class FlashCausalLM(Model):
def
forward
(
self
,
batch
:
FlashCausalLMBatch
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Model Forward
if
batch
.
speculative_ids
is
not
None
:
input_ids
=
batch
.
input_ids
position_ids
=
batch
.
position_ids
cu_seqlen_prefill
=
batch
.
cu_seqlen_prefill
kv_cache
=
get_cache_manager
().
kv_cache
block_tables
=
batch
.
block_tables_tensor
slots
=
batch
.
slots
[
batch
.
slot_indices
]
input_lengths
=
batch
.
input_lengths_tensor
max_s
=
batch
.
max_seqlen
lm_head_indices
=
batch
.
prefill_head_indices
speculative_ids
=
batch
.
speculative_ids
B
,
speculative_length
=
speculative_ids
.
shape
new_length
=
speculative_length
+
1
new_input_ids
=
torch
.
cat
([
input_ids
.
unsqueeze
(
-
1
),
speculative_ids
],
dim
=
1
).
reshape
(
-
1
)
arange
=
torch
.
arange
(
new_length
,
device
=
position_ids
.
device
).
unsqueeze
(
0
)
arange_int
=
arange
.
to
(
dtype
=
torch
.
int32
)
new_position_ids
=
(
position_ids
.
unsqueeze
(
-
1
).
expand
(
B
,
new_length
)
+
arange
).
view
(
-
1
)
slots
=
(
slots
.
unsqueeze
(
-
1
).
expand
(
B
,
new_length
)
+
arange_int
).
view
(
-
1
)
input_lengths
=
(
input_lengths
.
unsqueeze
(
-
1
).
expand
(
B
,
new_length
)
+
arange_int
).
view
(
-
1
)
# Add Copy the block tables for all members
block_tables
=
block_tables
.
unsqueeze
(
1
).
expand
(
B
,
new_length
,
-
1
).
reshape
(
B
*
new_length
,
-
1
).
contiguous
()
max_s
=
max_s
+
speculative_length
input_ids
=
new_input_ids
position_ids
=
new_position_ids
else
:
input_ids
=
batch
.
input_ids
position_ids
=
batch
.
position_ids
cu_seqlen_prefill
=
batch
.
cu_seqlen_prefill
kv_cache
=
get_cache_manager
().
kv_cache
block_tables
=
batch
.
block_tables_tensor
slots
=
batch
.
slots
[
batch
.
slot_indices
]
input_lengths
=
batch
.
input_lengths_tensor
max_s
=
batch
.
max_seqlen
lm_head_indices
=
batch
.
prefill_head_indices
return
self
.
model
.
forward
(
input_ids
=
batch
.
input_ids
,
position_ids
=
batch
.
position_ids
,
cu_seqlen_prefill
=
batch
.
cu_seqlen_prefill
,
kv_cache
=
get_cache_manager
().
kv_cache
,
block_tables
=
batch
.
block_tables
_tensor
,
slots
=
batch
.
slots
[
batch
.
slot_indices
]
,
input_lengths
=
batch
.
input_lengths
_tensor
,
max_s
=
batch
.
max_seqlen
,
lm_head_indices
=
batch
.
prefil
l_head_indices
,
input_ids
=
input_ids
,
position_ids
=
position_ids
,
cu_seqlen_prefill
=
cu_seqlen_prefill
,
kv_cache
=
kv_cache
,
block_tables
=
block_tables
,
slots
=
slots
,
input_lengths
=
input_lengths
,
max_s
=
max_s
,
lm_head_indices
=
l
m
_head_indices
,
)
@
tracer
.
start_as_current_span
(
"generate_token"
)
...
...
@@ -752,21 +803,32 @@ class FlashCausalLM(Model):
del
batch
raise
e
if
isinstance
(
out
,
tuple
):
out
,
speculative_logits
=
out
else
:
speculative_logits
=
None
if
prefill
:
next_token_logits
=
(
out
[
batch
.
prefill_next_token_indices
]
if
prefill_logprobs
else
out
)
if
speculative_logits
is
not
None
:
speculative_logits
=
(
speculative_logits
[
batch
.
prefill_next_token_indices
]
if
prefill_logprobs
else
speculative_logits
)
else
:
next_token_logits
=
out
next_input_ids
,
next_token_logprobs
,
logprobs
=
batch
.
next_token_chooser
(
batch
.
all_input_ids_tensor
[:,
:
batch
.
max_seqlen
],
next_token_logits
next_input_ids
,
next_token_logprobs
,
logprobs
,
accepted_ids
,
speculative_ids
=
batch
.
next_token_chooser
(
batch
.
all_input_ids_tensor
[:,
:
batch
.
max_seqlen
],
next_token_logits
,
get_speculate
(),
batch
.
speculative_ids
,
speculative_logits
)
batch_top_token_ids
,
batch_top_token_logprobs
=
batch_top_tokens
(
batch
.
top_n_tokens
,
batch
.
top_n_tokens_tensor
,
logprobs
)
speculative_length
=
0
if
speculative_ids
is
None
else
speculative_ids
.
shape
[
1
]
if
prefill
:
if
len
(
batch
)
>
1
and
prefill_logprobs
:
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
...
...
@@ -792,6 +854,7 @@ class FlashCausalLM(Model):
iterator
=
zip
(
batch
.
input_lengths
,
batch
.
all_input_ids
,
accepted_ids
)
# We do two for loops as the first one can run completely asynchronously from the GPU while for the second
...
...
@@ -799,9 +862,11 @@ class FlashCausalLM(Model):
# It is faster if we delay this sync for the maximum amount of time
# For each member of the batch
index
=
0
for
i
,
(
input_length
,
all_input_ids
,
n_accepted_ids
)
in
enumerate
(
iterator
):
# Indexing metadata
start_index
=
cumulative_length
...
...
@@ -830,15 +895,18 @@ class FlashCausalLM(Model):
start_index
+
1
:
start_index
+
out_length
]
batch
.
all_input_ids_tensor
[
i
,
input_length
]
=
next_input_ids
[
i
]
for
j
in
range
(
n_accepted_ids
):
batch
.
all_input_ids_tensor
[
i
,
input_length
+
j
]
=
next_input_ids
[
index
]
index
+=
1
cumulative_length
+=
input_length
# Set values in batch
batch
.
input_ids
=
next_input_ids
batch
.
position_ids
=
next_position_ids
+
1
batch
.
input_lengths_tensor
+=
1
batch
.
slot_indices
+=
1
batch
.
input_ids
=
next_input_ids
[
accepted_ids
.
cumsum
(
dim
=-
1
)
-
1
]
batch
.
speculative_ids
=
speculative_ids
batch
.
position_ids
=
next_position_ids
+
accepted_ids
batch
.
input_lengths_tensor
+=
accepted_ids
batch
.
slot_indices
+=
accepted_ids
if
prefill
and
prefill_logprobs
:
# Get prefill logprobs
...
...
@@ -851,7 +919,7 @@ class FlashCausalLM(Model):
# GPU <-> CPU sync
next_token_logprobs
=
next_token_logprobs
.
tolist
()
next_token_ids
=
batch
.
input_ids
.
tolist
()
next_token_ids
=
next_
input_ids
.
tolist
()
# Zipped iterator
iterator
=
zip
(
...
...
@@ -864,13 +932,13 @@ class FlashCausalLM(Model):
batch
.
next_token_chooser
.
do_sample
,
batch
.
next_token_chooser
.
seeds
,
batch
.
top_n_tokens
,
next_token_ids
,
next_token_logprobs
,
accepted_ids
,
batch_top_token_ids
,
batch_top_token_logprobs
,
)
# For each member of the batch
index
=
0
for
i
,
(
request
,
input_length
,
...
...
@@ -881,29 +949,43 @@ class FlashCausalLM(Model):
do_sample
,
seed
,
top_n_tokens
,
next_token_id
,
next_token_logprob
,
n_accepted_ids
,
top_token_ids
,
top_token_logprobs
,
)
in
enumerate
(
iterator
):
# Append next token to all tokens
all_input_ids
.
append
(
next_token_id
)
next_token_texts
=
[]
left
=
0
before
=
stopping_criteria
.
current_tokens
current_stopped
=
False
for
j
in
range
(
index
,
index
+
n_accepted_ids
):
# Generated token
next_token_id
=
next_token_ids
[
j
]
all_input_ids
.
append
(
next_token_id
)
next_token_text
,
prefix_offset
,
read_offset
=
self
.
decode_token
(
all_input_ids
,
prefix_offset
,
read_offset
,
)
next_token_texts
.
append
(
next_token_text
)
# Generated token
next_token_text
,
prefix_offset
,
read_offset
=
self
.
decode_token
(
all_input_ids
,
prefix_offset
,
read_offset
,
)
stop
,
reason
=
stopping_criteria
(
next_token_id
,
next_token_text
,
)
# Evaluate stopping criteria
stop
,
reason
=
stopping_criteria
(
next_token_id
,
next_token_text
,
)
if
stop
:
left
=
index
+
n_accepted_ids
-
j
-
1
current_stopped
=
True
break
else
:
current_stopped
=
False
stopped
=
stopped
and
current_stopped
if
not
stop
:
stopped
=
False
_next_token_ids
=
next_token_ids
[
index
:
index
+
n_accepted_ids
-
left
]
_next_token_logprobs
=
next_token_logprobs
[
index
:
index
+
n_accepted_ids
-
left
]
index
+=
n_accepted_ids
# Shard generations
# All generations will be appended in the rust sharded client
...
...
@@ -943,8 +1025,9 @@ class FlashCausalLM(Model):
clean_up_tokenization_spaces
=
False
,
skip_special_tokens
=
False
,
)
prefill_tokens
=
PrefillTokens
(
prefill_token_ids
,
request_prefill_logprobs
,
prefill_texts
prefill_tokens
=
Tokens
(
prefill_token_ids
,
request_prefill_logprobs
,
prefill_texts
,
is_special
=
[]
)
else
:
prefill_tokens
=
None
...
...
@@ -958,7 +1041,7 @@ class FlashCausalLM(Model):
special_toptokens
=
[
token_id
in
self
.
all_special_ids
for
token_id
in
top_token_ids
]
top_tokens
=
Top
Tokens
(
top_tokens
=
Tokens
(
top_token_ids
,
top_token_logprobs
,
toptoken_texts
,
...
...
@@ -970,10 +1053,12 @@ class FlashCausalLM(Model):
generation
=
Generation
(
request
.
id
,
prefill_tokens
,
next_token_id
,
next_token_logprob
,
next_token_text
,
next_token_id
in
self
.
all_special_ids
,
Tokens
(
_next_token_ids
,
_next_token_logprobs
,
next_token_texts
,
[
nid
in
self
.
all_special_ids
for
nid
in
_next_token_ids
],
),
generated_text
,
top_tokens
,
)
...
...
@@ -981,7 +1066,9 @@ class FlashCausalLM(Model):
generations
.
append
(
generation
)
# Update values
batch
.
input_lengths
[
i
]
=
input_length
+
1
batch
.
input_lengths
[
i
]
=
input_length
+
n_accepted_ids
.
item
()
if
batch
.
input_lengths
[
i
]
>
batch
.
max_seqlen
:
batch
.
max_seqlen
=
batch
.
input_lengths
[
i
]
batch
.
prefix_offsets
[
i
]
=
prefix_offset
batch
.
read_offsets
[
i
]
=
read_offset
batch
.
all_input_ids
[
i
]
=
all_input_ids
...
...
@@ -994,6 +1081,5 @@ class FlashCausalLM(Model):
batch
.
prefill_cu_outlens
=
None
batch
.
prefill_head_indices
=
None
batch
.
prefill_next_token_indices
=
None
batch
.
max_seqlen
=
batch
.
max_seqlen
+
1
return
generations
,
batch
server/text_generation_server/models/flash_llama.py
View file @
9ecfa16b
...
...
@@ -28,6 +28,7 @@ class FlashLlama(FlashCausalLM):
quantize
:
Optional
[
str
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
trust_remote_code
:
bool
=
False
,
use_medusa
:
Optional
[
str
]
=
None
,
):
self
.
process_group
,
rank
,
world_size
=
initialize_torch_distributed
()
if
torch
.
cuda
.
is_available
():
...
...
@@ -66,6 +67,18 @@ class FlashLlama(FlashCausalLM):
weights
.
_set_gptq_params
(
model_id
)
model
=
FlashLlamaForCausalLM
(
config
,
weights
)
if
use_medusa
:
from
text_generation_server.utils.medusa
import
MedusaModel
from
huggingface_hub
import
hf_hub_download
import
json
medusa_config
=
hf_hub_download
(
use_medusa
,
revision
=
revision
,
filename
=
"config.json"
)
with
open
(
medusa_config
,
"r"
)
as
f
:
config
=
json
.
load
(
f
)
medusa_head
=
hf_hub_download
(
use_medusa
,
revision
=
revision
,
filename
=
"medusa_lm_head.pt"
)
medusa_sf
=
medusa_head
[:
-
len
(
".pt"
)]
+
".safetensors"
weights
=
Weights
([
medusa_sf
],
device
,
dtype
,
process_group
=
self
.
process_group
)
lm_head
=
model
.
lm_head
model
.
lm_head
=
MedusaModel
(
config
,
weights
,
lm_head
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
super
(
FlashLlama
,
self
).
__init__
(
...
...
server/text_generation_server/models/flash_mistral.py
View file @
9ecfa16b
...
...
@@ -21,6 +21,7 @@ from text_generation_server.models.custom_modeling.flash_mistral_modeling import
FlashMistralForCausalLM
,
MistralConfig
,
)
from
text_generation_server.utils.speculate
import
get_speculate
from
text_generation_server.utils
import
(
initialize_torch_distributed
,
weight_files
,
...
...
@@ -132,7 +133,8 @@ class FlashMistralBatch(FlashCausalLMBatch):
# Paged attention
# Remove one as the first token des not have a past
total_tokens
=
input_length
+
max_new_tokens
-
1
speculative_length
=
get_speculate
()
total_tokens
=
input_length
+
max_new_tokens
-
1
+
speculative_length
# Needed blocks can not go over SLIDING_WINDOW_BLOCKS
needed_blocks
=
min
(
...
...
@@ -183,7 +185,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
cumulative_max_length
+=
total_tokens
max_seqlen
=
max
(
max_seqlen
,
input_length
)
max_blocks
=
max
(
max_blocks
,
needed_blocks
)
max_length
=
max
(
max_length
,
input_length
+
max_new_tokens
)
max_length
=
max
(
max_length
,
input_length
+
max_new_tokens
+
speculative_length
)
next_token_chooser
=
HeterogeneousNextTokenChooser
.
from_pb
(
next_token_chooser_parameters
,
dtype
,
device
...
...
@@ -272,6 +274,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
blocks
=
blocks
,
max_blocks
=
max_blocks
,
prefill_cache_indices
=
prefill_cache_indices
,
speculative_ids
=
None
)
...
...
@@ -340,17 +343,55 @@ class FlashMistral(FlashCausalLM):
def
forward
(
self
,
batch
:
FlashMistralBatch
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Model Forward
if
batch
.
speculative_ids
is
not
None
:
input_ids
=
batch
.
input_ids
position_ids
=
batch
.
position_ids
cu_seqlen_prefill
=
batch
.
cu_seqlen_prefill
kv_cache
=
get_cache_manager
().
kv_cache
block_tables
=
batch
.
block_tables_tensor
slots
=
batch
.
slots
[
batch
.
slot_indices
]
input_lengths
=
batch
.
input_lengths_tensor
max_s
=
batch
.
max_seqlen
lm_head_indices
=
batch
.
prefill_head_indices
speculative_ids
=
batch
.
speculative_ids
B
,
speculative_length
=
speculative_ids
.
shape
new_length
=
speculative_length
+
1
new_input_ids
=
torch
.
cat
([
input_ids
.
unsqueeze
(
-
1
),
speculative_ids
],
dim
=
1
).
reshape
(
-
1
)
arange
=
torch
.
arange
(
new_length
,
device
=
position_ids
.
device
).
unsqueeze
(
0
)
arange_int
=
arange
.
to
(
dtype
=
torch
.
int32
)
new_position_ids
=
(
position_ids
.
unsqueeze
(
-
1
).
expand
(
B
,
new_length
)
+
arange
).
view
(
-
1
)
slots
=
(
slots
.
unsqueeze
(
-
1
).
expand
(
B
,
new_length
)
+
arange_int
).
view
(
-
1
)
input_lengths
=
(
input_lengths
.
unsqueeze
(
-
1
).
expand
(
B
,
new_length
)
+
arange_int
).
view
(
-
1
)
# Add Copy the block tables for all members
block_tables
=
block_tables
.
unsqueeze
(
1
).
expand
(
B
,
new_length
,
-
1
).
reshape
(
B
*
new_length
,
-
1
).
contiguous
()
max_s
=
max_s
+
speculative_length
input_ids
=
new_input_ids
position_ids
=
new_position_ids
else
:
input_ids
=
batch
.
input_ids
position_ids
=
batch
.
position_ids
cu_seqlen_prefill
=
batch
.
cu_seqlen_prefill
kv_cache
=
get_cache_manager
().
kv_cache
block_tables
=
batch
.
block_tables_tensor
slots
=
batch
.
slots
[
batch
.
slot_indices
]
input_lengths
=
batch
.
input_lengths_tensor
max_s
=
batch
.
max_seqlen
lm_head_indices
=
batch
.
prefill_head_indices
logits
=
self
.
model
.
forward
(
input_ids
=
batch
.
input_ids
,
position_ids
=
batch
.
position_ids
,
cu_seqlen_prefill
=
batch
.
cu_seqlen_prefill
,
kv_cache
=
get_cache_manager
().
kv_cache
,
block_tables
=
batch
.
block_tables
_tensor
,
slots
=
batch
.
slots
[
batch
.
slot_indices
]
,
input_lengths
=
batch
.
input_lengths
_tensor
,
max_s
=
batch
.
max_seqlen
,
input_ids
=
input_ids
,
position_ids
=
position_ids
,
cu_seqlen_prefill
=
cu_seqlen_prefill
,
kv_cache
=
kv_cache
,
block_tables
=
block_tables
,
slots
=
slots
,
input_lengths
=
input_lengths
,
max_s
=
max_s
,
prefill_cache_indices
=
batch
.
prefill_cache_indices
,
lm_head_indices
=
batch
.
prefil
l_head_indices
,
lm_head_indices
=
l
m
_head_indices
,
)
if
batch
.
prefill_cache_indices
is
not
None
:
batch
.
prefill_cache_indices
=
None
...
...
server/text_generation_server/models/idefics_causal_lm.py
View file @
9ecfa16b
...
...
@@ -20,7 +20,7 @@ from typing import Optional, Tuple, List, Type, Dict
from
text_generation_server.models
import
Model
from
text_generation_server.models.types
import
(
Batch
,
Prefill
Tokens
,
Tokens
,
Generation
,
GeneratedText
,
)
...
...
@@ -791,8 +791,8 @@ class IdeficsCausalLM(Model):
clean_up_tokenization_spaces
=
False
,
skip_special_tokens
=
False
,
)
prefill_tokens
=
Prefill
Tokens
(
prefill_token_ids
,
prefill_logprobs
,
prefill_texts
prefill_tokens
=
Tokens
(
prefill_token_ids
,
prefill_logprobs
,
prefill_texts
,
is_special
=
[]
)
else
:
prefill_tokens
=
None
...
...
@@ -802,10 +802,12 @@ class IdeficsCausalLM(Model):
generation
=
Generation
(
request
.
id
,
prefill_tokens
,
next_token_id_squeezed
,
next_token_logprob
,
next_token_text
,
next_token_id_squeezed
.
item
()
in
self
.
all_special_ids
,
Tokens
(
[
next_token_id_squeezed
],
[
next_token_logprob
],
[
next_token_text
],
[
next_token_id_squeezed
.
item
()
in
self
.
all_special_ids
],
),
generated_text
,
top_tokens
,
)
...
...
server/text_generation_server/models/model.py
View file @
9ecfa16b
...
...
@@ -6,6 +6,7 @@ from typing import List, Tuple, Optional, TypeVar, Type
from
transformers
import
PreTrainedTokenizerBase
,
PretrainedConfig
from
text_generation_server.models.types
import
Batch
,
Generation
from
text_generation_server.utils.speculate
import
get_speculate
from
text_generation_server.pb.generate_pb2
import
InfoResponse
B
=
TypeVar
(
"B"
,
bound
=
Batch
)
...
...
@@ -22,6 +23,7 @@ class Model(ABC):
rank
:
int
=
0
,
world_size
:
int
=
1
,
sliding_window
:
Optional
[
int
]
=
None
,
speculate
:
Optional
[
int
]
=
None
,
):
self
.
model
=
model
.
eval
()
self
.
tokenizer
=
tokenizer
...
...
@@ -33,6 +35,10 @@ class Model(ABC):
self
.
world_size
=
world_size
self
.
sliding_window
=
sliding_window
if
speculate
is
None
:
speculate
=
get_speculate
()
self
.
speculate
=
speculate
self
.
has_position_ids
=
(
inspect
.
signature
(
model
.
forward
).
parameters
.
get
(
"position_ids"
,
None
)
is
not
None
...
...
@@ -50,6 +56,7 @@ class Model(ABC):
dtype
=
str
(
self
.
dtype
),
device_type
=
self
.
device
.
type
,
window_size
=
self
.
sliding_window
,
speculate
=
self
.
speculate
)
@
property
...
...
server/text_generation_server/models/seq2seq_lm.py
View file @
9ecfa16b
...
...
@@ -11,8 +11,7 @@ from text_generation_server.models.types import (
GeneratedText
,
Batch
,
Generation
,
PrefillTokens
,
TopTokens
,
Tokens
,
)
from
text_generation_server.pb
import
generate_pb2
from
text_generation_server.utils
import
NextTokenChooser
,
StoppingCriteria
,
Sampling
...
...
@@ -733,10 +732,11 @@ class Seq2SeqLM(Model):
# Prefill
if
stopping_criteria
.
current_tokens
==
1
and
request
.
prefill_logprobs
:
prefill_tokens
=
Prefill
Tokens
(
prefill_tokens
=
Tokens
(
[
self
.
tokenizer
.
bos_token_id
],
[
float
(
"nan"
)],
[
self
.
tokenizer
.
bos_token
],
[
False
]
)
else
:
prefill_tokens
=
None
...
...
@@ -750,7 +750,7 @@ class Seq2SeqLM(Model):
special_toptokens
=
[
token_id
in
self
.
all_special_ids
for
token_id
in
top_token_ids
]
top_tokens
=
Top
Tokens
(
top_tokens
=
Tokens
(
top_token_ids
,
top_token_logprobs
,
toptoken_texts
,
...
...
@@ -762,10 +762,12 @@ class Seq2SeqLM(Model):
generation
=
Generation
(
request
.
id
,
prefill_tokens
,
next_token_id_squeezed
,
next_token_logprob
,
next_token_text
,
next_token_id_squeezed
.
item
()
in
self
.
all_special_ids
,
Tokens
(
[
next_token_id_squeezed
],
[
next_token_logprob
],
[
next_token_text
],
[
next_token_id_squeezed
.
item
()
in
self
.
all_special_ids
],
),
generated_text
,
top_tokens
,
)
...
...
server/text_generation_server/models/types.py
View file @
9ecfa16b
...
...
@@ -58,33 +58,15 @@ class GeneratedText:
@
dataclass
class
PrefillTokens
:
token_ids
:
List
[
int
]
logprobs
:
List
[
float
]
texts
:
List
[
str
]
def
to_pb
(
self
)
->
generate_pb2
.
PrefillTokens
:
return
generate_pb2
.
PrefillTokens
(
ids
=
self
.
token_ids
,
logprobs
=
self
.
logprobs
,
texts
=
self
.
texts
)
def
__len__
(
self
):
return
len
(
self
.
token_ids
)
@
dataclass
class
TopTokens
:
class
Tokens
:
token_ids
:
List
[
int
]
logprobs
:
List
[
float
]
texts
:
List
[
str
]
is_special
:
List
[
bool
]
def
to_pb
(
self
)
->
generate_pb2
.
TopTokens
:
return
generate_pb2
.
TopTokens
(
ids
=
self
.
token_ids
,
logprobs
=
self
.
logprobs
,
texts
=
self
.
texts
,
is_special
=
self
.
is_special
,
def
to_pb
(
self
)
->
generate_pb2
.
Tokens
:
return
generate_pb2
.
Tokens
(
ids
=
self
.
token_ids
,
logprobs
=
self
.
logprobs
,
texts
=
self
.
texts
,
is_special
=
self
.
is_special
)
def
__len__
(
self
):
...
...
@@ -94,14 +76,11 @@ class TopTokens:
@
dataclass
class
Generation
:
request_id
:
int
prefill_tokens
:
Optional
[
PrefillTokens
]
token_id
:
int
token_logprob
:
float
token_text
:
str
token_is_special
:
bool
prefill_tokens
:
Optional
[
Tokens
]
tokens
:
Tokens
generated_text
:
Optional
[
GeneratedText
]
# Optional for now, since it's not yet supported for every model.
top_tokens
:
Optional
[
Top
Tokens
]
top_tokens
:
Optional
[
List
[
Tokens
]
]
def
to_pb
(
self
)
->
generate_pb2
.
Generation
:
return
generate_pb2
.
Generation
(
...
...
@@ -109,10 +88,7 @@ class Generation:
prefill_tokens
=
self
.
prefill_tokens
.
to_pb
()
if
self
.
prefill_tokens
is
not
None
else
None
,
token_id
=
self
.
token_id
,
token_logprob
=
self
.
token_logprob
,
token_text
=
self
.
token_text
,
token_is_special
=
self
.
token_is_special
,
tokens
=
self
.
tokens
.
to_pb
(),
generated_text
=
self
.
generated_text
.
to_pb
()
if
self
.
generated_text
is
not
None
else
None
,
...
...
server/text_generation_server/server.py
View file @
9ecfa16b
...
...
@@ -132,6 +132,7 @@ def serve(
revision
:
Optional
[
str
],
sharded
:
bool
,
quantize
:
Optional
[
str
],
speculate
:
Optional
[
int
],
dtype
:
Optional
[
str
],
trust_remote_code
:
bool
,
uds_path
:
Path
,
...
...
@@ -141,6 +142,7 @@ def serve(
revision
:
Optional
[
str
],
sharded
:
bool
=
False
,
quantize
:
Optional
[
str
]
=
None
,
speculate
:
Optional
[
int
]
=
None
,
dtype
:
Optional
[
str
]
=
None
,
trust_remote_code
:
bool
=
False
,
):
...
...
@@ -157,7 +159,7 @@ def serve(
try
:
model
=
get_model
(
model_id
,
revision
,
sharded
,
quantize
,
dtype
,
trust_remote_code
model_id
,
revision
,
sharded
,
quantize
,
speculate
,
dtype
,
trust_remote_code
)
except
Exception
:
logger
.
exception
(
"Error when initializing model"
)
...
...
@@ -205,5 +207,5 @@ def serve(
await
server
.
stop
(
0
)
asyncio
.
run
(
serve_inner
(
model_id
,
revision
,
sharded
,
quantize
,
dtype
,
trust_remote_code
)
serve_inner
(
model_id
,
revision
,
sharded
,
quantize
,
speculate
,
dtype
,
trust_remote_code
)
)
server/text_generation_server/utils/medusa.py
0 → 100644
View file @
9ecfa16b
import
torch
from
dataclasses
import
dataclass
from
text_generation_server.utils.layers
import
TensorParallelHead
,
FastLinear
@
dataclass
class
Output
:
logits
:
torch
.
FloatTensor
=
None
speculative_logits
:
torch
.
FloatTensor
=
None
class
ResBlock
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
config
,
prefix
,
weights
):
super
().
__init__
()
self
.
linear
=
FastLinear
.
load
(
config
,
prefix
=
f
"
{
prefix
}
.linear"
,
weights
=
weights
,
bias
=
True
)
self
.
act
=
torch
.
nn
.
SiLU
()
def
forward
(
self
,
x
):
return
x
+
self
.
act
(
self
.
linear
(
x
))
class
MedusaModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
config
,
weights
,
lm_head
):
super
().
__init__
()
self
.
heads
=
torch
.
nn
.
ModuleList
(
[
MedusaHead
(
config
,
prefix
=
f
"
{
i
}
"
,
weights
=
weights
)
for
i
in
range
(
config
[
"medusa_num_heads"
])]
)
self
.
lm_head
=
lm_head
def
forward
(
self
,
x
):
logits
=
self
.
lm_head
(
x
)
speculative_logits
=
torch
.
stack
([
head
(
x
)
for
head
in
self
.
heads
],
dim
=
1
)
return
logits
,
speculative_logits
class
MedusaHead
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
config
,
prefix
,
weights
):
super
().
__init__
()
self
.
blocks
=
torch
.
nn
.
ModuleList
([
ResBlock
(
config
,
prefix
=
f
"
{
prefix
}
.
{
i
}
"
,
weights
=
weights
)
for
i
in
range
(
config
[
"medusa_num_layers"
])])
n
=
len
(
self
.
blocks
)
self
.
out
=
FastLinear
.
load
(
config
,
prefix
=
f
"
{
prefix
}
.
{
n
}
"
,
weights
=
weights
,
bias
=
False
)
def
forward
(
self
,
x
):
for
block
in
self
.
blocks
:
x
=
block
(
x
)
x
=
self
.
out
(
x
)
return
x
server/text_generation_server/utils/speculate.py
0 → 100644
View file @
9ecfa16b
SPECULATE
=
None
def
get_speculate
()
->
int
:
global
SPECULATE
return
SPECULATE
def
set_speculate
(
speculate
:
int
):
global
SPECULATE
SPECULATE
=
speculate
server/text_generation_server/utils/tokens.py
View file @
9ecfa16b
...
...
@@ -16,7 +16,6 @@ from text_generation_server.utils.logits_process import (
from
text_generation_server.utils.watermark
import
WatermarkLogitsProcessor
from
transformers
import
PreTrainedTokenizerBase
,
RepetitionPenaltyLogitsProcessor
class
NextTokenChooser
:
def
__init__
(
self
,
...
...
@@ -146,6 +145,20 @@ class StoppingCriteria:
pb
.
ignore_eos_token
,
)
def
create_n_gram_speculation
(
input_ids
:
torch
.
Tensor
,
next_ids
:
torch
.
Tensor
,
accepted_ids
:
torch
.
Tensor
,
speculate
:
int
,
verbose
:
bool
):
# Very trivial approach, find first match in the string.
# This is much less refined than actual n-gram but seems to work
# relatively OK in grounded mode and is by far much faster with
# much less worst case complexity as everything happens on device.
B
=
accepted_ids
.
shape
[
0
]
device
=
input_ids
.
device
seeds
=
next_ids
[
accepted_ids
.
cumsum
(
dim
=-
1
)
-
1
]
indices
=
(
input_ids
==
seeds
.
unsqueeze
(
-
1
)).
max
(
dim
=
1
).
indices
+
1
all_indices
=
indices
.
unsqueeze
(
-
1
).
expand
(
B
,
speculate
)
+
torch
.
arange
(
speculate
,
device
=
device
)
all_indices
=
torch
.
clamp
(
all_indices
,
max
=
input_ids
.
shape
[
1
]
-
1
)
speculative_ids
=
input_ids
.
gather
(
dim
=-
1
,
index
=
all_indices
)
return
speculative_ids
class
HeterogeneousNextTokenChooser
:
def
__init__
(
...
...
@@ -215,20 +228,79 @@ class HeterogeneousNextTokenChooser:
self
.
dtype
=
dtype
self
.
device
=
device
def
__call__
(
self
,
input_ids
:
torch
.
Tensor
,
scores
:
torch
.
Tensor
):
if
self
.
watermark_processor
is
not
None
:
scores
=
self
.
watermark_processor
(
input_ids
,
scores
)
if
self
.
repetition_processor
is
not
None
:
scores
=
self
.
repetition_processor
(
input_ids
,
scores
)
for
warper
in
self
.
warpers
:
scores
=
warper
(
input_ids
,
scores
)
def
__call__
(
self
,
input_ids
:
torch
.
Tensor
,
scores
:
torch
.
Tensor
,
speculate
:
int
,
speculated_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
speculative_scores
:
Optional
[
torch
.
Tensor
]
=
None
,
verbose
=
False
):
if
speculated_ids
is
not
None
:
B
=
scores
.
shape
[
0
]
//
(
speculated_ids
.
shape
[
1
]
+
1
)
S
=
speculated_ids
.
shape
[
1
]
+
1
scores
=
scores
.
view
(
B
,
S
,
-
1
)
else
:
B
=
scores
.
shape
[
0
]
S
=
1
scores
=
scores
.
view
(
B
,
S
,
-
1
)
next_ids
=
torch
.
zeros
((
B
,
S
),
device
=
scores
.
device
,
dtype
=
torch
.
long
)
for
j
in
range
(
S
):
_scores
=
scores
[:,
j
]
if
self
.
watermark_processor
is
not
None
:
_scores
=
self
.
watermark_processor
(
input_ids
,
_scores
)
if
self
.
repetition_processor
is
not
None
:
_scores
=
self
.
repetition_processor
(
input_ids
,
_scores
)
for
warper
in
self
.
warpers
:
_scores
=
warper
(
input_ids
,
_scores
)
_next_ids
=
self
.
choice
(
_scores
)
scores
[:,
j
]
=
_scores
next_ids
[:,
j
]
=
_next_ids
next_ids
=
next_ids
.
view
(
B
*
S
)
scores
=
scores
.
view
(
B
*
S
,
-
1
)
if
speculated_ids
is
not
None
:
accepted_ids
=
[]
B
=
next_ids
.
shape
[
0
]
//
(
speculated_ids
.
shape
[
1
]
+
1
)
S
=
speculated_ids
.
shape
[
1
]
+
1
indices
=
[]
for
i
in
range
(
B
):
_next_ids
=
next_ids
[
i
*
S
:
(
i
+
1
)
*
S
]
_speculated_ids
=
speculated_ids
[
i
]
validate_speculative
=
_next_ids
[:
-
1
]
==
_speculated_ids
index
=
i
*
S
accepted
=
1
# First is always valid
indices
.
append
(
index
)
for
valid
in
validate_speculative
.
tolist
():
if
valid
:
index
+=
1
accepted
+=
1
indices
.
append
(
index
)
else
:
break
accepted_ids
.
append
(
accepted
)
accepted_ids
=
torch
.
tensor
(
accepted_ids
,
device
=
input_ids
.
device
,
dtype
=
input_ids
.
dtype
)
next_ids
=
next_ids
[
indices
]
scores
=
scores
[
indices
]
indices
=
torch
.
arange
(
B
,
device
=
input_ids
.
device
)
*
S
if
speculative_scores
is
not
None
:
speculative_scores
=
speculative_scores
[
indices
+
accepted_ids
-
1
]
else
:
accepted_ids
=
torch
.
ones_like
(
next_ids
)
next_ids
=
self
.
choice
(
scores
)
logprobs
=
torch
.
log_softmax
(
scores
,
-
1
)
next_logprobs
=
torch
.
gather
(
logprobs
,
1
,
next_ids
.
view
(
-
1
,
1
)).
view
(
-
1
)
return
next_ids
,
next_logprobs
,
logprobs
if
speculate
>
0
:
if
speculative_scores
is
not
None
:
# Medusa provided some scores
speculative_ids
=
Greedy
()(
speculative_scores
)
else
:
# n-gram
speculative_ids
=
create_n_gram_speculation
(
input_ids
,
next_ids
,
accepted_ids
,
speculate
,
verbose
)
else
:
speculative_ids
=
None
return
next_ids
,
next_logprobs
,
logprobs
,
accepted_ids
,
speculative_ids
def
filter
(
self
,
indices
):
if
self
.
watermark_processor
is
not
None
:
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment