Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
9f42cb6e
Commit
9f42cb6e
authored
Aug 27, 2023
by
Tri Dao
Browse files
[Gen] Clone logits before returning when cg=True
parent
f8aea6ea
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
431 additions
and
4 deletions
+431
-4
flash_attn/utils/generation.py
flash_attn/utils/generation.py
+361
-2
tests/models/test_gpt.py
tests/models/test_gpt.py
+70
-2
No files found.
flash_attn/utils/generation.py
View file @
9f42cb6e
...
@@ -4,10 +4,12 @@ import gc
...
@@ -4,10 +4,12 @@ import gc
import
time
import
time
from
collections
import
namedtuple
from
collections
import
namedtuple
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
functools
import
partial
from
typing
import
Callable
,
Optional
,
Sequence
,
Union
from
typing
import
Callable
,
Optional
,
Sequence
,
Union
import
torch
import
torch
from
einops
import
rearrange
import
torch.nn.functional
as
F
from
einops
import
rearrange
,
repeat
from
torch
import
Tensor
from
torch
import
Tensor
from
torch.profiler
import
ProfilerActivity
,
profile
,
record_function
from
torch.profiler
import
ProfilerActivity
,
profile
,
record_function
from
transformers.generation
import
GreedySearchDecoderOnlyOutput
,
SampleDecoderOnlyOutput
from
transformers.generation
import
GreedySearchDecoderOnlyOutput
,
SampleDecoderOnlyOutput
...
@@ -205,6 +207,363 @@ def decode(
...
@@ -205,6 +207,363 @@ def decode(
)
)
def
sample_speculative
(
logits
,
logits_draft
,
tokens_draft
,
top_k
=
1
,
top_p
=
0.0
,
temperature
=
1.0
):
"""Algorithm 1 from [1]
[1] Fast Inference from Transformers via Speculative Decoding
Yaniv Leviathan, Matan Kalman, Yossi Matias
https://arxiv.org/abs/2211.17192
Arguments:
logits: Tensor of shape (batch_size, seqlen + 1, vocab_size)
logits_draft: Tensor of shape (batch_size, seqlen, vocab_size)
tokens_draft: Tensor of shape (batch_size, seqlen)
Return:
tokens: Tensor of shape (batch_size, seqlen + 1)
num_generated_tokens: Tensor of shape (batch_size), with value in [1, seqlen + 1].
For each sequence in the batch, the number of valid tokens that were sampled by
speculative sampling.
"""
batch
,
seqlen_p_1
,
vocab_size
=
logits
.
shape
seqlen
=
seqlen_p_1
-
1
assert
logits_draft
.
shape
==
(
batch
,
seqlen
,
vocab_size
)
assert
tokens_draft
.
shape
==
(
batch
,
seqlen
)
assert
tokens_draft
.
dtype
in
[
torch
.
int64
,
torch
.
int32
]
# TODO: if top_k = 1 we can simplify things and only work with indices
if
top_p
>
0.0
:
assert
top_p
<=
1.0
,
"top-p should be in (0, 1]."
# Clone so that when we modify for top_p we don't change the original logits
logits
=
logits
/
temperature
if
temperature
!=
1.0
else
logits
.
clone
()
logits_draft
=
logits_draft
/
temperature
if
temperature
!=
1.0
else
logits_draft
.
clone
()
if
top_k
>
0
:
top_k
=
min
(
top_k
,
logits
.
size
(
-
1
))
# Safety check
modify_logits_for_top_k_filtering
(
logits
,
top_k
)
modify_logits_for_top_k_filtering
(
logits_draft
,
top_k
)
modify_logits_for_top_p_filtering
(
logits
,
top_p
)
modify_logits_for_top_p_filtering
(
logits_draft
,
top_p
)
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
)
probs_draft
=
torch
.
softmax
(
logits_draft
,
dim
=-
1
)
gather
=
lambda
probs
,
tokens
:
rearrange
(
probs
.
gather
(
dim
=-
1
,
index
=
rearrange
(
tokens
,
"... -> ... 1"
)),
"... 1 -> ..."
)
# (batch, seqlen)
accepted
=
torch
.
rand
(
batch
,
seqlen
,
device
=
probs
.
device
)
*
gather
(
probs_draft
,
tokens_draft
)
<=
gather
(
probs
[:,
:
-
1
],
tokens_draft
)
accepted_all
=
accepted
.
all
(
dim
=-
1
)
# (batch,)
first_rejected_idx
=
torch
.
where
(
accepted_all
,
seqlen
,
accepted
.
int
().
argmin
(
dim
=-
1
))
probs_diff
=
torch
.
clamp
(
probs
[:,
:
-
1
]
-
probs_draft
,
min
=
0.0
)
# torch.multinomial can deal with unnormalized probabilities
# probs_diff /= probs_diff.sum(dim=-1, keepdim=True)
resample_probs
=
torch
.
cat
([
probs_diff
,
probs
[:,
-
1
:]],
dim
=
1
)
resample_probs
=
rearrange
(
resample_probs
.
gather
(
dim
=
1
,
index
=
repeat
(
first_rejected_idx
,
"b -> b 1 d"
,
d
=
vocab_size
)),
"b 1 d -> b d"
,
)
resample
=
torch
.
multinomial
(
resample_probs
,
num_samples
=
1
).
squeeze
(
dim
=-
1
)
# (batch,)
tokens
=
F
.
pad
(
tokens_draft
,
(
0
,
1
))
tokens
[:,
first_rejected_idx
]
=
resample
return
tokens
,
first_rejected_idx
+
1
def
decode_speculative
(
input_ids
,
model
,
model_draft
,
max_length
,
speculative_lookahead
=
3
,
top_k
=
1
,
top_p
=
0.0
,
temperature
=
1.0
,
eos_token_id
=
None
,
vocab_size
=
None
,
tensor_parallel
=
1
,
fused_ft_kernel
=
False
,
cg
=
False
,
timing
=
False
,
debug
=
False
,
):
"""
TD: WIP, for my own understanding, lightly tested. Only support batch_size == 1 for now.
Speculative decoding, either greedy or with top-k or top-p sampling.
If top-k = 0, don't limit the number of candidates (pure sampling).
Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
then top-p.
We assume that all sequences in the same batch have the same length.
Arguments:
input_ids: (batch, seq_len)
max_length: int
Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
sequences: (batch, max_length)
scores: tuples of (batch, vocab_size)
"""
batch_size
,
seqlen_og
=
input_ids
.
shape
assert
batch_size
==
1
,
"Speculative decoding implementation only supports batch_size=1"
assert
eos_token_id
is
None
,
"Speculative decoding implementation doesn't support eos_token_id"
if
cg
:
assert
fused_ft_kernel
if
not
hasattr
(
model_draft
,
"_decoding_cache"
):
model_draft
.
_decoding_cache
=
None
model_draft
.
_decoding_cache
=
update_graph_cache
(
model_draft
,
model_draft
.
_decoding_cache
,
batch_size
,
seqlen_og
,
max_length
,
tensor_parallel
=
tensor_parallel
,
)
inference_params_draft
=
model_draft
.
_decoding_cache
.
inference_params
inference_params_draft
.
max_sequence_len
=
max_length
inference_params_draft
.
max_batch_size
=
batch_size
inference_params_draft
.
sequence_len_offset
=
0
# fused_ft_kernel doesn't support passing in multiple tokens at once
inference_params
=
InferenceParams
(
max_sequence_len
=
max_length
,
max_batch_size
=
batch_size
,
fused_ft_kernel
=
False
)
else
:
inference_params_draft
=
InferenceParams
(
max_sequence_len
=
max_length
,
max_batch_size
=
batch_size
,
fused_ft_kernel
=
fused_ft_kernel
)
inference_params
=
InferenceParams
(
max_sequence_len
=
max_length
,
max_batch_size
=
batch_size
,
fused_ft_kernel
=
False
)
def
logits_forward_fn
(
model
,
input_ids
,
position_ids
,
inference_params
,
cg
=
False
):
if
not
cg
:
return
model
(
input_ids
,
position_ids
=
position_ids
,
inference_params
=
inference_params
,
num_last_tokens
=
1
,
).
logits
.
squeeze
(
dim
=
1
)
else
:
return
model
.
_decoding_cache
.
run
(
input_ids
,
position_ids
,
inference_params
.
sequence_len_offset
).
clone
()
logits_postprocess_fn
=
(
lambda
logits
:
logits
[...,
:
vocab_size
]
if
vocab_size
is
not
None
else
logits
)
def
sample_tokens
(
input_ids
,
model
,
inference_params
,
sample_fn
,
num_tokens
=
1
,
cg
=
False
,
decoding
=
True
,
last_token_logits
=
False
):
"""Sample `num_tokens` tokens from the model, given the previous logits.
Also return the logits of the sampled tokens.
Arguments:
input_ids: (batch, seqlen)
decoding: whether we're in the decoding phase or the prefilling phase. Prefill doesn't
need special position_ids.
last_token_logits: whether to return the logits of the last token. Normally we don't need this.
However, for speculative sampling, if the main model accepts all the draft tokens, plus it
samples one new token, then by right at the next iteration the draft model need to evaluate
the logits of the last draft token and the logits of the newly sampled token.
This makes implementation more complicated. So here we just evaluate the logits of the last
token in the draft model to simplify the implementation.
Return:
tokens: (batch, num_tokens)
scores: (batch, num_tokens), which contains @previous_logits and the logits of the next
(num_tokens - 1) tokens. The logits of the last token isn't computed unless last_token_logits=True.
In which case we have scores of shape (batch, num_tokens + 1)
"""
batch_size
,
seqlen
=
input_ids
.
shape
assert
num_tokens
>=
1
sequences
=
[]
if
decoding
:
assert
seqlen
==
1
position_ids
=
torch
.
full
(
(
batch_size
,
1
),
inference_params
.
sequence_len_offset
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
,
)
else
:
position_ids
=
None
logits
=
logits_postprocess_fn
(
logits_forward_fn
(
model
,
input_ids
,
position_ids
,
inference_params
,
cg
=
decoding
and
cg
)
)
inference_params
.
sequence_len_offset
+=
input_ids
.
shape
[
1
]
scores
=
[
logits
]
next_token
=
sample_fn
(
logits
)
sequences
.
append
(
next_token
)
for
i
in
range
(
num_tokens
):
if
i
<
num_tokens
-
1
or
last_token_logits
:
position_ids
=
torch
.
full
(
(
batch_size
,
1
),
inference_params_draft
.
sequence_len_offset
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
,
)
logits
=
logits_postprocess_fn
(
logits_forward_fn
(
model
,
rearrange
(
next_token
,
"b -> b 1"
),
position_ids
,
inference_params
,
cg
=
cg
)
)
inference_params
.
sequence_len_offset
+=
1
scores
.
append
(
logits
)
if
i
<
num_tokens
-
1
:
next_token
=
sample_fn
(
logits
)
sequences
.
append
(
next_token
)
return
torch
.
stack
(
sequences
,
dim
=
1
),
torch
.
stack
(
scores
,
dim
=
1
)
sampling_kwargs
=
dict
(
top_k
=
top_k
,
top_p
=
top_p
,
temperature
=
temperature
)
sample_fn
=
partial
(
sample
,
**
sampling_kwargs
)
sample_tokens_main
=
partial
(
sample_tokens
,
model
=
model
,
sample_fn
=
sample_fn
,
inference_params
=
inference_params
,
cg
=
False
)
# main model doesn't use CUDA graph
sample_tokens_draft
=
partial
(
sample_tokens
,
model
=
model_draft
,
sample_fn
=
sample_fn
,
last_token_logits
=
True
,
inference_params
=
inference_params_draft
,
cg
=
cg
)
if
debug
:
from
transformers
import
AutoTokenizer
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"gpt2"
)
sequences
=
[
input_ids
]
scores
=
[]
with
torch
.
inference_mode
():
if
timing
:
if
tensor_parallel
>
1
:
torch
.
distributed
.
barrier
()
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
if
seqlen_og
>=
max_length
-
1
:
# Don't do speculative sampling, just sample 1 token from the model
tokens
,
scores_new
=
sample_tokens_main
(
input_ids
,
num_tokens
=
1
,
decoding
=
False
)
sequences
.
append
(
tokens
)
scores
.
append
(
scores_new
)
else
:
# Sample from draft model, which produces @n_spec_tokens, and @model
# will then use to produce between 1 and 1 + @n_spec_tokens tokens.
# We want seqlen_og + 1 + @n_spec_tokens to be <= @max_length.
n_spec_tokens
=
min
(
speculative_lookahead
,
max_length
-
seqlen_og
-
1
)
tokens_draft
,
scores_draft
=
sample_tokens_draft
(
input_ids
,
num_tokens
=
n_spec_tokens
,
decoding
=
False
,
)
if
debug
:
scores_draft_ref
=
model_draft
(
torch
.
cat
([
input_ids
,
tokens_draft
],
dim
=
1
),
num_last_tokens
=
n_spec_tokens
+
1
).
logits
print
((
scores_draft
[:,
:
-
1
]
-
scores_draft_ref
[:,
:
-
1
]).
abs
().
max
())
# Evaluate the draft tokens with the model
logits
=
model
(
torch
.
cat
([
input_ids
,
tokens_draft
],
dim
=
1
),
inference_params
=
inference_params
,
num_last_tokens
=
n_spec_tokens
+
1
,
).
logits
logits
=
logits_postprocess_fn
(
logits
)
tokens
,
num_generated_tokens
=
sample_speculative
(
logits
,
scores_draft
[:,
:
-
1
],
tokens_draft
,
**
sampling_kwargs
)
if
debug
:
print
(
tokens
)
print
(
num_generated_tokens
)
# breakpoint()
# TODO: we're using the fact that batch_size == 1
# TODO: check eos_token_id
sequences
.
append
(
tokens
[:
1
,
:
num_generated_tokens
[
0
]])
scores
.
append
(
logits
[:
1
,
:
num_generated_tokens
[
0
]])
# Note that @model has not evaluated the last sampled token yet, so we'll need to pass
# that in the next time we call @model.
inference_params
.
sequence_len_offset
=
seqlen_og
+
num_generated_tokens
[
0
].
item
()
-
1
inference_params_draft
.
sequence_len_offset
=
inference_params
.
sequence_len_offset
if
debug
:
cur_ids
=
torch
.
cat
([
input_ids
,
sequences
[
-
1
]],
dim
=
1
)
scores_ref
=
model
(
cur_ids
,
num_last_tokens
=
num_generated_tokens
[
0
].
item
()
+
1
).
logits
print
((
scores
[
-
1
]
-
scores_ref
[:,
:
-
1
]).
abs
().
max
())
while
True
:
# sequence_len_offset is total length generated - 1
if
inference_params
.
sequence_len_offset
>=
max_length
-
1
:
break
if
inference_params
.
sequence_len_offset
>=
max_length
-
2
:
# Don't do speculative sampling, just sample 1 token from the model
tokens
,
scores_new
=
sample_tokens_main
(
sequences
[
-
1
][:,
-
1
:],
num_tokens
=
1
)
sequences
.
append
(
tokens
)
scores
.
append
(
scores_new
)
break
# Sample from draft model
n_spec_tokens
=
min
(
speculative_lookahead
,
max_length
-
inference_params_draft
.
sequence_len_offset
-
2
)
tokens_draft
,
scores_draft
=
sample_tokens_draft
(
sequences
[
-
1
][:,
-
1
:],
num_tokens
=
n_spec_tokens
)
if
debug
:
scores_draft_ref
=
model_draft
(
torch
.
cat
([
cur_ids
,
tokens_draft
],
dim
=
1
),
num_last_tokens
=
n_spec_tokens
+
1
).
logits
print
((
scores_draft
[:,
:
-
1
]
-
scores_draft_ref
[:,
:
-
1
]).
abs
().
max
())
# Evaluate the draft tokens with the model
position_ids
=
repeat
(
torch
.
arange
(
inference_params
.
sequence_len_offset
,
# 1 extra token from last time that hasn't been passed through model
inference_params
.
sequence_len_offset
+
n_spec_tokens
+
1
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
,
),
"s -> b s"
,
b
=
batch_size
,
)
logits
=
model
(
torch
.
cat
([
sequences
[
-
1
][:,
-
1
:],
tokens_draft
],
dim
=
1
),
position_ids
=
position_ids
,
inference_params
=
inference_params
,
).
logits
# (batch, n_spec_tokens, vocab_size)
logits
=
logits_postprocess_fn
(
logits
)
inference_params
.
sequence_len_offset
+=
1
if
debug
:
logits_ref
=
model
(
torch
.
cat
([
cur_ids
,
tokens_draft
],
dim
=
1
),
num_last_tokens
=
n_spec_tokens
+
1
).
logits
print
((
logits
-
logits_ref
).
abs
().
max
())
tokens
,
num_generated_tokens
=
sample_speculative
(
logits
,
scores_draft
[:,
:
-
1
],
tokens_draft
,
**
sampling_kwargs
)
if
debug
:
print
(
tokens
)
print
(
num_generated_tokens
)
sequences
.
append
(
tokens
[:
1
,
:
num_generated_tokens
[
0
]])
scores
.
append
(
logits
[:
1
,
:
num_generated_tokens
[
0
]])
inference_params
.
sequence_len_offset
+=
num_generated_tokens
[
0
].
item
()
-
1
inference_params_draft
.
sequence_len_offset
=
inference_params
.
sequence_len_offset
# breakpoint()
if
debug
:
cur_ids
=
torch
.
cat
([
cur_ids
,
sequences
[
-
1
]],
dim
=
1
)
scores_ref
=
model
(
cur_ids
,
num_last_tokens
=
num_generated_tokens
[
0
].
item
()
+
1
).
logits
print
((
scores
[
-
1
]
-
scores_ref
[:,
:
-
1
]).
abs
().
max
())
if
timing
:
if
tensor_parallel
>
1
:
torch
.
distributed
.
barrier
()
torch
.
cuda
.
synchronize
()
print
(
f
"Prompt processing + decoding time:
{
(
time
.
time
()
-
start
)
*
1000
:.
0
f
}
ms"
)
sequences
=
torch
.
cat
(
sequences
,
dim
=
1
)
scores
=
torch
.
cat
(
scores
,
dim
=
1
)
if
debug
:
scores_ref
=
model
(
sequences
).
logits
print
((
scores
-
scores_ref
[:,
seqlen_og
-
1
:
-
1
]).
abs
().
max
())
output_cls
=
GreedySearchDecoderOnlyOutput
if
top_k
==
1
else
SampleDecoderOnlyOutput
return
output_cls
(
sequences
=
sequences
,
scores
=
scores
)
class
GenerationMixin
:
class
GenerationMixin
:
def
allocate_inference_cache
(
self
,
batch_size
,
max_seqlen
,
dtype
=
None
,
**
kwargs
):
def
allocate_inference_cache
(
self
,
batch_size
,
max_seqlen
,
dtype
=
None
,
**
kwargs
):
raise
NotImplementedError
raise
NotImplementedError
...
@@ -394,7 +753,7 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None,
...
@@ -394,7 +753,7 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None,
input_ids
.
copy_
(
new_input_ids
)
input_ids
.
copy_
(
new_input_ids
)
position_ids
.
copy_
(
new_position_ids
)
position_ids
.
copy_
(
new_position_ids
)
graph
.
replay
()
graph
.
replay
()
return
logits
return
logits
.
clone
()
inference_params
.
sequence_len_offset
=
sequence_len_offset_og
inference_params
.
sequence_len_offset
=
sequence_len_offset_og
return
run
return
run
tests/models/test_gpt.py
View file @
9f42cb6e
...
@@ -368,11 +368,79 @@ def test_gpt2_multiple_token_generation(model_name, optimized):
...
@@ -368,11 +368,79 @@ def test_gpt2_multiple_token_generation(model_name, optimized):
logits_10
=
model
(
input_ids
[:,
:
10
],
inference_params
=
inference_params
).
logits
logits_10
=
model
(
input_ids
[:,
:
10
],
inference_params
=
inference_params
).
logits
inference_params
.
sequence_len_offset
+=
10
inference_params
.
sequence_len_offset
+=
10
position_ids
=
torch
.
arange
(
10
,
14
,
dtype
=
torch
.
long
,
device
=
device
)
position_ids
=
torch
.
arange
(
10
,
14
,
dtype
=
torch
.
long
,
device
=
device
)
logits_1014
=
model
(
input_ids
[:,
10
:
14
],
position_ids
=
position_ids
,
inference_params
=
inference_params
).
logits
logits_1014
=
model
(
input_ids
[:,
10
:
14
],
position_ids
=
position_ids
,
inference_params
=
inference_params
).
logits
inference_params
.
sequence_len_offset
+=
4
inference_params
.
sequence_len_offset
+=
4
position_ids
=
torch
.
arange
(
14
,
20
,
dtype
=
torch
.
long
,
device
=
device
)
position_ids
=
torch
.
arange
(
14
,
20
,
dtype
=
torch
.
long
,
device
=
device
)
logits_1420
=
model
(
input_ids
[:,
14
:
20
],
position_ids
=
position_ids
,
inference_params
=
inference_params
).
logits
logits_1420
=
model
(
input_ids
[:,
14
:
20
],
position_ids
=
position_ids
,
inference_params
=
inference_params
).
logits
logits
=
torch
.
cat
([
logits_10
,
logits_1014
,
logits_1420
],
dim
=
1
)
logits
=
torch
.
cat
([
logits_10
,
logits_1014
,
logits_1420
],
dim
=
1
)
print
(
f
"Logits max diff:
{
(
logits
-
logits_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Logits max diff:
{
(
logits
-
logits_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Logits mean diff:
{
(
logits
-
logits_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"Logits mean diff:
{
(
logits
-
logits_ref
).
abs
().
mean
().
item
()
}
"
)
assert
torch
.
allclose
(
logits
,
logits_ref
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
logits
,
logits_ref
,
rtol
=
rtol
,
atol
=
atol
)
@
pytest
.
mark
.
parametrize
(
"fused_ft_kernel, cg"
,
[(
False
,
False
),
(
True
,
False
),
(
True
,
True
)])
# @pytest.mark.parametrize("fused_ft_kernel, cg", [(True, True)])
# @pytest.mark.parametrize("optimized", [False, True])
@
pytest
.
mark
.
parametrize
(
"optimized"
,
[
True
])
# @pytest.mark.parametrize("model_name", ["gpt2-medium"])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"gpt2-xl"
])
def
test_gpt2_speculative_decoding
(
model_name
,
optimized
,
fused_ft_kernel
,
cg
):
dtype
=
torch
.
float16
device
=
"cuda"
rtol
,
atol
=
3e-3
,
3e-1
config
=
GPT2Config
.
from_pretrained
(
model_name
)
config
.
residual_in_fp32
=
True
if
optimized
:
config
.
use_flash_attn
=
True
config
.
fused_bias_fc
=
True
config
.
fused_mlp
=
True
config
.
fused_dropout_add_ln
=
True
config_draft
=
GPT2Config
.
from_pretrained
(
"gpt2"
)
config_draft
.
residual_in_fp32
=
True
if
optimized
:
config_draft
.
use_flash_attn
=
True
config_draft
.
fused_bias_fc
=
True
config_draft
.
fused_mlp
=
True
config_draft
.
fused_dropout_add_ln
=
True
model
=
GPTLMHeadModel
.
from_pretrained
(
model_name
,
config
,
device
=
device
,
dtype
=
dtype
)
model
.
eval
()
model_draft
=
GPTLMHeadModel
.
from_pretrained
(
"gpt2"
,
config_draft
,
device
=
device
,
dtype
=
dtype
)
model_draft
.
eval
()
torch
.
manual_seed
(
0
)
tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
"gpt2"
)
input_ids
=
tokenizer
(
"Hello, my dog is cute and he"
,
return_tensors
=
"pt"
).
input_ids
.
to
(
device
=
device
)
max_length
=
100
from
flash_attn.utils.generation
import
decode_speculative
torch
.
manual_seed
(
42
)
out
=
decode_speculative
(
input_ids
,
model
,
model_draft
,
max_length
=
max_length
,
top_k
=
5
,
fused_ft_kernel
=
fused_ft_kernel
,
cg
=
cg
,
speculative_lookahead
=
4
,
timing
=
True
,
)
print
(
tokenizer
.
batch_decode
(
out
.
sequences
))
out_og
=
model
.
generate
(
input_ids
,
max_length
=
max_length
,
top_k
=
5
,
fused_ft_kernel
=
fused_ft_kernel
,
cg
=
False
,
timing
=
True
,
return_dict_in_generate
=
True
,
)
print
(
tokenizer
.
batch_decode
(
out_og
.
sequences
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment