Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
a86442f0
"src/lib/vscode:/vscode.git/clone" did not exist on "fcfdfa3dea007a9b1adce726ce83ab8b0050223b"
Commit
a86442f0
authored
Sep 07, 2023
by
Tri Dao
Browse files
[Gen] Use flash_attn_with_kvcache in generation
parent
a1576ad1
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
134 additions
and
36 deletions
+134
-36
flash_attn/layers/rotary.py
flash_attn/layers/rotary.py
+2
-1
flash_attn/modules/mha.py
flash_attn/modules/mha.py
+96
-16
flash_attn/utils/generation.py
flash_attn/utils/generation.py
+24
-9
tests/models/test_baichuan.py
tests/models/test_baichuan.py
+8
-8
tests/models/test_gpt.py
tests/models/test_gpt.py
+4
-2
No files found.
flash_attn/layers/rotary.py
View file @
a86442f0
...
@@ -146,7 +146,8 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
...
@@ -146,7 +146,8 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
# Call 1 kernel instead of 2 kernels
# Call 1 kernel instead of 2 kernels
# We need qkv to be contiguous so that when we reshape to combine (3, nheads)
# We need qkv to be contiguous so that when we reshape to combine (3, nheads)
# dimensions, we get the same tensor
# dimensions, we get the same tensor
qk
=
rearrange
(
qkv
[:,
:,
:
2
],
"b s t h d -> b s (t h) d"
)
# qk = rearrange(qkv[:, :, :2], "b s t h d -> b s (t h) d")
qk
=
qkv
[:,
:,
:
2
].
reshape
(
batch
,
seqlen
,
-
1
,
headdim
)
apply_rotary
(
apply_rotary
(
qk
,
cos
,
sin
,
seqlen_offsets
=
seqlen_offsets
,
interleaved
=
interleaved
,
inplace
=
True
qk
,
cos
,
sin
,
seqlen_offsets
=
seqlen_offsets
,
interleaved
=
interleaved
,
inplace
=
True
)
)
...
...
flash_attn/modules/mha.py
View file @
a86442f0
...
@@ -15,10 +15,12 @@ try:
...
@@ -15,10 +15,12 @@ try:
flash_attn_qkvpacked_func
,
flash_attn_qkvpacked_func
,
flash_attn_varlen_kvpacked_func
,
flash_attn_varlen_kvpacked_func
,
flash_attn_varlen_qkvpacked_func
,
flash_attn_varlen_qkvpacked_func
,
flash_attn_with_kvcache
,
)
)
except
ImportError
:
except
ImportError
:
flash_attn_varlen_qkvpacked_func
,
flash_attn_varlen_kvpacked_func
=
None
,
None
flash_attn_varlen_qkvpacked_func
,
flash_attn_varlen_kvpacked_func
=
None
,
None
flash_attn_qkvpacked_func
,
flash_attn_kvpacked_func
=
None
,
None
flash_attn_qkvpacked_func
,
flash_attn_kvpacked_func
=
None
,
None
flash_attn_with_kvcache
=
None
try
:
try
:
from
flash_attn.ops.fused_dense
import
ColumnParallelLinear
,
FusedDense
,
RowParallelLinear
from
flash_attn.ops.fused_dense
import
ColumnParallelLinear
,
FusedDense
,
RowParallelLinear
...
@@ -556,6 +558,35 @@ class MHA(nn.Module):
...
@@ -556,6 +558,35 @@ class MHA(nn.Module):
else
False
,
else
False
,
)
)
def
_update_kvcache_attention
(
self
,
q
,
kv
,
inference_params
):
"""Write kv to inference_params, then do attention """
if
(
inference_params
.
sequence_len_offset
==
0
or
flash_attn_with_kvcache
is
None
or
not
self
.
use_flash_attn
):
# TODO: this only uses sequence_len_offset and not lengths_per_sample.
kv
=
self
.
_update_kv_cache
(
kv
,
inference_params
)
return
self
.
inner_cross_attn
(
q
,
kv
)
else
:
batch
=
q
.
shape
[
0
]
kv_cache
=
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
][:
batch
]
cache_seqlens
=
(
inference_params
.
lengths_per_sample
[:
batch
]
if
inference_params
.
lengths_per_sample
is
not
None
else
inference_params
.
sequence_len_offset
)
return
flash_attn_with_kvcache
(
q
,
kv_cache
[:,
:,
0
],
kv_cache
[:,
:,
1
],
kv
[:,
:,
0
],
kv
[:,
:,
1
],
cache_seqlens
=
cache_seqlens
,
softmax_scale
=
self
.
inner_cross_attn
.
softmax_scale
,
causal
=
self
.
inner_cross_attn
.
causal
,
)
def
forward
(
def
forward
(
self
,
self
,
x
,
x
,
...
@@ -605,10 +636,19 @@ class MHA(nn.Module):
...
@@ -605,10 +636,19 @@ class MHA(nn.Module):
if
self
.
use_flash_attn
if
self
.
use_flash_attn
else
{
"key_padding_mask"
:
key_padding_mask
,
**
kwargs
}
else
{
"key_padding_mask"
:
key_padding_mask
,
**
kwargs
}
)
)
seqlen_offset
=
0
if
inference_params
is
None
else
inference_params
.
sequence_len_offset
seqlen_offset
=
(
0
if
inference_params
is
None
else
(
inference_params
.
lengths_per_sample
if
inference_params
.
lengths_per_sample
is
not
None
else
inference_params
.
sequence_len_offset
)
)
rotary_max_seqlen
=
(
rotary_max_seqlen
=
(
inference_params
.
max_sequence_len
if
inference_params
is
not
None
else
None
inference_params
.
max_sequence_len
if
inference_params
is
not
None
else
None
)
)
batch
,
seqlen
=
x
.
shape
[:
2
]
if
not
self
.
cross_attn
and
self
.
num_heads_kv
==
self
.
num_heads
:
if
not
self
.
cross_attn
and
self
.
num_heads_kv
==
self
.
num_heads
:
assert
x_kv
is
None
and
mixer_subset
is
None
assert
x_kv
is
None
and
mixer_subset
is
None
if
not
self
.
return_residual
:
if
not
self
.
return_residual
:
...
@@ -619,7 +659,8 @@ class MHA(nn.Module):
...
@@ -619,7 +659,8 @@ class MHA(nn.Module):
qkv
=
rearrange
(
qkv
=
rearrange
(
self
.
dwconv_qkv
(
rearrange
(
qkv
,
"b s d -> b d s"
))[...,
:
-
2
],
"b d s -> b s d"
self
.
dwconv_qkv
(
rearrange
(
qkv
,
"b s d -> b d s"
))[...,
:
-
2
],
"b d s -> b s d"
).
contiguous
()
).
contiguous
()
qkv
=
rearrange
(
qkv
,
"... (three h d) -> ... three h d"
,
three
=
3
,
d
=
self
.
head_dim
)
# qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
qkv
=
qkv
.
reshape
(
batch
,
seqlen
,
3
,
self
.
num_heads
,
self
.
head_dim
)
if
(
if
(
inference_params
is
None
inference_params
is
None
or
inference_params
.
sequence_len_offset
==
0
or
inference_params
.
sequence_len_offset
==
0
...
@@ -635,9 +676,9 @@ class MHA(nn.Module):
...
@@ -635,9 +676,9 @@ class MHA(nn.Module):
else
:
else
:
context
=
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
inner_attn
,
qkv
,
**
kwargs
)
context
=
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
inner_attn
,
qkv
,
**
kwargs
)
else
:
else
:
q
=
qkv
[:,
:,
0
]
context
=
self
.
_update_kvcache_attention
(
kv
=
self
.
_update_kv_cache
(
qkv
[:,
:,
1
:],
inference_params
)
qkv
[:,
:,
0
],
qkv
[:,
:,
1
:],
inference_params
context
=
self
.
inner_cross_attn
(
q
,
kv
)
)
else
:
else
:
context
=
self
.
_apply_rotary_single_query_attention
(
qkv
,
inference_params
)
context
=
self
.
_apply_rotary_single_query_attention
(
qkv
,
inference_params
)
else
:
else
:
...
@@ -659,8 +700,10 @@ class MHA(nn.Module):
...
@@ -659,8 +700,10 @@ class MHA(nn.Module):
qkv
,
x
=
self
.
Wqkv
(
x
)
qkv
,
x
=
self
.
Wqkv
(
x
)
q
=
qkv
[...,
:
self
.
num_heads
*
self
.
head_dim
]
q
=
qkv
[...,
:
self
.
num_heads
*
self
.
head_dim
]
kv
=
qkv
[...,
self
.
num_heads
*
self
.
head_dim
:]
kv
=
qkv
[...,
self
.
num_heads
*
self
.
head_dim
:]
q
=
rearrange
(
q
,
"... (h d) -> ... h d"
,
d
=
self
.
head_dim
)
# q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
kv
=
rearrange
(
kv
,
"... (two hkv d) -> ... two hkv d"
,
two
=
2
,
d
=
self
.
head_dim
)
q
=
q
.
reshape
(
batch
,
seqlen
,
-
1
,
self
.
head_dim
)
# kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
kv
=
kv
.
reshape
(
batch
,
seqlen
,
2
,
-
1
,
self
.
head_dim
)
if
self
.
dwconv
:
if
self
.
dwconv
:
q
=
rearrange
(
q
=
rearrange
(
self
.
dwconv_q
(
rearrange
(
q
,
"b s d -> b d s"
))[...,
:
-
2
],
"b d s -> b s d"
self
.
dwconv_q
(
rearrange
(
q
,
"b s d -> b d s"
))[...,
:
-
2
],
"b d s -> b s d"
...
@@ -685,11 +728,11 @@ class MHA(nn.Module):
...
@@ -685,11 +728,11 @@ class MHA(nn.Module):
self
.
inner_cross_attn
,
q
,
kv
,
**
kwargs
self
.
inner_cross_attn
,
q
,
kv
,
**
kwargs
)
)
else
:
else
:
kv
=
self
.
_update_kv_cache
(
kv
,
inference_params
)
context
=
self
.
_update_kvcache_attention
(
q
,
kv
,
inference_params
)
context
=
self
.
inner_cross_attn
(
q
,
kv
)
else
:
else
:
context
=
self
.
_apply_rotary_single_query_attention
(
q
,
inference_params
,
kv
=
kv
)
context
=
self
.
_apply_rotary_single_query_attention
(
q
,
inference_params
,
kv
=
kv
)
out
=
self
.
out_proj
(
rearrange
(
context
,
"... h d -> ... (h d)"
))
# out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
out
=
self
.
out_proj
(
context
.
reshape
(
batch
,
seqlen
,
-
1
))
return
out
if
not
self
.
return_residual
else
(
out
,
x
)
return
out
if
not
self
.
return_residual
else
(
out
,
x
)
...
@@ -846,6 +889,36 @@ class ParallelMHA(nn.Module):
...
@@ -846,6 +889,36 @@ class ParallelMHA(nn.Module):
else
False
,
else
False
,
)
)
def
_update_kvcache_attention
(
self
,
q
,
kv
,
inference_params
):
"""Write kv to inference_params, then do attention """
if
(
inference_params
.
sequence_len_offset
==
0
or
flash_attn_with_kvcache
is
None
or
not
self
.
use_flash_attn
):
# TODO: this only uses sequence_len_offset and not lengths_per_sample.
kv
=
self
.
_update_kv_cache
(
kv
,
inference_params
)
return
self
.
inner_cross_attn
(
q
,
kv
)
else
:
batch
=
q
.
shape
[
0
]
kv_cache
=
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
][:
batch
]
cache_seqlens
=
(
inference_params
.
lengths_per_sample
[:
batch
]
if
inference_params
.
lengths_per_sample
is
not
None
else
inference_params
.
sequence_len_offset
)
context
=
flash_attn_with_kvcache
(
q
,
kv_cache
[:,
:,
0
],
kv_cache
[:,
:,
1
],
kv
[:,
:,
0
],
kv
[:,
:,
1
],
cache_seqlens
=
cache_seqlens
,
softmax_scale
=
self
.
inner_cross_attn
.
softmax_scale
,
causal
=
self
.
inner_cross_attn
.
causal
,
)
return
context
def
forward
(
self
,
x
,
seqlen
=
None
,
inference_params
=
None
,
**
kwargs
):
def
forward
(
self
,
x
,
seqlen
=
None
,
inference_params
=
None
,
**
kwargs
):
"""
"""
Arguments:
Arguments:
...
@@ -857,7 +930,15 @@ class ParallelMHA(nn.Module):
...
@@ -857,7 +930,15 @@ class ParallelMHA(nn.Module):
qkv
=
self
.
Wqkv
(
x
)
qkv
=
self
.
Wqkv
(
x
)
if
seqlen
is
not
None
:
if
seqlen
is
not
None
:
qkv
=
rearrange
(
qkv
,
"(b s) ... -> b s ..."
,
s
=
seqlen
)
qkv
=
rearrange
(
qkv
,
"(b s) ... -> b s ..."
,
s
=
seqlen
)
seqlen_offset
=
0
if
inference_params
is
None
else
inference_params
.
sequence_len_offset
seqlen_offset
=
(
0
if
inference_params
is
None
else
(
inference_params
.
lengths_per_sample
if
inference_params
.
lengths_per_sample
is
not
None
else
inference_params
.
sequence_len_offset
)
)
rotary_max_seqlen
=
(
rotary_max_seqlen
=
(
inference_params
.
max_sequence_len
if
inference_params
is
not
None
else
None
inference_params
.
max_sequence_len
if
inference_params
is
not
None
else
None
)
)
...
@@ -878,9 +959,9 @@ class ParallelMHA(nn.Module):
...
@@ -878,9 +959,9 @@ class ParallelMHA(nn.Module):
else
:
else
:
context
=
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
inner_attn
,
qkv
,
**
kwargs
)
context
=
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
inner_attn
,
qkv
,
**
kwargs
)
else
:
else
:
q
=
qkv
[:,
:,
0
]
context
=
self
.
_update_kvcache_attention
(
kv
=
_update_kv_cache
(
qkv
[:,
:,
1
:],
inference_params
,
self
.
layer_idx
)
qkv
[:,
:,
0
],
qkv
[:,
:,
1
:],
inference_params
context
=
self
.
inner_cross_attn
(
q
,
kv
)
)
else
:
else
:
context
=
self
.
_apply_rotary_single_query_attention
(
qkv
,
inference_params
)
context
=
self
.
_apply_rotary_single_query_attention
(
qkv
,
inference_params
)
else
:
else
:
...
@@ -912,8 +993,7 @@ class ParallelMHA(nn.Module):
...
@@ -912,8 +993,7 @@ class ParallelMHA(nn.Module):
self
.
inner_cross_attn
,
q
,
kv
,
**
kwargs
self
.
inner_cross_attn
,
q
,
kv
,
**
kwargs
)
)
else
:
else
:
kv
=
self
.
_update_kv_cache
(
kv
,
inference_params
)
context
=
self
.
_update_kvcache_attention
(
q
,
kv
,
inference_params
)
context
=
self
.
inner_cross_attn
(
q
,
kv
)
else
:
else
:
context
=
self
.
_apply_rotary_single_query_attention
(
q
,
inference_params
,
kv
=
kv
)
context
=
self
.
_apply_rotary_single_query_attention
(
q
,
inference_params
,
kv
=
kv
)
context
=
rearrange
(
context
,
"b s h d -> b s (h d)"
)
context
=
rearrange
(
context
,
"b s h d -> b s (h d)"
)
...
...
flash_attn/utils/generation.py
View file @
a86442f0
...
@@ -118,7 +118,6 @@ def decode(
...
@@ -118,7 +118,6 @@ def decode(
batch_size
,
seqlen_og
=
input_ids
.
shape
batch_size
,
seqlen_og
=
input_ids
.
shape
teacher_output_len
=
teacher_outputs
.
shape
[
1
]
if
teacher_outputs
is
not
None
else
0
teacher_output_len
=
teacher_outputs
.
shape
[
1
]
if
teacher_outputs
is
not
None
else
0
if
cg
:
if
cg
:
assert
fused_ft_kernel
if
not
hasattr
(
model
,
"_decoding_cache"
):
if
not
hasattr
(
model
,
"_decoding_cache"
):
model
.
_decoding_cache
=
None
model
.
_decoding_cache
=
None
model
.
_decoding_cache
=
update_graph_cache
(
model
.
_decoding_cache
=
update_graph_cache
(
...
@@ -128,11 +127,13 @@ def decode(
...
@@ -128,11 +127,13 @@ def decode(
seqlen_og
,
seqlen_og
,
max_length
,
max_length
,
tensor_parallel
=
tensor_parallel
,
tensor_parallel
=
tensor_parallel
,
fused_ft_kernel
=
fused_ft_kernel
,
)
)
inference_params
=
model
.
_decoding_cache
.
inference_params
inference_params
=
model
.
_decoding_cache
.
inference_params
inference_params
.
max_sequence_len
=
max_length
inference_params
.
max_sequence_len
=
max_length
inference_params
.
max_batch_size
=
batch_size
inference_params
.
max_batch_size
=
batch_size
inference_params
.
sequence_len_offset
=
0
inference_params
.
sequence_len_offset
=
0
inference_params
.
lengths_per_sample
.
zero_
()
else
:
else
:
inference_params
=
InferenceParams
(
inference_params
=
InferenceParams
(
max_sequence_len
=
max_length
,
max_batch_size
=
batch_size
,
fused_ft_kernel
=
fused_ft_kernel
max_sequence_len
=
max_length
,
max_batch_size
=
batch_size
,
fused_ft_kernel
=
fused_ft_kernel
...
@@ -167,7 +168,8 @@ def decode(
...
@@ -167,7 +168,8 @@ def decode(
token
=
sample
(
logits
,
top_k
=
top_k
,
top_p
=
top_p
,
temperature
=
temperature
)
token
=
sample
(
logits
,
top_k
=
top_k
,
top_p
=
top_p
,
temperature
=
temperature
)
else
:
else
:
token
=
teacher_outputs
[:,
inference_params
.
sequence_len_offset
]
token
=
teacher_outputs
[:,
inference_params
.
sequence_len_offset
]
return
rearrange
(
token
,
"b -> b 1"
)
# return rearrange(token, "b -> b 1")
return
token
.
unsqueeze
(
1
)
def
should_stop
(
current_token
,
inference_params
):
def
should_stop
(
current_token
,
inference_params
):
if
inference_params
.
sequence_len_offset
==
0
:
if
inference_params
.
sequence_len_offset
==
0
:
...
@@ -197,9 +199,7 @@ def decode(
...
@@ -197,9 +199,7 @@ def decode(
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
f
"Prompt processing + decoding time:
{
(
start
.
elapsed_time
(
end
)):.
0
f
}
ms"
)
print
(
f
"Prompt processing + decoding time:
{
(
start
.
elapsed_time
(
end
)):.
0
f
}
ms"
)
output_cls
=
GreedySearchDecoderOnlyOutput
if
top_k
==
1
else
SampleDecoderOnlyOutput
output_cls
=
GreedySearchDecoderOnlyOutput
if
top_k
==
1
else
SampleDecoderOnlyOutput
return
output_cls
(
return
output_cls
(
sequences
=
torch
.
cat
(
sequences
,
dim
=
1
),
scores
=
tuple
(
scores
))
sequences
=
torch
.
cat
(
sequences
,
dim
=
1
),
scores
=
tuple
(
scores
)
)
def
sample_speculative
(
logits
,
logits_draft
,
tokens_draft
,
top_k
=
1
,
top_p
=
0.0
,
temperature
=
1.0
):
def
sample_speculative
(
logits
,
logits_draft
,
tokens_draft
,
top_k
=
1
,
top_p
=
0.0
,
temperature
=
1.0
):
...
@@ -298,7 +298,6 @@ def decode_speculative(
...
@@ -298,7 +298,6 @@ def decode_speculative(
assert
batch_size
==
1
,
"Speculative decoding implementation only supports batch_size=1"
assert
batch_size
==
1
,
"Speculative decoding implementation only supports batch_size=1"
assert
eos_token_id
is
None
,
"Speculative decoding implementation doesn't support eos_token_id"
assert
eos_token_id
is
None
,
"Speculative decoding implementation doesn't support eos_token_id"
if
cg
:
if
cg
:
assert
fused_ft_kernel
if
not
hasattr
(
model_draft
,
"_decoding_cache"
):
if
not
hasattr
(
model_draft
,
"_decoding_cache"
):
model_draft
.
_decoding_cache
=
None
model_draft
.
_decoding_cache
=
None
model_draft
.
_decoding_cache
=
update_graph_cache
(
model_draft
.
_decoding_cache
=
update_graph_cache
(
...
@@ -308,6 +307,7 @@ def decode_speculative(
...
@@ -308,6 +307,7 @@ def decode_speculative(
seqlen_og
,
seqlen_og
,
max_length
,
max_length
,
tensor_parallel
=
tensor_parallel
,
tensor_parallel
=
tensor_parallel
,
fused_ft_kernel
=
fused_ft_kernel
,
)
)
inference_params_draft
=
model_draft
.
_decoding_cache
.
inference_params
inference_params_draft
=
model_draft
.
_decoding_cache
.
inference_params
inference_params_draft
.
max_sequence_len
=
max_length
inference_params_draft
.
max_sequence_len
=
max_length
...
@@ -606,12 +606,14 @@ def allocate_inference_cache(
...
@@ -606,12 +606,14 @@ def allocate_inference_cache(
layers
:
Union
[
int
,
Sequence
],
layers
:
Union
[
int
,
Sequence
],
device
,
device
,
dtype
=
torch
.
float16
,
dtype
=
torch
.
float16
,
fused_ft_kernel
=
False
,
):
):
assert
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
assert
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
packsize
=
4
if
dtype
==
torch
.
float32
else
8
packsize
=
4
if
dtype
==
torch
.
float32
else
8
assert
headdim
%
packsize
==
0
assert
headdim
%
packsize
==
0
k_cache_shape
=
(
max_batch_size
,
nheads
,
headdim
//
packsize
,
max_seqlen
,
packsize
)
k_cache_shape
=
(
max_batch_size
,
nheads
,
headdim
//
packsize
,
max_seqlen
,
packsize
)
v_cache_shape
=
(
max_batch_size
,
nheads
,
max_seqlen
,
headdim
)
v_cache_shape
=
(
max_batch_size
,
nheads
,
max_seqlen
,
headdim
)
kv_cache_shape
=
(
max_batch_size
,
max_seqlen
,
2
,
nheads
,
headdim
)
if
isinstance
(
layers
,
int
):
if
isinstance
(
layers
,
int
):
layers
=
range
(
layers
)
layers
=
range
(
layers
)
return
{
return
{
...
@@ -619,6 +621,8 @@ def allocate_inference_cache(
...
@@ -619,6 +621,8 @@ def allocate_inference_cache(
torch
.
empty
(
k_cache_shape
,
device
=
device
,
dtype
=
dtype
),
torch
.
empty
(
k_cache_shape
,
device
=
device
,
dtype
=
dtype
),
torch
.
empty
(
v_cache_shape
,
device
=
device
,
dtype
=
dtype
),
torch
.
empty
(
v_cache_shape
,
device
=
device
,
dtype
=
dtype
),
)
)
if
fused_ft_kernel
else
torch
.
empty
(
kv_cache_sahpe
,
device
=
device
,
dtype
=
dtype
)
for
i
in
layers
for
i
in
layers
}
}
...
@@ -651,7 +655,15 @@ class DecodingCGCache:
...
@@ -651,7 +655,15 @@ class DecodingCGCache:
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
update_graph_cache
(
def
update_graph_cache
(
model
,
cache
,
batch_size
,
seqlen_og
,
max_seqlen
,
tensor_parallel
=
1
,
dtype
=
None
,
n_warmups
=
2
model
,
cache
,
batch_size
,
seqlen_og
,
max_seqlen
,
tensor_parallel
=
1
,
dtype
=
None
,
n_warmups
=
2
,
fused_ft_kernel
=
False
,
):
):
if
cache
is
None
:
if
cache
is
None
:
cache
=
DecodingCGCache
()
cache
=
DecodingCGCache
()
...
@@ -671,7 +683,9 @@ def update_graph_cache(
...
@@ -671,7 +683,9 @@ def update_graph_cache(
cache
.
device
,
cache
.
dtype
=
device
,
dtype
cache
.
device
,
cache
.
dtype
=
device
,
dtype
cache
.
max_batch_size
,
cache
.
max_seqlen
=
batch_size
,
max_seqlen
cache
.
max_batch_size
,
cache
.
max_seqlen
=
batch_size
,
max_seqlen
if
hasattr
(
model
,
"allocate_inference_cache"
):
if
hasattr
(
model
,
"allocate_inference_cache"
):
inf_cache
=
model
.
allocate_inference_cache
(
batch_size
,
max_seqlen
,
dtype
)
inf_cache
=
model
.
allocate_inference_cache
(
batch_size
,
max_seqlen
,
dtype
,
fused_ft_kernel
=
fused_ft_kernel
)
else
:
else
:
headdim
=
getattr
(
headdim
=
getattr
(
model
.
config
,
model
.
config
,
...
@@ -686,6 +700,7 @@ def update_graph_cache(
...
@@ -686,6 +700,7 @@ def update_graph_cache(
model
.
config
.
num_hidden_layers
,
model
.
config
.
num_hidden_layers
,
device
,
device
,
dtype
,
dtype
,
fused_ft_kernel
=
fused_ft_kernel
,
)
)
lengths_per_sample
=
torch
.
full
((
batch_size
,),
seqlen_og
,
dtype
=
torch
.
int32
,
device
=
device
)
lengths_per_sample
=
torch
.
full
((
batch_size
,),
seqlen_og
,
dtype
=
torch
.
int32
,
device
=
device
)
cache
.
inference_params
=
InferenceParams
(
cache
.
inference_params
=
InferenceParams
(
...
@@ -693,7 +708,7 @@ def update_graph_cache(
...
@@ -693,7 +708,7 @@ def update_graph_cache(
max_batch_size
=
batch_size
,
max_batch_size
=
batch_size
,
sequence_len_offset
=
seqlen_og
,
sequence_len_offset
=
seqlen_og
,
key_value_memory_dict
=
inf_cache
,
key_value_memory_dict
=
inf_cache
,
fused_ft_kernel
=
True
,
fused_ft_kernel
=
fused_ft_kernel
,
lengths_per_sample
=
lengths_per_sample
,
lengths_per_sample
=
lengths_per_sample
,
)
)
cache
.
mempool
=
torch
.
cuda
.
graphs
.
graph_pool_handle
()
cache
.
mempool
=
torch
.
cuda
.
graphs
.
graph_pool_handle
()
...
...
tests/models/test_baichuan.py
View file @
a86442f0
...
@@ -217,8 +217,9 @@ def test_baichuan_parallel_forward(model_name, world_size):
...
@@ -217,8 +217,9 @@ def test_baichuan_parallel_forward(model_name, world_size):
).
abs
().
max
().
item
()
).
abs
().
max
().
item
()
@
pytest
.
mark
.
parametrize
(
"fused_ft_kernel"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"baichuan-inc/Baichuan-7B"
])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"baichuan-inc/Baichuan-7B"
])
def
test_baichuan_generation
(
model_name
):
def
test_baichuan_generation
(
model_name
,
fused_ft_kernel
):
dtype
=
torch
.
float16
dtype
=
torch
.
float16
device
=
"cuda"
device
=
"cuda"
config
=
baichuan_config_to_gpt2_config
(
config
=
baichuan_config_to_gpt2_config
(
...
@@ -276,6 +277,7 @@ def test_baichuan_generation(model_name):
...
@@ -276,6 +277,7 @@ def test_baichuan_generation(model_name):
model
.
load_state_dict
(
pretrained_state_dict
)
model
.
load_state_dict
(
pretrained_state_dict
)
model
.
eval
()
model
.
eval
()
model
(
input_ids
)
# Warm up
print
(
"Without CUDA graph"
)
print
(
"Without CUDA graph"
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
start
=
time
.
time
()
...
@@ -283,7 +285,7 @@ def test_baichuan_generation(model_name):
...
@@ -283,7 +285,7 @@ def test_baichuan_generation(model_name):
input_ids
=
input_ids
,
input_ids
=
input_ids
,
max_length
=
max_length
,
max_length
=
max_length
,
eos_token_id
=
eos_token_id
,
eos_token_id
=
eos_token_id
,
fused_ft_kernel
=
True
,
fused_ft_kernel
=
fused_ft_kernel
,
return_dict_in_generate
=
True
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
output_scores
=
True
,
enable_timing
=
True
,
enable_timing
=
True
,
...
@@ -295,7 +297,7 @@ def test_baichuan_generation(model_name):
...
@@ -295,7 +297,7 @@ def test_baichuan_generation(model_name):
# Capture graph outside the timing loop
# Capture graph outside the timing loop
batch_size
,
seqlen_og
=
input_ids
.
shape
batch_size
,
seqlen_og
=
input_ids
.
shape
model
.
_decoding_cache
=
update_graph_cache
(
model
.
_decoding_cache
=
update_graph_cache
(
model
,
None
,
batch_size
,
seqlen_og
,
max_length
model
,
None
,
batch_size
,
seqlen_og
,
max_length
,
fused_ft_kernel
=
fused_ft_kernel
)
)
print
(
"With CUDA graph"
)
print
(
"With CUDA graph"
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
...
@@ -303,7 +305,7 @@ def test_baichuan_generation(model_name):
...
@@ -303,7 +305,7 @@ def test_baichuan_generation(model_name):
out_cg
=
model
.
generate
(
out_cg
=
model
.
generate
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
max_length
=
max_length
,
max_length
=
max_length
,
fused_ft_kernel
=
True
,
fused_ft_kernel
=
fused_ft_kernel
,
cg
=
True
,
cg
=
True
,
return_dict_in_generate
=
True
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
output_scores
=
True
,
...
@@ -346,7 +348,7 @@ def test_baichuan_parallel_generation(model_name, world_size):
...
@@ -346,7 +348,7 @@ def test_baichuan_parallel_generation(model_name, world_size):
config
=
baichuan_config_to_gpt2_config
(
config
=
baichuan_config_to_gpt2_config
(
AutoConfig
.
from_pretrained
(
model_name
,
trust_remote_code
=
True
)
AutoConfig
.
from_pretrained
(
model_name
,
trust_remote_code
=
True
)
)
)
config
.
use_flash_attn
=
Fals
e
config
.
use_flash_attn
=
Tru
e
config
.
fused_bias_fc
=
True
config
.
fused_bias_fc
=
True
config
.
fused_mlp
=
False
# We don't have fused GatedMLP yet
config
.
fused_mlp
=
False
# We don't have fused GatedMLP yet
config
.
fused_dropout_add_ln
=
False
config
.
fused_dropout_add_ln
=
False
...
@@ -393,7 +395,6 @@ def test_baichuan_parallel_generation(model_name, world_size):
...
@@ -393,7 +395,6 @@ def test_baichuan_parallel_generation(model_name, world_size):
max_length
=
max_length
,
max_length
=
max_length
,
tensor_parallel
=
world_size
,
tensor_parallel
=
world_size
,
vocab_size
=
config
.
vocab_size
,
vocab_size
=
config
.
vocab_size
,
fused_ft_kernel
=
True
,
# teacher_outputs=out_hf.sequences,
# teacher_outputs=out_hf.sequences,
return_dict_in_generate
=
True
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
output_scores
=
True
,
...
@@ -411,7 +412,6 @@ def test_baichuan_parallel_generation(model_name, world_size):
...
@@ -411,7 +412,6 @@ def test_baichuan_parallel_generation(model_name, world_size):
max_length
=
max_length
,
max_length
=
max_length
,
tensor_parallel
=
world_size
,
tensor_parallel
=
world_size
,
vocab_size
=
config
.
vocab_size
,
vocab_size
=
config
.
vocab_size
,
fused_ft_kernel
=
True
,
cg
=
True
,
cg
=
True
,
# teacher_outputs=out_hf.sequences,
# teacher_outputs=out_hf.sequences,
return_dict_in_generate
=
True
,
return_dict_in_generate
=
True
,
...
@@ -458,6 +458,6 @@ def test_baichuan_parallel_generation(model_name, world_size):
...
@@ -458,6 +458,6 @@ def test_baichuan_parallel_generation(model_name, world_size):
hf_error
=
(
logits_hf
-
logits_ref
).
abs
().
max
().
item
()
hf_error
=
(
logits_hf
-
logits_ref
).
abs
().
max
().
item
()
print
(
f
"HF fp16 logits max diff:
{
hf_error
}
"
)
print
(
f
"HF fp16 logits max diff:
{
hf_error
}
"
)
print
(
f
"Logits max diff:
{
(
logits
-
logits_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Logits max diff:
{
(
logits
-
logits_ref
).
abs
().
max
().
item
()
}
"
)
assert
(
logits
-
logits_ref
).
abs
().
max
().
item
()
<
2
*
hf_error
print
(
f
"Logits CG max diff:
{
(
logits_cg
-
logits_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Logits CG max diff:
{
(
logits_cg
-
logits_ref
).
abs
().
max
().
item
()
}
"
)
assert
(
logits
-
logits_ref
).
abs
().
max
().
item
()
<
2
*
hf_error
assert
torch
.
equal
(
logits_cg
,
logits
)
assert
torch
.
equal
(
logits_cg
,
logits
)
tests/models/test_gpt.py
View file @
a86442f0
...
@@ -135,7 +135,7 @@ def test_gpt2_optimized(model_name):
...
@@ -135,7 +135,7 @@ def test_gpt2_optimized(model_name):
@
pytest
.
mark
.
parametrize
(
"fused_ft_kernel"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"fused_ft_kernel"
,
[
False
,
True
])
# @pytest.mark.parametrize('fused_ft_kernel', [
Fals
e])
# @pytest.mark.parametrize('fused_ft_kernel', [
Tru
e])
@
pytest
.
mark
.
parametrize
(
"optimized"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"optimized"
,
[
False
,
True
])
# @pytest.mark.parametrize('optimized', [True])
# @pytest.mark.parametrize('optimized', [True])
@
pytest
.
mark
.
parametrize
(
"rotary"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"rotary"
,
[
False
,
True
])
...
@@ -209,7 +209,7 @@ def test_gpt2_generation(model_name, rotary, optimized, fused_ft_kernel):
...
@@ -209,7 +209,7 @@ def test_gpt2_generation(model_name, rotary, optimized, fused_ft_kernel):
)
)
print
(
out
.
sequences
)
print
(
out
.
sequences
)
print
(
tokenizer
.
batch_decode
(
out
.
sequences
.
tolist
()))
print
(
tokenizer
.
batch_decode
(
out
.
sequences
.
tolist
()))
if
fused_ft_kernel
or
config
.
use_flash_attn
:
if
fused_ft_kernel
or
getattr
(
config
,
"
use_flash_attn
"
,
False
)
:
out_cg
=
model
.
generate
(
out_cg
=
model
.
generate
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
max_length
=
max_length
,
max_length
=
max_length
,
...
@@ -220,6 +220,7 @@ def test_gpt2_generation(model_name, rotary, optimized, fused_ft_kernel):
...
@@ -220,6 +220,7 @@ def test_gpt2_generation(model_name, rotary, optimized, fused_ft_kernel):
enable_timing
=
True
,
enable_timing
=
True
,
)
)
print
(
out_cg
.
sequences
)
print
(
out_cg
.
sequences
)
assert
torch
.
equal
(
torch
.
stack
(
out
.
scores
,
dim
=
1
),
torch
.
stack
(
out_cg
.
scores
,
dim
=
1
))
if
not
rotary
:
if
not
rotary
:
out_hf
=
model_hf
.
generate
(
out_hf
=
model_hf
.
generate
(
...
@@ -282,6 +283,7 @@ def get_logits(model, input_ids, max_length, teacher_outputs=None, **kwargs):
...
@@ -282,6 +283,7 @@ def get_logits(model, input_ids, max_length, teacher_outputs=None, **kwargs):
@
pytest
.
mark
.
parametrize
(
"rotary"
,
[
None
,
"interleaved"
,
"block"
])
@
pytest
.
mark
.
parametrize
(
"rotary"
,
[
None
,
"interleaved"
,
"block"
])
# @pytest.mark.parametrize('rotary', [None])
# @pytest.mark.parametrize('rotary', [None])
@
pytest
.
mark
.
parametrize
(
"fused_ft_kernel"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"fused_ft_kernel"
,
[
False
,
True
])
# @pytest.mark.parametrize("fused_ft_kernel", [False])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"gpt2"
])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"gpt2"
])
def
test_gpt2_generation_cg
(
model_name
,
fused_ft_kernel
,
rotary
,
seqlen
,
maxlen
):
def
test_gpt2_generation_cg
(
model_name
,
fused_ft_kernel
,
rotary
,
seqlen
,
maxlen
):
"""Check that decoding with CUDA graph is the same as decoding without CUDA graph."""
"""Check that decoding with CUDA graph is the same as decoding without CUDA graph."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment