Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
e6a80264
Commit
e6a80264
authored
Sep 18, 2023
by
Tri Dao
Browse files
[Gen] Rename max_sequence_len->max_seqlen, sequence_len_offset->seqlen_offset
parent
42832575
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
72 additions
and
75 deletions
+72
-75
flash_attn/modules/mha.py
flash_attn/modules/mha.py
+22
-26
flash_attn/utils/generation.py
flash_attn/utils/generation.py
+47
-46
tests/models/test_gpt.py
tests/models/test_gpt.py
+3
-3
No files found.
flash_attn/modules/mha.py
View file @
e6a80264
...
@@ -300,7 +300,7 @@ def _update_kv_cache(kv, inference_params, layer_idx):
...
@@ -300,7 +300,7 @@ def _update_kv_cache(kv, inference_params, layer_idx):
if
layer_idx
not
in
inference_params
.
key_value_memory_dict
:
if
layer_idx
not
in
inference_params
.
key_value_memory_dict
:
kv_cache
=
torch
.
empty
(
kv_cache
=
torch
.
empty
(
inference_params
.
max_batch_size
,
inference_params
.
max_batch_size
,
inference_params
.
max_seq
uence_
len
,
inference_params
.
max_seqlen
,
2
,
2
,
num_heads
,
num_heads
,
head_dim
,
head_dim
,
...
@@ -313,7 +313,7 @@ def _update_kv_cache(kv, inference_params, layer_idx):
...
@@ -313,7 +313,7 @@ def _update_kv_cache(kv, inference_params, layer_idx):
# Adjust key and value for inference
# Adjust key and value for inference
batch_start
=
inference_params
.
batch_size_offset
batch_start
=
inference_params
.
batch_size_offset
batch_end
=
batch_start
+
kv
.
shape
[
0
]
batch_end
=
batch_start
+
kv
.
shape
[
0
]
sequence_start
=
inference_params
.
seq
uence_
len_offset
sequence_start
=
inference_params
.
seqlen_offset
sequence_end
=
sequence_start
+
kv
.
shape
[
1
]
sequence_end
=
sequence_start
+
kv
.
shape
[
1
]
assert
batch_end
<=
(
kv_cache
.
shape
[
0
]
if
kv_cache
is
not
None
else
v_cache
.
shape
[
0
])
assert
batch_end
<=
(
kv_cache
.
shape
[
0
]
if
kv_cache
is
not
None
else
v_cache
.
shape
[
0
])
assert
sequence_end
<=
(
kv_cache
.
shape
[
1
]
if
kv_cache
is
not
None
else
v_cache
.
shape
[
2
])
assert
sequence_end
<=
(
kv_cache
.
shape
[
1
]
if
kv_cache
is
not
None
else
v_cache
.
shape
[
2
])
...
@@ -445,12 +445,12 @@ class MHA(nn.Module):
...
@@ -445,12 +445,12 @@ class MHA(nn.Module):
q: (batch_size, seqlen_q, nheads, head_dim)
q: (batch_size, seqlen_q, nheads, head_dim)
kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
"""
"""
assert
inference_params
is
not
None
and
inference_params
.
seq
uence_
len_offset
>
0
assert
inference_params
is
not
None
and
inference_params
.
seqlen_offset
>
0
assert
self
.
use_flash_attn
assert
self
.
use_flash_attn
if
self
.
rotary_emb_dim
>
0
:
if
self
.
rotary_emb_dim
>
0
:
assert
self
.
rotary_emb
.
scale
is
None
,
"This code path does not support xPos"
assert
self
.
rotary_emb
.
scale
is
None
,
"This code path does not support xPos"
self
.
rotary_emb
.
_update_cos_sin_cache
(
self
.
rotary_emb
.
_update_cos_sin_cache
(
inference_params
.
max_seq
uence_
len
,
device
=
q
.
device
,
dtype
=
q
.
dtype
inference_params
.
max_seqlen
,
device
=
q
.
device
,
dtype
=
q
.
dtype
)
)
rotary_cos
,
rotary_sin
=
self
.
rotary_emb
.
_cos_cached
,
self
.
rotary_emb
.
_sin_cached
rotary_cos
,
rotary_sin
=
self
.
rotary_emb
.
_cos_cached
,
self
.
rotary_emb
.
_sin_cached
else
:
else
:
...
@@ -460,7 +460,7 @@ class MHA(nn.Module):
...
@@ -460,7 +460,7 @@ class MHA(nn.Module):
cache_seqlens
=
(
cache_seqlens
=
(
inference_params
.
lengths_per_sample
[:
batch
]
inference_params
.
lengths_per_sample
[:
batch
]
if
inference_params
.
lengths_per_sample
is
not
None
if
inference_params
.
lengths_per_sample
is
not
None
else
inference_params
.
seq
uence_
len_offset
else
inference_params
.
seqlen_offset
)
)
context
=
flash_attn_with_kvcache
(
context
=
flash_attn_with_kvcache
(
q
,
q
,
...
@@ -480,11 +480,11 @@ class MHA(nn.Module):
...
@@ -480,11 +480,11 @@ class MHA(nn.Module):
def
_update_kvcache_attention
(
self
,
q
,
kv
,
inference_params
):
def
_update_kvcache_attention
(
self
,
q
,
kv
,
inference_params
):
"""Write kv to inference_params, then do attention"""
"""Write kv to inference_params, then do attention"""
if
(
if
(
inference_params
.
seq
uence_
len_offset
==
0
inference_params
.
seqlen_offset
==
0
or
flash_attn_with_kvcache
is
None
or
flash_attn_with_kvcache
is
None
or
not
self
.
use_flash_attn
or
not
self
.
use_flash_attn
):
):
# TODO: this only uses seq
uence_
len_offset and not lengths_per_sample.
# TODO: this only uses seqlen_offset and not lengths_per_sample.
kv
=
self
.
_update_kv_cache
(
kv
,
inference_params
)
kv
=
self
.
_update_kv_cache
(
kv
,
inference_params
)
return
self
.
inner_cross_attn
(
q
,
kv
)
return
self
.
inner_cross_attn
(
q
,
kv
)
else
:
else
:
...
@@ -493,7 +493,7 @@ class MHA(nn.Module):
...
@@ -493,7 +493,7 @@ class MHA(nn.Module):
cache_seqlens
=
(
cache_seqlens
=
(
inference_params
.
lengths_per_sample
[:
batch
]
inference_params
.
lengths_per_sample
[:
batch
]
if
inference_params
.
lengths_per_sample
is
not
None
if
inference_params
.
lengths_per_sample
is
not
None
else
inference_params
.
seq
uence_
len_offset
else
inference_params
.
seqlen_offset
)
)
return
flash_attn_with_kvcache
(
return
flash_attn_with_kvcache
(
q
,
q
,
...
@@ -561,12 +561,10 @@ class MHA(nn.Module):
...
@@ -561,12 +561,10 @@ class MHA(nn.Module):
else
(
else
(
inference_params
.
lengths_per_sample
inference_params
.
lengths_per_sample
if
inference_params
.
lengths_per_sample
is
not
None
if
inference_params
.
lengths_per_sample
is
not
None
else
inference_params
.
seq
uence_
len_offset
else
inference_params
.
seqlen_offset
)
)
)
)
rotary_max_seqlen
=
(
rotary_max_seqlen
=
inference_params
.
max_seqlen
if
inference_params
is
not
None
else
None
inference_params
.
max_sequence_len
if
inference_params
is
not
None
else
None
)
batch
,
seqlen
=
x
.
shape
[:
2
]
batch
,
seqlen
=
x
.
shape
[:
2
]
if
not
self
.
cross_attn
and
self
.
num_heads_kv
==
self
.
num_heads
:
if
not
self
.
cross_attn
and
self
.
num_heads_kv
==
self
.
num_heads
:
assert
x_kv
is
None
and
mixer_subset
is
None
assert
x_kv
is
None
and
mixer_subset
is
None
...
@@ -581,7 +579,7 @@ class MHA(nn.Module):
...
@@ -581,7 +579,7 @@ class MHA(nn.Module):
qkv
=
rearrange
(
qkv
,
"... (three h d) -> ... three h d"
,
three
=
3
,
d
=
self
.
head_dim
)
qkv
=
rearrange
(
qkv
,
"... (three h d) -> ... three h d"
,
three
=
3
,
d
=
self
.
head_dim
)
if
(
if
(
inference_params
is
None
inference_params
is
None
or
inference_params
.
seq
uence_
len_offset
==
0
or
inference_params
.
seqlen_offset
==
0
or
(
self
.
rotary_emb_dim
==
0
or
self
.
rotary_emb_dim
%
16
!=
0
)
or
(
self
.
rotary_emb_dim
==
0
or
self
.
rotary_emb_dim
%
16
!=
0
)
or
not
self
.
use_flash_attn
or
not
self
.
use_flash_attn
):
):
...
@@ -632,7 +630,7 @@ class MHA(nn.Module):
...
@@ -632,7 +630,7 @@ class MHA(nn.Module):
).
contiguous
()
).
contiguous
()
if
(
if
(
inference_params
is
None
inference_params
is
None
or
inference_params
.
seq
uence_
len_offset
==
0
or
inference_params
.
seqlen_offset
==
0
or
(
self
.
rotary_emb_dim
==
0
or
self
.
rotary_emb_dim
%
16
!=
0
)
or
(
self
.
rotary_emb_dim
==
0
or
self
.
rotary_emb_dim
%
16
!=
0
)
or
not
self
.
use_flash_attn
or
not
self
.
use_flash_attn
):
):
...
@@ -772,12 +770,12 @@ class ParallelMHA(nn.Module):
...
@@ -772,12 +770,12 @@ class ParallelMHA(nn.Module):
q: (batch_size, seqlen_q, nheads, head_dim)
q: (batch_size, seqlen_q, nheads, head_dim)
kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
"""
"""
assert
inference_params
is
not
None
and
inference_params
.
seq
uence_
len_offset
>
0
assert
inference_params
is
not
None
and
inference_params
.
seqlen_offset
>
0
assert
self
.
use_flash_attn
assert
self
.
use_flash_attn
if
self
.
rotary_emb_dim
>
0
:
if
self
.
rotary_emb_dim
>
0
:
assert
self
.
rotary_emb
.
scale
is
None
,
"This code path does not support xPos"
assert
self
.
rotary_emb
.
scale
is
None
,
"This code path does not support xPos"
self
.
rotary_emb
.
_update_cos_sin_cache
(
self
.
rotary_emb
.
_update_cos_sin_cache
(
inference_params
.
max_seq
uence_
len
,
device
=
q
.
device
,
dtype
=
q
.
dtype
inference_params
.
max_seqlen
,
device
=
q
.
device
,
dtype
=
q
.
dtype
)
)
rotary_cos
,
rotary_sin
=
self
.
rotary_emb
.
_cos_cached
,
self
.
rotary_emb
.
_sin_cached
rotary_cos
,
rotary_sin
=
self
.
rotary_emb
.
_cos_cached
,
self
.
rotary_emb
.
_sin_cached
else
:
else
:
...
@@ -787,7 +785,7 @@ class ParallelMHA(nn.Module):
...
@@ -787,7 +785,7 @@ class ParallelMHA(nn.Module):
cache_seqlens
=
(
cache_seqlens
=
(
inference_params
.
lengths_per_sample
[:
batch
]
inference_params
.
lengths_per_sample
[:
batch
]
if
inference_params
.
lengths_per_sample
is
not
None
if
inference_params
.
lengths_per_sample
is
not
None
else
inference_params
.
seq
uence_
len_offset
else
inference_params
.
seqlen_offset
)
)
context
=
flash_attn_with_kvcache
(
context
=
flash_attn_with_kvcache
(
q
,
q
,
...
@@ -806,8 +804,8 @@ class ParallelMHA(nn.Module):
...
@@ -806,8 +804,8 @@ class ParallelMHA(nn.Module):
def
_update_kvcache_attention
(
self
,
q
,
kv
,
inference_params
):
def
_update_kvcache_attention
(
self
,
q
,
kv
,
inference_params
):
"""Write kv to inference_params, then do attention"""
"""Write kv to inference_params, then do attention"""
if
inference_params
.
seq
uence_
len_offset
==
0
or
not
self
.
use_flash_attn
:
if
inference_params
.
seqlen_offset
==
0
or
not
self
.
use_flash_attn
:
# TODO: this only uses seq
uence_
len_offset and not lengths_per_sample.
# TODO: this only uses seqlen_offset and not lengths_per_sample.
kv
=
self
.
_update_kv_cache
(
kv
,
inference_params
)
kv
=
self
.
_update_kv_cache
(
kv
,
inference_params
)
return
self
.
inner_cross_attn
(
q
,
kv
)
return
self
.
inner_cross_attn
(
q
,
kv
)
else
:
else
:
...
@@ -816,7 +814,7 @@ class ParallelMHA(nn.Module):
...
@@ -816,7 +814,7 @@ class ParallelMHA(nn.Module):
cache_seqlens
=
(
cache_seqlens
=
(
inference_params
.
lengths_per_sample
[:
batch
]
inference_params
.
lengths_per_sample
[:
batch
]
if
inference_params
.
lengths_per_sample
is
not
None
if
inference_params
.
lengths_per_sample
is
not
None
else
inference_params
.
seq
uence_
len_offset
else
inference_params
.
seqlen_offset
)
)
context
=
flash_attn_with_kvcache
(
context
=
flash_attn_with_kvcache
(
q
,
q
,
...
@@ -847,17 +845,15 @@ class ParallelMHA(nn.Module):
...
@@ -847,17 +845,15 @@ class ParallelMHA(nn.Module):
else
(
else
(
inference_params
.
lengths_per_sample
inference_params
.
lengths_per_sample
if
inference_params
.
lengths_per_sample
is
not
None
if
inference_params
.
lengths_per_sample
is
not
None
else
inference_params
.
seq
uence_
len_offset
else
inference_params
.
seqlen_offset
)
)
)
)
rotary_max_seqlen
=
(
rotary_max_seqlen
=
inference_params
.
max_seqlen
if
inference_params
is
not
None
else
None
inference_params
.
max_sequence_len
if
inference_params
is
not
None
else
None
)
if
self
.
num_heads_kv
==
self
.
num_heads
:
if
self
.
num_heads_kv
==
self
.
num_heads
:
qkv
=
rearrange
(
qkv
,
"b s (three h d) -> b s three h d"
,
three
=
3
,
d
=
self
.
head_dim
)
qkv
=
rearrange
(
qkv
,
"b s (three h d) -> b s three h d"
,
three
=
3
,
d
=
self
.
head_dim
)
if
(
if
(
inference_params
is
None
inference_params
is
None
or
inference_params
.
seq
uence_
len_offset
==
0
or
inference_params
.
seqlen_offset
==
0
or
(
self
.
rotary_emb_dim
==
0
or
self
.
rotary_emb_dim
%
16
!=
0
)
or
(
self
.
rotary_emb_dim
==
0
or
self
.
rotary_emb_dim
%
16
!=
0
)
or
not
self
.
use_flash_attn
or
not
self
.
use_flash_attn
):
):
...
@@ -892,7 +888,7 @@ class ParallelMHA(nn.Module):
...
@@ -892,7 +888,7 @@ class ParallelMHA(nn.Module):
)
)
if
(
if
(
inference_params
is
None
inference_params
is
None
or
inference_params
.
seq
uence_
len_offset
==
0
or
inference_params
.
seqlen_offset
==
0
or
(
self
.
rotary_emb_dim
==
0
or
self
.
rotary_emb_dim
%
16
!=
0
)
or
(
self
.
rotary_emb_dim
==
0
or
self
.
rotary_emb_dim
%
16
!=
0
)
or
not
self
.
use_flash_attn
or
not
self
.
use_flash_attn
):
):
...
...
flash_attn/utils/generation.py
View file @
e6a80264
...
@@ -20,13 +20,20 @@ class InferenceParams:
...
@@ -20,13 +20,20 @@ class InferenceParams:
"""Inference parameters that are passed to the main model in order
"""Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference."""
to efficienly calculate and store the context during inference."""
max_seq
uence_
len
:
int
max_seqlen
:
int
max_batch_size
:
int
max_batch_size
:
int
seq
uence_
len_offset
:
int
=
0
seqlen_offset
:
int
=
0
batch_size_offset
:
int
=
0
batch_size_offset
:
int
=
0
key_value_memory_dict
:
dict
=
field
(
default_factory
=
dict
)
key_value_memory_dict
:
dict
=
field
(
default_factory
=
dict
)
lengths_per_sample
:
Optional
[
Tensor
]
=
None
lengths_per_sample
:
Optional
[
Tensor
]
=
None
def
reset
(
self
,
max_seqlen
,
max_batch_size
):
self
.
max_seqlen
=
max_seqlen
self
.
max_batch_size
=
max_batch_size
self
.
seqlen_offset
=
0
if
self
.
lengths_per_sample
is
not
None
:
self
.
lengths_per_sample
.
zero_
()
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231
...
@@ -127,19 +134,16 @@ def decode(
...
@@ -127,19 +134,16 @@ def decode(
tensor_parallel
=
tensor_parallel
,
tensor_parallel
=
tensor_parallel
,
)
)
inference_params
=
model
.
_decoding_cache
.
inference_params
inference_params
=
model
.
_decoding_cache
.
inference_params
inference_params
.
max_sequence_len
=
max_length
inference_params
.
reset
(
max_length
,
batch_size
)
inference_params
.
max_batch_size
=
batch_size
inference_params
.
sequence_len_offset
=
0
inference_params
.
lengths_per_sample
.
zero_
()
else
:
else
:
inference_params
=
InferenceParams
(
max_seq
uence_
len
=
max_length
,
max_batch_size
=
batch_size
)
inference_params
=
InferenceParams
(
max_seqlen
=
max_length
,
max_batch_size
=
batch_size
)
def
get_logits
(
input_ids
,
inference_params
):
def
get_logits
(
input_ids
,
inference_params
):
decoding
=
inference_params
.
seq
uence_
len_offset
>
0
decoding
=
inference_params
.
seqlen_offset
>
0
if
decoding
:
if
decoding
:
position_ids
=
torch
.
full
(
position_ids
=
torch
.
full
(
(
batch_size
,
1
),
(
batch_size
,
1
),
inference_params
.
seq
uence_
len_offset
,
inference_params
.
seqlen_offset
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
,
device
=
input_ids
.
device
,
)
)
...
@@ -154,24 +158,24 @@ def decode(
...
@@ -154,24 +158,24 @@ def decode(
).
logits
.
squeeze
(
dim
=
1
)
).
logits
.
squeeze
(
dim
=
1
)
else
:
else
:
logits
=
model
.
_decoding_cache
.
run
(
logits
=
model
.
_decoding_cache
.
run
(
input_ids
,
position_ids
,
inference_params
.
seq
uence_
len_offset
input_ids
,
position_ids
,
inference_params
.
seqlen_offset
).
clone
()
).
clone
()
return
logits
[...,
:
vocab_size
]
if
vocab_size
is
not
None
else
logits
return
logits
[...,
:
vocab_size
]
if
vocab_size
is
not
None
else
logits
def
sample_tokens
(
logits
,
inference_params
):
def
sample_tokens
(
logits
,
inference_params
):
if
teacher_outputs
is
None
or
teacher_output_len
<=
inference_params
.
seq
uence_
len_offset
:
if
teacher_outputs
is
None
or
teacher_output_len
<=
inference_params
.
seqlen_offset
:
token
=
sample
(
logits
,
top_k
=
top_k
,
top_p
=
top_p
,
temperature
=
temperature
)
token
=
sample
(
logits
,
top_k
=
top_k
,
top_p
=
top_p
,
temperature
=
temperature
)
else
:
else
:
token
=
teacher_outputs
[:,
inference_params
.
seq
uence_
len_offset
]
token
=
teacher_outputs
[:,
inference_params
.
seqlen_offset
]
# return rearrange(token, "b -> b 1")
# return rearrange(token, "b -> b 1")
return
token
.
unsqueeze
(
1
)
return
token
.
unsqueeze
(
1
)
def
should_stop
(
current_token
,
inference_params
):
def
should_stop
(
current_token
,
inference_params
):
if
inference_params
.
seq
uence_
len_offset
==
0
:
if
inference_params
.
seqlen_offset
==
0
:
return
False
return
False
if
eos_token_id
is
not
None
and
(
current_token
==
eos_token_id
).
all
():
if
eos_token_id
is
not
None
and
(
current_token
==
eos_token_id
).
all
():
return
True
return
True
if
inference_params
.
seq
uence_
len_offset
>=
max_length
-
1
:
if
inference_params
.
seqlen_offset
>=
max_length
-
1
:
return
True
return
True
return
False
return
False
...
@@ -185,7 +189,7 @@ def decode(
...
@@ -185,7 +189,7 @@ def decode(
scores
,
sequences
=
[],
[
input_ids
]
scores
,
sequences
=
[],
[
input_ids
]
while
not
should_stop
(
sequences
[
-
1
],
inference_params
):
while
not
should_stop
(
sequences
[
-
1
],
inference_params
):
scores
.
append
(
get_logits
(
sequences
[
-
1
],
inference_params
))
scores
.
append
(
get_logits
(
sequences
[
-
1
],
inference_params
))
inference_params
.
seq
uence_
len_offset
+=
sequences
[
-
1
].
shape
[
1
]
inference_params
.
seqlen_offset
+=
sequences
[
-
1
].
shape
[
1
]
sequences
.
append
(
sample_tokens
(
scores
[
-
1
],
inference_params
))
sequences
.
append
(
sample_tokens
(
scores
[
-
1
],
inference_params
))
if
enable_timing
:
if
enable_timing
:
end
.
record
()
end
.
record
()
...
@@ -256,6 +260,7 @@ def sample_speculative(logits, logits_draft, tokens_draft, top_k=1, top_p=0.0, t
...
@@ -256,6 +260,7 @@ def sample_speculative(logits, logits_draft, tokens_draft, top_k=1, top_p=0.0, t
return
tokens
,
first_rejected_idx
+
1
return
tokens
,
first_rejected_idx
+
1
@
torch
.
inference_mode
()
def
decode_speculative
(
def
decode_speculative
(
input_ids
,
input_ids
,
model
,
model
,
...
@@ -303,15 +308,11 @@ def decode_speculative(
...
@@ -303,15 +308,11 @@ def decode_speculative(
tensor_parallel
=
tensor_parallel
,
tensor_parallel
=
tensor_parallel
,
)
)
inference_params_draft
=
model_draft
.
_decoding_cache
.
inference_params
inference_params_draft
=
model_draft
.
_decoding_cache
.
inference_params
inference_params_draft
.
max_sequence_len
=
max_length
inference_params_draft
.
reset
(
max_length
,
batch_size
)
inference_params_draft
.
max_batch_size
=
batch_size
inference_params
=
InferenceParams
(
max_seqlen
=
max_length
,
max_batch_size
=
batch_size
)
inference_params_draft
.
sequence_len_offset
=
0
inference_params
=
InferenceParams
(
max_sequence_len
=
max_length
,
max_batch_size
=
batch_size
)
else
:
else
:
inference_params_draft
=
InferenceParams
(
inference_params_draft
=
InferenceParams
(
max_seqlen
=
max_length
,
max_batch_size
=
batch_size
)
max_sequence_len
=
max_length
,
max_batch_size
=
batch_size
inference_params
=
InferenceParams
(
max_seqlen
=
max_length
,
max_batch_size
=
batch_size
)
)
inference_params
=
InferenceParams
(
max_sequence_len
=
max_length
,
max_batch_size
=
batch_size
)
def
logits_forward_fn
(
model
,
input_ids
,
position_ids
,
inference_params
,
cg
=
False
):
def
logits_forward_fn
(
model
,
input_ids
,
position_ids
,
inference_params
,
cg
=
False
):
if
not
cg
:
if
not
cg
:
...
@@ -323,7 +324,7 @@ def decode_speculative(
...
@@ -323,7 +324,7 @@ def decode_speculative(
).
logits
.
squeeze
(
dim
=
1
)
).
logits
.
squeeze
(
dim
=
1
)
else
:
else
:
return
model
.
_decoding_cache
.
run
(
return
model
.
_decoding_cache
.
run
(
input_ids
,
position_ids
,
inference_params
.
seq
uence_
len_offset
input_ids
,
position_ids
,
inference_params
.
seqlen_offset
).
clone
()
).
clone
()
logits_postprocess_fn
=
(
logits_postprocess_fn
=
(
...
@@ -365,13 +366,13 @@ def decode_speculative(
...
@@ -365,13 +366,13 @@ def decode_speculative(
assert
seqlen
==
1
assert
seqlen
==
1
position_ids
=
repeat
(
position_ids
=
repeat
(
torch
.
arange
(
seqlen
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
torch
.
arange
(
seqlen
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
+
inference_params
.
seq
uence_
len_offset
,
+
inference_params
.
seqlen_offset
,
"s -> b s"
,
"s -> b s"
,
b
=
batch_size
,
b
=
batch_size
,
)
)
# position_ids = torch.full(
# position_ids = torch.full(
# (batch_size, 1),
# (batch_size, 1),
# inference_params.seq
uence_
len_offset,
# inference_params.seqlen_offset,
# dtype=torch.long,
# dtype=torch.long,
# device=input_ids.device,
# device=input_ids.device,
# )
# )
...
@@ -380,7 +381,7 @@ def decode_speculative(
...
@@ -380,7 +381,7 @@ def decode_speculative(
logits
=
logits_postprocess_fn
(
logits
=
logits_postprocess_fn
(
logits_forward_fn
(
model
,
input_ids
,
position_ids
,
inference_params
,
cg
=
decoding
and
cg
)
logits_forward_fn
(
model
,
input_ids
,
position_ids
,
inference_params
,
cg
=
decoding
and
cg
)
)
)
inference_params
.
seq
uence_
len_offset
+=
input_ids
.
shape
[
1
]
inference_params
.
seqlen_offset
+=
input_ids
.
shape
[
1
]
scores
=
[
logits
]
scores
=
[
logits
]
next_token
=
sample_fn
(
logits
)
next_token
=
sample_fn
(
logits
)
sequences
.
append
(
next_token
)
sequences
.
append
(
next_token
)
...
@@ -388,7 +389,7 @@ def decode_speculative(
...
@@ -388,7 +389,7 @@ def decode_speculative(
if
i
<
num_tokens
-
1
or
last_token_logits
:
if
i
<
num_tokens
-
1
or
last_token_logits
:
position_ids
=
torch
.
full
(
position_ids
=
torch
.
full
(
(
batch_size
,
1
),
(
batch_size
,
1
),
inference_params_draft
.
seq
uence_
len_offset
,
inference_params_draft
.
seqlen_offset
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
,
device
=
input_ids
.
device
,
)
)
...
@@ -401,7 +402,7 @@ def decode_speculative(
...
@@ -401,7 +402,7 @@ def decode_speculative(
cg
=
cg
,
cg
=
cg
,
)
)
)
)
inference_params
.
seq
uence_
len_offset
+=
1
inference_params
.
seqlen_offset
+=
1
scores
.
append
(
logits
)
scores
.
append
(
logits
)
if
i
<
num_tokens
-
1
:
if
i
<
num_tokens
-
1
:
next_token
=
sample_fn
(
logits
)
next_token
=
sample_fn
(
logits
)
...
@@ -476,8 +477,8 @@ def decode_speculative(
...
@@ -476,8 +477,8 @@ def decode_speculative(
scores
.
append
(
logits
[:
1
,
:
num_generated_tokens
[
0
]])
scores
.
append
(
logits
[:
1
,
:
num_generated_tokens
[
0
]])
# Note that @model has not evaluated the last sampled token yet, so we'll need to pass
# Note that @model has not evaluated the last sampled token yet, so we'll need to pass
# that in the next time we call @model.
# that in the next time we call @model.
inference_params
.
seq
uence_
len_offset
=
seqlen_og
+
num_generated_tokens
[
0
].
item
()
-
1
inference_params
.
seqlen_offset
=
seqlen_og
+
num_generated_tokens
[
0
].
item
()
-
1
inference_params_draft
.
seq
uence_
len_offset
=
inference_params
.
seq
uence_
len_offset
inference_params_draft
.
seqlen_offset
=
inference_params
.
seqlen_offset
if
debug
:
if
debug
:
cur_ids
=
torch
.
cat
([
input_ids
,
sequences
[
-
1
]],
dim
=
1
)
cur_ids
=
torch
.
cat
([
input_ids
,
sequences
[
-
1
]],
dim
=
1
)
scores_ref
=
model
(
scores_ref
=
model
(
...
@@ -486,10 +487,10 @@ def decode_speculative(
...
@@ -486,10 +487,10 @@ def decode_speculative(
print
((
scores
[
-
1
]
-
scores_ref
[:,
:
-
1
]).
abs
().
max
())
print
((
scores
[
-
1
]
-
scores_ref
[:,
:
-
1
]).
abs
().
max
())
while
True
:
while
True
:
# seq
uence_
len_offset is total length generated - 1
# seqlen_offset is total length generated - 1
if
inference_params
.
seq
uence_
len_offset
>=
max_length
-
1
:
if
inference_params
.
seqlen_offset
>=
max_length
-
1
:
break
break
if
inference_params
.
seq
uence_
len_offset
>=
max_length
-
2
:
if
inference_params
.
seqlen_offset
>=
max_length
-
2
:
# Don't do speculative sampling, just sample 1 token from the model
# Don't do speculative sampling, just sample 1 token from the model
tokens
,
scores_new
=
sample_tokens_main
(
sequences
[
-
1
][:,
-
1
:],
num_tokens
=
1
)
tokens
,
scores_new
=
sample_tokens_main
(
sequences
[
-
1
][:,
-
1
:],
num_tokens
=
1
)
sequences
.
append
(
tokens
)
sequences
.
append
(
tokens
)
...
@@ -497,7 +498,7 @@ def decode_speculative(
...
@@ -497,7 +498,7 @@ def decode_speculative(
break
break
# Sample from draft model
# Sample from draft model
n_spec_tokens
=
min
(
n_spec_tokens
=
min
(
speculative_lookahead
,
max_length
-
inference_params_draft
.
seq
uence_
len_offset
-
2
speculative_lookahead
,
max_length
-
inference_params_draft
.
seqlen_offset
-
2
)
)
tokens_draft
,
scores_draft
=
sample_tokens_draft
(
tokens_draft
,
scores_draft
=
sample_tokens_draft
(
sequences
[
-
1
][:,
-
1
:],
num_tokens
=
n_spec_tokens
sequences
[
-
1
][:,
-
1
:],
num_tokens
=
n_spec_tokens
...
@@ -510,9 +511,9 @@ def decode_speculative(
...
@@ -510,9 +511,9 @@ def decode_speculative(
# Evaluate the draft tokens with the model
# Evaluate the draft tokens with the model
position_ids
=
repeat
(
position_ids
=
repeat
(
torch
.
arange
(
torch
.
arange
(
inference_params
.
seq
uence_
len_offset
,
inference_params
.
seqlen_offset
,
# 1 extra token from last time that hasn't been passed through model
# 1 extra token from last time that hasn't been passed through model
inference_params
.
seq
uence_
len_offset
+
n_spec_tokens
+
1
,
inference_params
.
seqlen_offset
+
n_spec_tokens
+
1
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
,
device
=
input_ids
.
device
,
),
),
...
@@ -525,7 +526,7 @@ def decode_speculative(
...
@@ -525,7 +526,7 @@ def decode_speculative(
inference_params
=
inference_params
,
inference_params
=
inference_params
,
).
logits
# (batch, n_spec_tokens, vocab_size)
).
logits
# (batch, n_spec_tokens, vocab_size)
logits
=
logits_postprocess_fn
(
logits
)
logits
=
logits_postprocess_fn
(
logits
)
inference_params
.
seq
uence_
len_offset
+=
1
inference_params
.
seqlen_offset
+=
1
if
debug
:
if
debug
:
logits_ref
=
model
(
logits_ref
=
model
(
torch
.
cat
([
cur_ids
,
tokens_draft
],
dim
=
1
),
num_last_tokens
=
n_spec_tokens
+
1
torch
.
cat
([
cur_ids
,
tokens_draft
],
dim
=
1
),
num_last_tokens
=
n_spec_tokens
+
1
...
@@ -539,8 +540,8 @@ def decode_speculative(
...
@@ -539,8 +540,8 @@ def decode_speculative(
print
(
num_generated_tokens
)
print
(
num_generated_tokens
)
sequences
.
append
(
tokens
[:
1
,
:
num_generated_tokens
[
0
]])
sequences
.
append
(
tokens
[:
1
,
:
num_generated_tokens
[
0
]])
scores
.
append
(
logits
[:
1
,
:
num_generated_tokens
[
0
]])
scores
.
append
(
logits
[:
1
,
:
num_generated_tokens
[
0
]])
inference_params
.
seq
uence_
len_offset
+=
num_generated_tokens
[
0
].
item
()
-
1
inference_params
.
seqlen_offset
+=
num_generated_tokens
[
0
].
item
()
-
1
inference_params_draft
.
seq
uence_
len_offset
=
inference_params
.
seq
uence_
len_offset
inference_params_draft
.
seqlen_offset
=
inference_params
.
seqlen_offset
# breakpoint()
# breakpoint()
if
debug
:
if
debug
:
cur_ids
=
torch
.
cat
([
cur_ids
,
sequences
[
-
1
]],
dim
=
1
)
cur_ids
=
torch
.
cat
([
cur_ids
,
sequences
[
-
1
]],
dim
=
1
)
...
@@ -679,9 +680,9 @@ def update_graph_cache(
...
@@ -679,9 +680,9 @@ def update_graph_cache(
)
)
lengths_per_sample
=
torch
.
full
((
batch_size
,),
seqlen_og
,
dtype
=
torch
.
int32
,
device
=
device
)
lengths_per_sample
=
torch
.
full
((
batch_size
,),
seqlen_og
,
dtype
=
torch
.
int32
,
device
=
device
)
cache
.
inference_params
=
InferenceParams
(
cache
.
inference_params
=
InferenceParams
(
max_seq
uence_
len
=
max_seqlen
,
max_seqlen
=
max_seqlen
,
max_batch_size
=
batch_size
,
max_batch_size
=
batch_size
,
seq
uence_
len_offset
=
seqlen_og
,
seqlen_offset
=
seqlen_og
,
key_value_memory_dict
=
inf_cache
,
key_value_memory_dict
=
inf_cache
,
lengths_per_sample
=
lengths_per_sample
,
lengths_per_sample
=
lengths_per_sample
,
)
)
...
@@ -705,7 +706,7 @@ def update_graph_cache(
...
@@ -705,7 +706,7 @@ def update_graph_cache(
)
)
cache
.
run
=
dispatch
cache
.
run
=
dispatch
cache
.
inference_params
.
seq
uence_
len_offset
=
0
# Reset so it's not confusing
cache
.
inference_params
.
seqlen_offset
=
0
# Reset so it's not confusing
return
cache
return
cache
...
@@ -713,10 +714,10 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None,
...
@@ -713,10 +714,10 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None,
device
=
next
(
iter
(
model
.
parameters
())).
device
device
=
next
(
iter
(
model
.
parameters
())).
device
input_ids
=
torch
.
full
((
batch_size
,
1
),
0
,
dtype
=
torch
.
long
,
device
=
device
)
input_ids
=
torch
.
full
((
batch_size
,
1
),
0
,
dtype
=
torch
.
long
,
device
=
device
)
position_ids
=
torch
.
full
((
batch_size
,
1
),
0
,
dtype
=
torch
.
long
,
device
=
device
)
position_ids
=
torch
.
full
((
batch_size
,
1
),
0
,
dtype
=
torch
.
long
,
device
=
device
)
seq
uence_
len_offset_og
=
inference_params
.
seq
uence_
len_offset
seqlen_offset_og
=
inference_params
.
seqlen_offset
# TD [2023-04-14]: important for correctness of the FT's attention kernel, as seqlen_cpu is
# TD [2023-04-14]: important for correctness of the FT's attention kernel, as seqlen_cpu is
# used to determine the size of smem. Hence seqlen_cpu must be >= lengths_per_sample.
# used to determine the size of smem. Hence seqlen_cpu must be >= lengths_per_sample.
inference_params
.
seq
uence_
len_offset
=
max_seqlen
-
1
inference_params
.
seqlen_offset
=
max_seqlen
-
1
inference_params
.
lengths_per_sample
[:]
=
max_seqlen
-
1
inference_params
.
lengths_per_sample
[:]
=
max_seqlen
-
1
# Warmup before capture
# Warmup before capture
...
@@ -755,5 +756,5 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None,
...
@@ -755,5 +756,5 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None,
graph
.
replay
()
graph
.
replay
()
return
logits
.
clone
()
return
logits
.
clone
()
inference_params
.
seq
uence_
len_offset
=
seq
uence_
len_offset_og
inference_params
.
seqlen_offset
=
seqlen_offset_og
return
run
return
run
tests/models/test_gpt.py
View file @
e6a80264
...
@@ -364,14 +364,14 @@ def test_gpt2_multiple_token_generation(model_name, optimized):
...
@@ -364,14 +364,14 @@ def test_gpt2_multiple_token_generation(model_name, optimized):
logits_ref
=
model
(
input_ids
).
logits
logits_ref
=
model
(
input_ids
).
logits
# Run 10 tokens, then pass in another 4, then another 6, to see if we get the same logits
# Run 10 tokens, then pass in another 4, then another 6, to see if we get the same logits
inference_params
=
InferenceParams
(
max_seq
uence_
len
=
20
,
max_batch_size
=
1
)
inference_params
=
InferenceParams
(
max_seqlen
=
20
,
max_batch_size
=
1
)
logits_10
=
model
(
input_ids
[:,
:
10
],
inference_params
=
inference_params
).
logits
logits_10
=
model
(
input_ids
[:,
:
10
],
inference_params
=
inference_params
).
logits
inference_params
.
seq
uence_
len_offset
+=
10
inference_params
.
seqlen_offset
+=
10
position_ids
=
torch
.
arange
(
10
,
14
,
dtype
=
torch
.
long
,
device
=
device
)
position_ids
=
torch
.
arange
(
10
,
14
,
dtype
=
torch
.
long
,
device
=
device
)
logits_1014
=
model
(
logits_1014
=
model
(
input_ids
[:,
10
:
14
],
position_ids
=
position_ids
,
inference_params
=
inference_params
input_ids
[:,
10
:
14
],
position_ids
=
position_ids
,
inference_params
=
inference_params
).
logits
).
logits
inference_params
.
seq
uence_
len_offset
+=
4
inference_params
.
seqlen_offset
+=
4
position_ids
=
torch
.
arange
(
14
,
20
,
dtype
=
torch
.
long
,
device
=
device
)
position_ids
=
torch
.
arange
(
14
,
20
,
dtype
=
torch
.
long
,
device
=
device
)
logits_1420
=
model
(
logits_1420
=
model
(
input_ids
[:,
14
:
20
],
position_ids
=
position_ids
,
inference_params
=
inference_params
input_ids
[:,
14
:
20
],
position_ids
=
position_ids
,
inference_params
=
inference_params
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment