Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
e6a80264
Commit
e6a80264
authored
Sep 18, 2023
by
Tri Dao
Browse files
[Gen] Rename max_sequence_len->max_seqlen, sequence_len_offset->seqlen_offset
parent
42832575
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
72 additions
and
75 deletions
+72
-75
flash_attn/modules/mha.py
flash_attn/modules/mha.py
+22
-26
flash_attn/utils/generation.py
flash_attn/utils/generation.py
+47
-46
tests/models/test_gpt.py
tests/models/test_gpt.py
+3
-3
No files found.
flash_attn/modules/mha.py
View file @
e6a80264
...
...
@@ -300,7 +300,7 @@ def _update_kv_cache(kv, inference_params, layer_idx):
if
layer_idx
not
in
inference_params
.
key_value_memory_dict
:
kv_cache
=
torch
.
empty
(
inference_params
.
max_batch_size
,
inference_params
.
max_seq
uence_
len
,
inference_params
.
max_seqlen
,
2
,
num_heads
,
head_dim
,
...
...
@@ -313,7 +313,7 @@ def _update_kv_cache(kv, inference_params, layer_idx):
# Adjust key and value for inference
batch_start
=
inference_params
.
batch_size_offset
batch_end
=
batch_start
+
kv
.
shape
[
0
]
sequence_start
=
inference_params
.
seq
uence_
len_offset
sequence_start
=
inference_params
.
seqlen_offset
sequence_end
=
sequence_start
+
kv
.
shape
[
1
]
assert
batch_end
<=
(
kv_cache
.
shape
[
0
]
if
kv_cache
is
not
None
else
v_cache
.
shape
[
0
])
assert
sequence_end
<=
(
kv_cache
.
shape
[
1
]
if
kv_cache
is
not
None
else
v_cache
.
shape
[
2
])
...
...
@@ -445,12 +445,12 @@ class MHA(nn.Module):
q: (batch_size, seqlen_q, nheads, head_dim)
kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
"""
assert
inference_params
is
not
None
and
inference_params
.
seq
uence_
len_offset
>
0
assert
inference_params
is
not
None
and
inference_params
.
seqlen_offset
>
0
assert
self
.
use_flash_attn
if
self
.
rotary_emb_dim
>
0
:
assert
self
.
rotary_emb
.
scale
is
None
,
"This code path does not support xPos"
self
.
rotary_emb
.
_update_cos_sin_cache
(
inference_params
.
max_seq
uence_
len
,
device
=
q
.
device
,
dtype
=
q
.
dtype
inference_params
.
max_seqlen
,
device
=
q
.
device
,
dtype
=
q
.
dtype
)
rotary_cos
,
rotary_sin
=
self
.
rotary_emb
.
_cos_cached
,
self
.
rotary_emb
.
_sin_cached
else
:
...
...
@@ -460,7 +460,7 @@ class MHA(nn.Module):
cache_seqlens
=
(
inference_params
.
lengths_per_sample
[:
batch
]
if
inference_params
.
lengths_per_sample
is
not
None
else
inference_params
.
seq
uence_
len_offset
else
inference_params
.
seqlen_offset
)
context
=
flash_attn_with_kvcache
(
q
,
...
...
@@ -480,11 +480,11 @@ class MHA(nn.Module):
def
_update_kvcache_attention
(
self
,
q
,
kv
,
inference_params
):
"""Write kv to inference_params, then do attention"""
if
(
inference_params
.
seq
uence_
len_offset
==
0
inference_params
.
seqlen_offset
==
0
or
flash_attn_with_kvcache
is
None
or
not
self
.
use_flash_attn
):
# TODO: this only uses seq
uence_
len_offset and not lengths_per_sample.
# TODO: this only uses seqlen_offset and not lengths_per_sample.
kv
=
self
.
_update_kv_cache
(
kv
,
inference_params
)
return
self
.
inner_cross_attn
(
q
,
kv
)
else
:
...
...
@@ -493,7 +493,7 @@ class MHA(nn.Module):
cache_seqlens
=
(
inference_params
.
lengths_per_sample
[:
batch
]
if
inference_params
.
lengths_per_sample
is
not
None
else
inference_params
.
seq
uence_
len_offset
else
inference_params
.
seqlen_offset
)
return
flash_attn_with_kvcache
(
q
,
...
...
@@ -561,12 +561,10 @@ class MHA(nn.Module):
else
(
inference_params
.
lengths_per_sample
if
inference_params
.
lengths_per_sample
is
not
None
else
inference_params
.
seq
uence_
len_offset
else
inference_params
.
seqlen_offset
)
)
rotary_max_seqlen
=
(
inference_params
.
max_sequence_len
if
inference_params
is
not
None
else
None
)
rotary_max_seqlen
=
inference_params
.
max_seqlen
if
inference_params
is
not
None
else
None
batch
,
seqlen
=
x
.
shape
[:
2
]
if
not
self
.
cross_attn
and
self
.
num_heads_kv
==
self
.
num_heads
:
assert
x_kv
is
None
and
mixer_subset
is
None
...
...
@@ -581,7 +579,7 @@ class MHA(nn.Module):
qkv
=
rearrange
(
qkv
,
"... (three h d) -> ... three h d"
,
three
=
3
,
d
=
self
.
head_dim
)
if
(
inference_params
is
None
or
inference_params
.
seq
uence_
len_offset
==
0
or
inference_params
.
seqlen_offset
==
0
or
(
self
.
rotary_emb_dim
==
0
or
self
.
rotary_emb_dim
%
16
!=
0
)
or
not
self
.
use_flash_attn
):
...
...
@@ -632,7 +630,7 @@ class MHA(nn.Module):
).
contiguous
()
if
(
inference_params
is
None
or
inference_params
.
seq
uence_
len_offset
==
0
or
inference_params
.
seqlen_offset
==
0
or
(
self
.
rotary_emb_dim
==
0
or
self
.
rotary_emb_dim
%
16
!=
0
)
or
not
self
.
use_flash_attn
):
...
...
@@ -772,12 +770,12 @@ class ParallelMHA(nn.Module):
q: (batch_size, seqlen_q, nheads, head_dim)
kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
"""
assert
inference_params
is
not
None
and
inference_params
.
seq
uence_
len_offset
>
0
assert
inference_params
is
not
None
and
inference_params
.
seqlen_offset
>
0
assert
self
.
use_flash_attn
if
self
.
rotary_emb_dim
>
0
:
assert
self
.
rotary_emb
.
scale
is
None
,
"This code path does not support xPos"
self
.
rotary_emb
.
_update_cos_sin_cache
(
inference_params
.
max_seq
uence_
len
,
device
=
q
.
device
,
dtype
=
q
.
dtype
inference_params
.
max_seqlen
,
device
=
q
.
device
,
dtype
=
q
.
dtype
)
rotary_cos
,
rotary_sin
=
self
.
rotary_emb
.
_cos_cached
,
self
.
rotary_emb
.
_sin_cached
else
:
...
...
@@ -787,7 +785,7 @@ class ParallelMHA(nn.Module):
cache_seqlens
=
(
inference_params
.
lengths_per_sample
[:
batch
]
if
inference_params
.
lengths_per_sample
is
not
None
else
inference_params
.
seq
uence_
len_offset
else
inference_params
.
seqlen_offset
)
context
=
flash_attn_with_kvcache
(
q
,
...
...
@@ -806,8 +804,8 @@ class ParallelMHA(nn.Module):
def
_update_kvcache_attention
(
self
,
q
,
kv
,
inference_params
):
"""Write kv to inference_params, then do attention"""
if
inference_params
.
seq
uence_
len_offset
==
0
or
not
self
.
use_flash_attn
:
# TODO: this only uses seq
uence_
len_offset and not lengths_per_sample.
if
inference_params
.
seqlen_offset
==
0
or
not
self
.
use_flash_attn
:
# TODO: this only uses seqlen_offset and not lengths_per_sample.
kv
=
self
.
_update_kv_cache
(
kv
,
inference_params
)
return
self
.
inner_cross_attn
(
q
,
kv
)
else
:
...
...
@@ -816,7 +814,7 @@ class ParallelMHA(nn.Module):
cache_seqlens
=
(
inference_params
.
lengths_per_sample
[:
batch
]
if
inference_params
.
lengths_per_sample
is
not
None
else
inference_params
.
seq
uence_
len_offset
else
inference_params
.
seqlen_offset
)
context
=
flash_attn_with_kvcache
(
q
,
...
...
@@ -847,17 +845,15 @@ class ParallelMHA(nn.Module):
else
(
inference_params
.
lengths_per_sample
if
inference_params
.
lengths_per_sample
is
not
None
else
inference_params
.
sequence_len_offset
)
else
inference_params
.
seqlen_offset
)
rotary_max_seqlen
=
(
inference_params
.
max_sequence_len
if
inference_params
is
not
None
else
None
)
rotary_max_seqlen
=
inference_params
.
max_seqlen
if
inference_params
is
not
None
else
None
if
self
.
num_heads_kv
==
self
.
num_heads
:
qkv
=
rearrange
(
qkv
,
"b s (three h d) -> b s three h d"
,
three
=
3
,
d
=
self
.
head_dim
)
if
(
inference_params
is
None
or
inference_params
.
seq
uence_
len_offset
==
0
or
inference_params
.
seqlen_offset
==
0
or
(
self
.
rotary_emb_dim
==
0
or
self
.
rotary_emb_dim
%
16
!=
0
)
or
not
self
.
use_flash_attn
):
...
...
@@ -892,7 +888,7 @@ class ParallelMHA(nn.Module):
)
if
(
inference_params
is
None
or
inference_params
.
seq
uence_
len_offset
==
0
or
inference_params
.
seqlen_offset
==
0
or
(
self
.
rotary_emb_dim
==
0
or
self
.
rotary_emb_dim
%
16
!=
0
)
or
not
self
.
use_flash_attn
):
...
...
flash_attn/utils/generation.py
View file @
e6a80264
...
...
@@ -20,13 +20,20 @@ class InferenceParams:
"""Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference."""
max_seq
uence_
len
:
int
max_seqlen
:
int
max_batch_size
:
int
seq
uence_
len_offset
:
int
=
0
seqlen_offset
:
int
=
0
batch_size_offset
:
int
=
0
key_value_memory_dict
:
dict
=
field
(
default_factory
=
dict
)
lengths_per_sample
:
Optional
[
Tensor
]
=
None
def
reset
(
self
,
max_seqlen
,
max_batch_size
):
self
.
max_seqlen
=
max_seqlen
self
.
max_batch_size
=
max_batch_size
self
.
seqlen_offset
=
0
if
self
.
lengths_per_sample
is
not
None
:
self
.
lengths_per_sample
.
zero_
()
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231
...
...
@@ -127,19 +134,16 @@ def decode(
tensor_parallel
=
tensor_parallel
,
)
inference_params
=
model
.
_decoding_cache
.
inference_params
inference_params
.
max_sequence_len
=
max_length
inference_params
.
max_batch_size
=
batch_size
inference_params
.
sequence_len_offset
=
0
inference_params
.
lengths_per_sample
.
zero_
()
inference_params
.
reset
(
max_length
,
batch_size
)
else
:
inference_params
=
InferenceParams
(
max_seq
uence_
len
=
max_length
,
max_batch_size
=
batch_size
)
inference_params
=
InferenceParams
(
max_seqlen
=
max_length
,
max_batch_size
=
batch_size
)
def
get_logits
(
input_ids
,
inference_params
):
decoding
=
inference_params
.
seq
uence_
len_offset
>
0
decoding
=
inference_params
.
seqlen_offset
>
0
if
decoding
:
position_ids
=
torch
.
full
(
(
batch_size
,
1
),
inference_params
.
seq
uence_
len_offset
,
inference_params
.
seqlen_offset
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
,
)
...
...
@@ -154,24 +158,24 @@ def decode(
).
logits
.
squeeze
(
dim
=
1
)
else
:
logits
=
model
.
_decoding_cache
.
run
(
input_ids
,
position_ids
,
inference_params
.
seq
uence_
len_offset
input_ids
,
position_ids
,
inference_params
.
seqlen_offset
).
clone
()
return
logits
[...,
:
vocab_size
]
if
vocab_size
is
not
None
else
logits
def
sample_tokens
(
logits
,
inference_params
):
if
teacher_outputs
is
None
or
teacher_output_len
<=
inference_params
.
seq
uence_
len_offset
:
if
teacher_outputs
is
None
or
teacher_output_len
<=
inference_params
.
seqlen_offset
:
token
=
sample
(
logits
,
top_k
=
top_k
,
top_p
=
top_p
,
temperature
=
temperature
)
else
:
token
=
teacher_outputs
[:,
inference_params
.
seq
uence_
len_offset
]
token
=
teacher_outputs
[:,
inference_params
.
seqlen_offset
]
# return rearrange(token, "b -> b 1")
return
token
.
unsqueeze
(
1
)
def
should_stop
(
current_token
,
inference_params
):
if
inference_params
.
seq
uence_
len_offset
==
0
:
if
inference_params
.
seqlen_offset
==
0
:
return
False
if
eos_token_id
is
not
None
and
(
current_token
==
eos_token_id
).
all
():
return
True
if
inference_params
.
seq
uence_
len_offset
>=
max_length
-
1
:
if
inference_params
.
seqlen_offset
>=
max_length
-
1
:
return
True
return
False
...
...
@@ -185,7 +189,7 @@ def decode(
scores
,
sequences
=
[],
[
input_ids
]
while
not
should_stop
(
sequences
[
-
1
],
inference_params
):
scores
.
append
(
get_logits
(
sequences
[
-
1
],
inference_params
))
inference_params
.
seq
uence_
len_offset
+=
sequences
[
-
1
].
shape
[
1
]
inference_params
.
seqlen_offset
+=
sequences
[
-
1
].
shape
[
1
]
sequences
.
append
(
sample_tokens
(
scores
[
-
1
],
inference_params
))
if
enable_timing
:
end
.
record
()
...
...
@@ -256,6 +260,7 @@ def sample_speculative(logits, logits_draft, tokens_draft, top_k=1, top_p=0.0, t
return
tokens
,
first_rejected_idx
+
1
@
torch
.
inference_mode
()
def
decode_speculative
(
input_ids
,
model
,
...
...
@@ -303,15 +308,11 @@ def decode_speculative(
tensor_parallel
=
tensor_parallel
,
)
inference_params_draft
=
model_draft
.
_decoding_cache
.
inference_params
inference_params_draft
.
max_sequence_len
=
max_length
inference_params_draft
.
max_batch_size
=
batch_size
inference_params_draft
.
sequence_len_offset
=
0
inference_params
=
InferenceParams
(
max_sequence_len
=
max_length
,
max_batch_size
=
batch_size
)
inference_params_draft
.
reset
(
max_length
,
batch_size
)
inference_params
=
InferenceParams
(
max_seqlen
=
max_length
,
max_batch_size
=
batch_size
)
else
:
inference_params_draft
=
InferenceParams
(
max_sequence_len
=
max_length
,
max_batch_size
=
batch_size
)
inference_params
=
InferenceParams
(
max_sequence_len
=
max_length
,
max_batch_size
=
batch_size
)
inference_params_draft
=
InferenceParams
(
max_seqlen
=
max_length
,
max_batch_size
=
batch_size
)
inference_params
=
InferenceParams
(
max_seqlen
=
max_length
,
max_batch_size
=
batch_size
)
def
logits_forward_fn
(
model
,
input_ids
,
position_ids
,
inference_params
,
cg
=
False
):
if
not
cg
:
...
...
@@ -323,7 +324,7 @@ def decode_speculative(
).
logits
.
squeeze
(
dim
=
1
)
else
:
return
model
.
_decoding_cache
.
run
(
input_ids
,
position_ids
,
inference_params
.
seq
uence_
len_offset
input_ids
,
position_ids
,
inference_params
.
seqlen_offset
).
clone
()
logits_postprocess_fn
=
(
...
...
@@ -365,13 +366,13 @@ def decode_speculative(
assert
seqlen
==
1
position_ids
=
repeat
(
torch
.
arange
(
seqlen
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
+
inference_params
.
seq
uence_
len_offset
,
+
inference_params
.
seqlen_offset
,
"s -> b s"
,
b
=
batch_size
,
)
# position_ids = torch.full(
# (batch_size, 1),
# inference_params.seq
uence_
len_offset,
# inference_params.seqlen_offset,
# dtype=torch.long,
# device=input_ids.device,
# )
...
...
@@ -380,7 +381,7 @@ def decode_speculative(
logits
=
logits_postprocess_fn
(
logits_forward_fn
(
model
,
input_ids
,
position_ids
,
inference_params
,
cg
=
decoding
and
cg
)
)
inference_params
.
seq
uence_
len_offset
+=
input_ids
.
shape
[
1
]
inference_params
.
seqlen_offset
+=
input_ids
.
shape
[
1
]
scores
=
[
logits
]
next_token
=
sample_fn
(
logits
)
sequences
.
append
(
next_token
)
...
...
@@ -388,7 +389,7 @@ def decode_speculative(
if
i
<
num_tokens
-
1
or
last_token_logits
:
position_ids
=
torch
.
full
(
(
batch_size
,
1
),
inference_params_draft
.
seq
uence_
len_offset
,
inference_params_draft
.
seqlen_offset
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
,
)
...
...
@@ -401,7 +402,7 @@ def decode_speculative(
cg
=
cg
,
)
)
inference_params
.
seq
uence_
len_offset
+=
1
inference_params
.
seqlen_offset
+=
1
scores
.
append
(
logits
)
if
i
<
num_tokens
-
1
:
next_token
=
sample_fn
(
logits
)
...
...
@@ -476,8 +477,8 @@ def decode_speculative(
scores
.
append
(
logits
[:
1
,
:
num_generated_tokens
[
0
]])
# Note that @model has not evaluated the last sampled token yet, so we'll need to pass
# that in the next time we call @model.
inference_params
.
seq
uence_
len_offset
=
seqlen_og
+
num_generated_tokens
[
0
].
item
()
-
1
inference_params_draft
.
seq
uence_
len_offset
=
inference_params
.
seq
uence_
len_offset
inference_params
.
seqlen_offset
=
seqlen_og
+
num_generated_tokens
[
0
].
item
()
-
1
inference_params_draft
.
seqlen_offset
=
inference_params
.
seqlen_offset
if
debug
:
cur_ids
=
torch
.
cat
([
input_ids
,
sequences
[
-
1
]],
dim
=
1
)
scores_ref
=
model
(
...
...
@@ -486,10 +487,10 @@ def decode_speculative(
print
((
scores
[
-
1
]
-
scores_ref
[:,
:
-
1
]).
abs
().
max
())
while
True
:
# seq
uence_
len_offset is total length generated - 1
if
inference_params
.
seq
uence_
len_offset
>=
max_length
-
1
:
# seqlen_offset is total length generated - 1
if
inference_params
.
seqlen_offset
>=
max_length
-
1
:
break
if
inference_params
.
seq
uence_
len_offset
>=
max_length
-
2
:
if
inference_params
.
seqlen_offset
>=
max_length
-
2
:
# Don't do speculative sampling, just sample 1 token from the model
tokens
,
scores_new
=
sample_tokens_main
(
sequences
[
-
1
][:,
-
1
:],
num_tokens
=
1
)
sequences
.
append
(
tokens
)
...
...
@@ -497,7 +498,7 @@ def decode_speculative(
break
# Sample from draft model
n_spec_tokens
=
min
(
speculative_lookahead
,
max_length
-
inference_params_draft
.
seq
uence_
len_offset
-
2
speculative_lookahead
,
max_length
-
inference_params_draft
.
seqlen_offset
-
2
)
tokens_draft
,
scores_draft
=
sample_tokens_draft
(
sequences
[
-
1
][:,
-
1
:],
num_tokens
=
n_spec_tokens
...
...
@@ -510,9 +511,9 @@ def decode_speculative(
# Evaluate the draft tokens with the model
position_ids
=
repeat
(
torch
.
arange
(
inference_params
.
seq
uence_
len_offset
,
inference_params
.
seqlen_offset
,
# 1 extra token from last time that hasn't been passed through model
inference_params
.
seq
uence_
len_offset
+
n_spec_tokens
+
1
,
inference_params
.
seqlen_offset
+
n_spec_tokens
+
1
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
,
),
...
...
@@ -525,7 +526,7 @@ def decode_speculative(
inference_params
=
inference_params
,
).
logits
# (batch, n_spec_tokens, vocab_size)
logits
=
logits_postprocess_fn
(
logits
)
inference_params
.
seq
uence_
len_offset
+=
1
inference_params
.
seqlen_offset
+=
1
if
debug
:
logits_ref
=
model
(
torch
.
cat
([
cur_ids
,
tokens_draft
],
dim
=
1
),
num_last_tokens
=
n_spec_tokens
+
1
...
...
@@ -539,8 +540,8 @@ def decode_speculative(
print
(
num_generated_tokens
)
sequences
.
append
(
tokens
[:
1
,
:
num_generated_tokens
[
0
]])
scores
.
append
(
logits
[:
1
,
:
num_generated_tokens
[
0
]])
inference_params
.
seq
uence_
len_offset
+=
num_generated_tokens
[
0
].
item
()
-
1
inference_params_draft
.
seq
uence_
len_offset
=
inference_params
.
seq
uence_
len_offset
inference_params
.
seqlen_offset
+=
num_generated_tokens
[
0
].
item
()
-
1
inference_params_draft
.
seqlen_offset
=
inference_params
.
seqlen_offset
# breakpoint()
if
debug
:
cur_ids
=
torch
.
cat
([
cur_ids
,
sequences
[
-
1
]],
dim
=
1
)
...
...
@@ -679,9 +680,9 @@ def update_graph_cache(
)
lengths_per_sample
=
torch
.
full
((
batch_size
,),
seqlen_og
,
dtype
=
torch
.
int32
,
device
=
device
)
cache
.
inference_params
=
InferenceParams
(
max_seq
uence_
len
=
max_seqlen
,
max_seqlen
=
max_seqlen
,
max_batch_size
=
batch_size
,
seq
uence_
len_offset
=
seqlen_og
,
seqlen_offset
=
seqlen_og
,
key_value_memory_dict
=
inf_cache
,
lengths_per_sample
=
lengths_per_sample
,
)
...
...
@@ -705,7 +706,7 @@ def update_graph_cache(
)
cache
.
run
=
dispatch
cache
.
inference_params
.
seq
uence_
len_offset
=
0
# Reset so it's not confusing
cache
.
inference_params
.
seqlen_offset
=
0
# Reset so it's not confusing
return
cache
...
...
@@ -713,10 +714,10 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None,
device
=
next
(
iter
(
model
.
parameters
())).
device
input_ids
=
torch
.
full
((
batch_size
,
1
),
0
,
dtype
=
torch
.
long
,
device
=
device
)
position_ids
=
torch
.
full
((
batch_size
,
1
),
0
,
dtype
=
torch
.
long
,
device
=
device
)
seq
uence_
len_offset_og
=
inference_params
.
seq
uence_
len_offset
seqlen_offset_og
=
inference_params
.
seqlen_offset
# TD [2023-04-14]: important for correctness of the FT's attention kernel, as seqlen_cpu is
# used to determine the size of smem. Hence seqlen_cpu must be >= lengths_per_sample.
inference_params
.
seq
uence_
len_offset
=
max_seqlen
-
1
inference_params
.
seqlen_offset
=
max_seqlen
-
1
inference_params
.
lengths_per_sample
[:]
=
max_seqlen
-
1
# Warmup before capture
...
...
@@ -755,5 +756,5 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None,
graph
.
replay
()
return
logits
.
clone
()
inference_params
.
seq
uence_
len_offset
=
seq
uence_
len_offset_og
inference_params
.
seqlen_offset
=
seqlen_offset_og
return
run
tests/models/test_gpt.py
View file @
e6a80264
...
...
@@ -364,14 +364,14 @@ def test_gpt2_multiple_token_generation(model_name, optimized):
logits_ref
=
model
(
input_ids
).
logits
# Run 10 tokens, then pass in another 4, then another 6, to see if we get the same logits
inference_params
=
InferenceParams
(
max_seq
uence_
len
=
20
,
max_batch_size
=
1
)
inference_params
=
InferenceParams
(
max_seqlen
=
20
,
max_batch_size
=
1
)
logits_10
=
model
(
input_ids
[:,
:
10
],
inference_params
=
inference_params
).
logits
inference_params
.
seq
uence_
len_offset
+=
10
inference_params
.
seqlen_offset
+=
10
position_ids
=
torch
.
arange
(
10
,
14
,
dtype
=
torch
.
long
,
device
=
device
)
logits_1014
=
model
(
input_ids
[:,
10
:
14
],
position_ids
=
position_ids
,
inference_params
=
inference_params
).
logits
inference_params
.
seq
uence_
len_offset
+=
4
inference_params
.
seqlen_offset
+=
4
position_ids
=
torch
.
arange
(
14
,
20
,
dtype
=
torch
.
long
,
device
=
device
)
logits_1420
=
model
(
input_ids
[:,
14
:
20
],
position_ids
=
position_ids
,
inference_params
=
inference_params
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment