Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
dfe29f5e
Commit
dfe29f5e
authored
Sep 18, 2023
by
Tri Dao
Browse files
[Gen] Don't use ft_attention, use flash_attn_with_kvcache instead
parent
3250ff3d
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
154 additions
and
313 deletions
+154
-313
csrc/ft_attention/README.md
csrc/ft_attention/README.md
+6
-0
flash_attn/modules/mha.py
flash_attn/modules/mha.py
+111
-192
flash_attn/utils/generation.py
flash_attn/utils/generation.py
+9
-35
setup.py
setup.py
+4
-4
tests/models/test_baichuan.py
tests/models/test_baichuan.py
+5
-12
tests/models/test_bigcode.py
tests/models/test_bigcode.py
+0
-2
tests/models/test_falcon.py
tests/models/test_falcon.py
+2
-10
tests/models/test_gpt.py
tests/models/test_gpt.py
+7
-26
tests/models/test_gpt_generation_parallel.py
tests/models/test_gpt_generation_parallel.py
+4
-8
tests/models/test_gptj.py
tests/models/test_gptj.py
+2
-7
tests/models/test_llama.py
tests/models/test_llama.py
+2
-10
tests/models/test_opt.py
tests/models/test_opt.py
+2
-7
No files found.
csrc/ft_attention/README.md
View file @
dfe29f5e
...
@@ -6,3 +6,9 @@ FasterTransformer v5.2.1 for benchmarking purpose.
...
@@ -6,3 +6,9 @@ FasterTransformer v5.2.1 for benchmarking purpose.
```
sh
```
sh
cd
csrc/ft_attention
&&
pip
install
.
cd
csrc/ft_attention
&&
pip
install
.
```
```
As of 2023-09-17, this extension is no longer used in the FlashAttention repo.
FlashAttention now has implemented
[
`flash_attn_with_kvcache`
](
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attention_interface.py
)
with all the features of this
`ft_attention`
kernel (and more).
flash_attn/modules/mha.py
View file @
dfe29f5e
...
@@ -32,11 +32,6 @@ try:
...
@@ -32,11 +32,6 @@ try:
except
ImportError
:
except
ImportError
:
RotaryEmbedding
=
None
RotaryEmbedding
=
None
try
:
import
ft_attention
except
ImportError
:
ft_attention
=
None
class
FlashSelfAttention
(
nn
.
Module
):
class
FlashSelfAttention
(
nn
.
Module
):
"""Implement the scaled dot product attention with softmax.
"""Implement the scaled dot product attention with softmax.
...
@@ -314,14 +309,7 @@ def _update_kv_cache(kv, inference_params, layer_idx):
...
@@ -314,14 +309,7 @@ def _update_kv_cache(kv, inference_params, layer_idx):
)
)
inference_params
.
key_value_memory_dict
[
layer_idx
]
=
kv_cache
inference_params
.
key_value_memory_dict
[
layer_idx
]
=
kv_cache
else
:
else
:
if
not
inference_params
.
fused_ft_kernel
:
kv_cache
=
inference_params
.
key_value_memory_dict
[
layer_idx
]
kv_cache
=
inference_params
.
key_value_memory_dict
[
layer_idx
]
else
:
# For FT, k_cache has shape (b, h, headdim / packsize, s, packsize)
# where packsize = 4 if fp32, 8 if fp16 or bf16.
# v_cache has shape (b, h, s, headdim)
k_cache
,
v_cache
=
inference_params
.
key_value_memory_dict
[
layer_idx
]
kv_cache
=
None
# Adjust key and value for inference
# Adjust key and value for inference
batch_start
=
inference_params
.
batch_size_offset
batch_start
=
inference_params
.
batch_size_offset
batch_end
=
batch_start
+
kv
.
shape
[
0
]
batch_end
=
batch_start
+
kv
.
shape
[
0
]
...
@@ -329,79 +317,9 @@ def _update_kv_cache(kv, inference_params, layer_idx):
...
@@ -329,79 +317,9 @@ def _update_kv_cache(kv, inference_params, layer_idx):
sequence_end
=
sequence_start
+
kv
.
shape
[
1
]
sequence_end
=
sequence_start
+
kv
.
shape
[
1
]
assert
batch_end
<=
(
kv_cache
.
shape
[
0
]
if
kv_cache
is
not
None
else
v_cache
.
shape
[
0
])
assert
batch_end
<=
(
kv_cache
.
shape
[
0
]
if
kv_cache
is
not
None
else
v_cache
.
shape
[
0
])
assert
sequence_end
<=
(
kv_cache
.
shape
[
1
]
if
kv_cache
is
not
None
else
v_cache
.
shape
[
2
])
assert
sequence_end
<=
(
kv_cache
.
shape
[
1
]
if
kv_cache
is
not
None
else
v_cache
.
shape
[
2
])
# Copy key and values.
assert
kv_cache
is
not
None
if
not
inference_params
.
fused_ft_kernel
:
kv_cache
[
batch_start
:
batch_end
,
sequence_start
:
sequence_end
,
...]
=
kv
assert
kv_cache
is
not
None
return
kv_cache
[
batch_start
:
batch_end
,
:
sequence_end
,
...]
kv_cache
[
batch_start
:
batch_end
,
sequence_start
:
sequence_end
,
...]
=
kv
kv
=
kv_cache
[
batch_start
:
batch_end
,
:
sequence_end
,
...]
return
kv
else
:
assert
inference_params
.
sequence_len_offset
==
0
# FT kernel requires different layouts for the k_cache and v_cache.
assert
kv
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
packsize
=
4
if
kv
.
dtype
==
torch
.
float32
else
8
if
kv_cache
is
not
None
:
kv_cache
[
batch_start
:
batch_end
,
sequence_start
:
sequence_end
,
...]
=
kv
k_cache
=
rearrange
(
kv_cache
[:,
:,
0
],
"b s h (d packsize) -> b h d s packsize"
,
packsize
=
packsize
).
contiguous
()
v_cache
=
rearrange
(
kv_cache
[:,
:,
1
],
"b s h d -> b h s d"
).
contiguous
()
inference_params
.
key_value_memory_dict
[
layer_idx
]
=
(
k_cache
,
v_cache
)
else
:
k_cache
[
batch_start
:
batch_end
,
:,
:,
:
sequence_end
,
:]
=
rearrange
(
kv
[:,
:,
0
],
"b s h (d packsize) -> b h d s packsize"
,
packsize
=
packsize
)
v_cache
[
batch_start
:
batch_end
,
:,
:
sequence_end
,
:]
=
rearrange
(
kv
[:,
:,
1
],
"b s h d -> b h s d"
)
return
kv
def
_apply_rotary_single_query_attention
(
qkv
,
inference_params
,
layer_idx
,
rotary_emb_dim
,
rotary_emb_base
,
kv
=
None
,
rotary_emb_interleaved
=
False
,
):
"""
qkv: (batch_size, 1, 3, nheads, head_dim) if kv is None else it's just
q of shape (batch_size, 1, nheads, head_dim)
kv: (batch_size, 1, 2, nheads_kv, head_dim)
"""
assert
inference_params
.
fused_ft_kernel
assert
ft_attention
is
not
None
if
kv
is
None
:
q
,
k
,
v
=
rearrange
(
qkv
,
"b 1 three h d -> b three h d"
).
unbind
(
dim
=
1
)
else
:
q
=
rearrange
(
qkv
,
"b 1 h d -> b h d"
)
k
,
v
=
rearrange
(
kv
,
"b 1 two h d -> b two h d"
).
unbind
(
dim
=
1
)
batch_start
=
inference_params
.
batch_size_offset
batch_end
=
batch_start
+
q
.
shape
[
0
]
k_cache
,
v_cache
=
inference_params
.
key_value_memory_dict
[
layer_idx
]
lengths_per_sample
=
(
inference_params
.
lengths_per_sample
[
batch_start
:
batch_end
]
if
inference_params
.
lengths_per_sample
is
not
None
else
None
)
context
=
ft_attention
.
single_query_attention
(
q
,
k
,
v
,
k_cache
[
batch_start
:
batch_end
],
v_cache
[
batch_start
:
batch_end
],
lengths_per_sample
,
None
,
# rotary_cos_
None
,
# rotary_sin_
None
,
# nnz_head_idx
inference_params
.
sequence_len_offset
,
rotary_emb_dim
,
rotary_emb_base
,
not
rotary_emb_interleaved
,
# neox_rotary_style
)
return
rearrange
(
context
,
"b h d -> b 1 h d"
)
class
MHA
(
nn
.
Module
):
class
MHA
(
nn
.
Module
):
...
@@ -502,36 +420,18 @@ class MHA(nn.Module):
...
@@ -502,36 +420,18 @@ class MHA(nn.Module):
)
)
self
.
out_proj
=
linear_cls
(
embed_dim
,
embed_dim
,
bias
=
out_proj_bias
,
**
factory_kwargs
)
self
.
out_proj
=
linear_cls
(
embed_dim
,
embed_dim
,
bias
=
out_proj_bias
,
**
factory_kwargs
)
def
allocate_inference_cache
(
self
,
batch_size
,
max_seqlen
,
dtype
=
None
,
fused_ft_kernel
=
True
):
def
allocate_inference_cache
(
self
,
batch_size
,
max_seqlen
,
dtype
=
None
):
dtype
=
self
.
out_proj
.
weight
.
dtype
if
dtype
is
None
else
dtype
dtype
=
self
.
out_proj
.
weight
.
dtype
if
dtype
is
None
else
dtype
device
=
self
.
out_proj
.
weight
.
device
device
=
self
.
out_proj
.
weight
.
device
if
not
fused_ft_kernel
:
return
torch
.
empty
(
return
torch
.
empty
(
batch_size
,
batch_size
,
max_seqlen
,
max_seqlen
,
2
,
2
,
self
.
num_heads_kv
,
self
.
num_heads_kv
,
self
.
head_dim
,
self
.
head_dim
,
dtype
=
dtype
,
dtype
=
dtype
,
device
=
device
,
device
=
device
,
)
)
else
:
assert
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
packsize
=
4
if
dtype
==
torch
.
float32
else
8
assert
self
.
head_dim
%
packsize
==
0
k_cache
=
torch
.
empty
(
batch_size
,
self
.
num_heads_kv
,
self
.
head_dim
//
packsize
,
max_seqlen
,
packsize
,
dtype
=
dtype
,
device
=
device
,
)
v_cache
=
torch
.
empty
(
batch_size
,
self
.
num_heads_kv
,
max_seqlen
,
self
.
head_dim
,
dtype
=
dtype
,
device
=
device
)
return
k_cache
,
v_cache
def
_update_kv_cache
(
self
,
kv
,
inference_params
):
def
_update_kv_cache
(
self
,
kv
,
inference_params
):
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
...
@@ -539,27 +439,46 @@ class MHA(nn.Module):
...
@@ -539,27 +439,46 @@ class MHA(nn.Module):
assert
self
.
layer_idx
is
not
None
,
"Generation requires layer_idx in the constructor"
assert
self
.
layer_idx
is
not
None
,
"Generation requires layer_idx in the constructor"
return
_update_kv_cache
(
kv
,
inference_params
,
self
.
layer_idx
)
return
_update_kv_cache
(
kv
,
inference_params
,
self
.
layer_idx
)
def
_apply_rotary_
single_query
_attention
(
self
,
qkv
,
inference_params
,
kv
=
None
):
def
_apply_rotary_
update_kvcache
_attention
(
self
,
q
,
kv
,
inference_params
):
"""
"""
qkv: (batch_size, 1, 3, nheads, head_dim) if kv is None else it's just
Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
q of shape
(batch_size,
1
, nheads, head_dim)
q:
(batch_size,
seqlen_q
, nheads, head_dim)
kv: (batch_size,
1
, 2, nheads_kv, head_dim)
kv: (batch_size,
seqlen_k
, 2, nheads_kv, head_dim)
"""
"""
rotary_emb_base
=
self
.
rotary_emb
.
base
if
self
.
rotary_emb_dim
>
0
else
0
assert
inference_params
is
not
None
and
inference_params
.
sequence_len_offset
>
0
return
_apply_rotary_single_query_attention
(
assert
self
.
use_flash_attn
qkv
,
if
self
.
rotary_emb_dim
>
0
:
inference_params
,
assert
self
.
rotary_emb
.
scale
is
None
,
"This code path does not support xPos"
self
.
layer_idx
,
self
.
rotary_emb
.
_update_cos_sin_cache
(
self
.
rotary_emb_dim
,
inference_params
.
max_sequence_len
,
device
=
q
.
device
,
dtype
=
q
.
dtype
rotary_emb_base
,
)
kv
=
kv
,
rotary_cos
,
rotary_sin
=
self
.
rotary_emb
.
_cos_cached
,
self
.
rotary_emb
.
_sin_cached
rotary_emb_interleaved
=
self
.
rotary_emb
.
interleaved
else
:
if
self
.
rotary_emb_dim
>
0
rotary_cos
,
rotary_sin
=
None
,
None
else
False
,
batch
=
q
.
shape
[
0
]
kv_cache
=
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
][:
batch
]
cache_seqlens
=
(
inference_params
.
lengths_per_sample
[:
batch
]
if
inference_params
.
lengths_per_sample
is
not
None
else
inference_params
.
sequence_len_offset
)
)
context
=
flash_attn_with_kvcache
(
q
,
kv_cache
[:,
:,
0
],
kv_cache
[:,
:,
1
],
kv
[:,
:,
0
],
kv
[:,
:,
1
],
rotary_cos
=
rotary_cos
,
rotary_sin
=
rotary_sin
,
cache_seqlens
=
cache_seqlens
,
softmax_scale
=
self
.
inner_cross_attn
.
softmax_scale
,
causal
=
self
.
inner_cross_attn
.
causal
,
rotary_interleaved
=
self
.
rotary_emb
.
interleaved
if
self
.
rotary_emb_dim
>
0
else
False
,
)
return
context
def
_update_kvcache_attention
(
self
,
q
,
kv
,
inference_params
):
def
_update_kvcache_attention
(
self
,
q
,
kv
,
inference_params
):
"""Write kv to inference_params, then do attention
"""
"""Write kv to inference_params, then do attention"""
if
(
if
(
inference_params
.
sequence_len_offset
==
0
inference_params
.
sequence_len_offset
==
0
or
flash_attn_with_kvcache
is
None
or
flash_attn_with_kvcache
is
None
...
@@ -663,7 +582,8 @@ class MHA(nn.Module):
...
@@ -663,7 +582,8 @@ class MHA(nn.Module):
if
(
if
(
inference_params
is
None
inference_params
is
None
or
inference_params
.
sequence_len_offset
==
0
or
inference_params
.
sequence_len_offset
==
0
or
not
inference_params
.
fused_ft_kernel
or
(
self
.
rotary_emb_dim
==
0
or
self
.
rotary_emb_dim
%
16
!=
0
)
or
not
self
.
use_flash_attn
):
):
if
self
.
rotary_emb_dim
>
0
:
if
self
.
rotary_emb_dim
>
0
:
qkv
=
self
.
rotary_emb
(
qkv
=
self
.
rotary_emb
(
...
@@ -679,7 +599,9 @@ class MHA(nn.Module):
...
@@ -679,7 +599,9 @@ class MHA(nn.Module):
qkv
[:,
:,
0
],
qkv
[:,
:,
1
:],
inference_params
qkv
[:,
:,
0
],
qkv
[:,
:,
1
:],
inference_params
)
)
else
:
else
:
context
=
self
.
_apply_rotary_single_query_attention
(
qkv
,
inference_params
)
context
=
self
.
_apply_rotary_update_kvcache_attention
(
qkv
[:,
:,
0
],
qkv
[:,
:,
1
:],
inference_params
)
else
:
else
:
if
self
.
cross_attn
:
if
self
.
cross_attn
:
if
not
self
.
return_residual
:
if
not
self
.
return_residual
:
...
@@ -711,7 +633,8 @@ class MHA(nn.Module):
...
@@ -711,7 +633,8 @@ class MHA(nn.Module):
if
(
if
(
inference_params
is
None
inference_params
is
None
or
inference_params
.
sequence_len_offset
==
0
or
inference_params
.
sequence_len_offset
==
0
or
not
inference_params
.
fused_ft_kernel
or
(
self
.
rotary_emb_dim
==
0
or
self
.
rotary_emb_dim
%
16
!=
0
)
or
not
self
.
use_flash_attn
):
):
if
self
.
rotary_emb_dim
>
0
:
if
self
.
rotary_emb_dim
>
0
:
q
,
kv
=
self
.
rotary_emb
(
q
,
kv
=
self
.
rotary_emb
(
...
@@ -727,7 +650,7 @@ class MHA(nn.Module):
...
@@ -727,7 +650,7 @@ class MHA(nn.Module):
else
:
else
:
context
=
self
.
_update_kvcache_attention
(
q
,
kv
,
inference_params
)
context
=
self
.
_update_kvcache_attention
(
q
,
kv
,
inference_params
)
else
:
else
:
context
=
self
.
_apply_rotary_
single_query
_attention
(
q
,
inference_params
,
kv
=
kv
)
context
=
self
.
_apply_rotary_
update_kvcache
_attention
(
q
,
kv
,
inference_params
)
out
=
self
.
out_proj
(
rearrange
(
context
,
"... h d -> ... (h d)"
))
out
=
self
.
out_proj
(
rearrange
(
context
,
"... h d -> ... (h d)"
))
return
out
if
not
self
.
return_residual
else
(
out
,
x
)
return
out
if
not
self
.
return_residual
else
(
out
,
x
)
...
@@ -825,73 +748,65 @@ class ParallelMHA(nn.Module):
...
@@ -825,73 +748,65 @@ class ParallelMHA(nn.Module):
**
factory_kwargs
,
**
factory_kwargs
,
)
)
def
allocate_inference_cache
(
self
,
batch_size
,
max_seqlen
,
dtype
=
None
,
fused_ft_kernel
=
True
):
def
allocate_inference_cache
(
self
,
batch_size
,
max_seqlen
,
dtype
=
None
):
dtype
=
self
.
out_proj
.
weight
.
dtype
if
dtype
is
None
else
dtype
dtype
=
self
.
out_proj
.
weight
.
dtype
if
dtype
is
None
else
dtype
device
=
self
.
out_proj
.
weight
.
device
device
=
self
.
out_proj
.
weight
.
device
if
not
fused_ft_kernel
:
return
torch
.
empty
(
return
torch
.
empty
(
batch_size
,
batch_size
,
max_seqlen
,
max_seqlen
,
2
,
2
,
self
.
num_heads_kv_per_rank
,
self
.
num_heads_kv_per_rank
,
self
.
head_dim
,
self
.
head_dim
,
dtype
=
dtype
,
dtype
=
dtype
,
device
=
device
,
device
=
device
,
)
)
else
:
assert
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
packsize
=
4
if
dtype
==
torch
.
float32
else
8
assert
self
.
head_dim
%
packsize
==
0
k_cache
=
torch
.
empty
(
batch_size
,
self
.
num_heads_kv_per_rank
,
self
.
head_dim
//
packsize
,
max_seqlen
,
packsize
,
dtype
=
dtype
,
device
=
device
,
)
v_cache
=
torch
.
empty
(
batch_size
,
self
.
num_heads_kv_per_rank
,
max_seqlen
,
self
.
head_dim
,
dtype
=
dtype
,
device
=
device
,
)
return
k_cache
,
v_cache
def
_update_kv_cache
(
self
,
kv
,
inference_params
):
def
_update_kv_cache
(
self
,
kv
,
inference_params
):
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
assert
self
.
layer_idx
is
not
None
,
"Generation requires layer_idx in the constructor"
assert
self
.
layer_idx
is
not
None
,
"Generation requires layer_idx in the constructor"
return
_update_kv_cache
(
kv
,
inference_params
,
self
.
layer_idx
)
return
_update_kv_cache
(
kv
,
inference_params
,
self
.
layer_idx
)
def
_apply_rotary_
single_query
_attention
(
self
,
qkv
,
inference_params
,
kv
=
None
):
def
_apply_rotary_
update_kvcache
_attention
(
self
,
q
,
kv
,
inference_params
):
"""
"""
qkv: (batch_size, 1, 3, nheads, head_dim) if kv is None else it's just
Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
q of shape
(batch_size,
1
, nheads, head_dim)
q:
(batch_size,
seqlen_q
, nheads, head_dim)
kv: (batch_size,
1
, 2, nheads_kv, head_dim)
kv: (batch_size,
seqlen_k
, 2, nheads_kv, head_dim)
"""
"""
rotary_emb_base
=
self
.
rotary_emb
.
base
if
self
.
rotary_emb_dim
>
0
else
0
assert
inference_params
is
not
None
and
inference_params
.
sequence_len_offset
>
0
return
_apply_rotary_single_query_attention
(
assert
self
.
use_flash_attn
qkv
,
if
self
.
rotary_emb_dim
>
0
:
inference_params
,
assert
self
.
rotary_emb
.
scale
is
None
,
"This code path does not support xPos"
self
.
layer_idx
,
self
.
rotary_emb
.
_update_cos_sin_cache
(
self
.
rotary_emb_dim
,
inference_params
.
max_sequence_len
,
device
=
q
.
device
,
dtype
=
q
.
dtype
rotary_emb_base
,
)
kv
=
kv
,
rotary_cos
,
rotary_sin
=
self
.
rotary_emb
.
_cos_cached
,
self
.
rotary_emb
.
_sin_cached
rotary_emb_interleaved
=
self
.
rotary_emb
.
interleaved
else
:
if
self
.
rotary_emb_dim
>
0
rotary_cos
,
rotary_sin
=
None
,
None
else
False
,
batch
=
q
.
shape
[
0
]
kv_cache
=
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
][:
batch
]
cache_seqlens
=
(
inference_params
.
lengths_per_sample
[:
batch
]
if
inference_params
.
lengths_per_sample
is
not
None
else
inference_params
.
sequence_len_offset
)
context
=
flash_attn_with_kvcache
(
q
,
kv_cache
[:,
:,
0
],
kv_cache
[:,
:,
1
],
kv
[:,
:,
0
],
kv
[:,
:,
1
],
rotary_cos
=
rotary_cos
,
rotary_sin
=
rotary_sin
,
cache_seqlens
=
cache_seqlens
,
softmax_scale
=
self
.
inner_cross_attn
.
softmax_scale
,
causal
=
self
.
inner_cross_attn
.
causal
,
rotary_interleaved
=
self
.
rotary_emb
.
interleaved
if
self
.
rotary_emb_dim
>
0
else
False
,
)
)
return
context
def
_update_kvcache_attention
(
self
,
q
,
kv
,
inference_params
):
def
_update_kvcache_attention
(
self
,
q
,
kv
,
inference_params
):
"""Write kv to inference_params, then do attention """
"""Write kv to inference_params, then do attention"""
if
(
if
inference_params
.
sequence_len_offset
==
0
or
not
self
.
use_flash_attn
:
inference_params
.
sequence_len_offset
==
0
or
flash_attn_with_kvcache
is
None
or
not
self
.
use_flash_attn
):
# TODO: this only uses sequence_len_offset and not lengths_per_sample.
# TODO: this only uses sequence_len_offset and not lengths_per_sample.
kv
=
self
.
_update_kv_cache
(
kv
,
inference_params
)
kv
=
self
.
_update_kv_cache
(
kv
,
inference_params
)
return
self
.
inner_cross_attn
(
q
,
kv
)
return
self
.
inner_cross_attn
(
q
,
kv
)
...
@@ -943,7 +858,8 @@ class ParallelMHA(nn.Module):
...
@@ -943,7 +858,8 @@ class ParallelMHA(nn.Module):
if
(
if
(
inference_params
is
None
inference_params
is
None
or
inference_params
.
sequence_len_offset
==
0
or
inference_params
.
sequence_len_offset
==
0
or
not
inference_params
.
fused_ft_kernel
or
(
self
.
rotary_emb_dim
==
0
or
self
.
rotary_emb_dim
%
16
!=
0
)
or
not
self
.
use_flash_attn
):
):
if
self
.
rotary_emb_dim
>
0
:
if
self
.
rotary_emb_dim
>
0
:
qkv
=
self
.
rotary_emb
(
qkv
=
self
.
rotary_emb
(
...
@@ -959,7 +875,9 @@ class ParallelMHA(nn.Module):
...
@@ -959,7 +875,9 @@ class ParallelMHA(nn.Module):
qkv
[:,
:,
0
],
qkv
[:,
:,
1
:],
inference_params
qkv
[:,
:,
0
],
qkv
[:,
:,
1
:],
inference_params
)
)
else
:
else
:
context
=
self
.
_apply_rotary_single_query_attention
(
qkv
,
inference_params
)
context
=
self
.
_apply_rotary_update_kvcache_attention
(
qkv
[:,
:,
0
],
qkv
[:,
:,
1
:],
inference_params
)
else
:
else
:
q
=
rearrange
(
q
=
rearrange
(
qkv
[...,
:
self
.
num_heads_per_rank
*
self
.
head_dim
],
qkv
[...,
:
self
.
num_heads_per_rank
*
self
.
head_dim
],
...
@@ -975,7 +893,8 @@ class ParallelMHA(nn.Module):
...
@@ -975,7 +893,8 @@ class ParallelMHA(nn.Module):
if
(
if
(
inference_params
is
None
inference_params
is
None
or
inference_params
.
sequence_len_offset
==
0
or
inference_params
.
sequence_len_offset
==
0
or
not
inference_params
.
fused_ft_kernel
or
(
self
.
rotary_emb_dim
==
0
or
self
.
rotary_emb_dim
%
16
!=
0
)
or
not
self
.
use_flash_attn
):
):
if
self
.
rotary_emb_dim
>
0
:
if
self
.
rotary_emb_dim
>
0
:
q
,
kv
=
self
.
rotary_emb
(
q
,
kv
=
self
.
rotary_emb
(
...
@@ -991,7 +910,7 @@ class ParallelMHA(nn.Module):
...
@@ -991,7 +910,7 @@ class ParallelMHA(nn.Module):
else
:
else
:
context
=
self
.
_update_kvcache_attention
(
q
,
kv
,
inference_params
)
context
=
self
.
_update_kvcache_attention
(
q
,
kv
,
inference_params
)
else
:
else
:
context
=
self
.
_apply_rotary_
single_query
_attention
(
q
,
inference_params
,
kv
=
kv
)
context
=
self
.
_apply_rotary_
update_kvcache
_attention
(
q
,
kv
,
inference_params
)
context
=
rearrange
(
context
,
"b s h d -> b s (h d)"
)
context
=
rearrange
(
context
,
"b s h d -> b s (h d)"
)
if
seqlen
is
not
None
:
if
seqlen
is
not
None
:
context
=
rearrange
(
context
,
"b s d -> (b s) d"
)
context
=
rearrange
(
context
,
"b s d -> (b s) d"
)
...
...
flash_attn/utils/generation.py
View file @
dfe29f5e
...
@@ -25,7 +25,6 @@ class InferenceParams:
...
@@ -25,7 +25,6 @@ class InferenceParams:
sequence_len_offset
:
int
=
0
sequence_len_offset
:
int
=
0
batch_size_offset
:
int
=
0
batch_size_offset
:
int
=
0
key_value_memory_dict
:
dict
=
field
(
default_factory
=
dict
)
key_value_memory_dict
:
dict
=
field
(
default_factory
=
dict
)
fused_ft_kernel
:
bool
=
False
lengths_per_sample
:
Optional
[
Tensor
]
=
None
lengths_per_sample
:
Optional
[
Tensor
]
=
None
...
@@ -96,7 +95,6 @@ def decode(
...
@@ -96,7 +95,6 @@ def decode(
teacher_outputs
=
None
,
teacher_outputs
=
None
,
vocab_size
=
None
,
vocab_size
=
None
,
tensor_parallel
=
1
,
tensor_parallel
=
1
,
fused_ft_kernel
=
False
,
cg
=
False
,
cg
=
False
,
enable_timing
=
False
,
enable_timing
=
False
,
):
):
...
@@ -127,7 +125,6 @@ def decode(
...
@@ -127,7 +125,6 @@ def decode(
seqlen_og
,
seqlen_og
,
max_length
,
max_length
,
tensor_parallel
=
tensor_parallel
,
tensor_parallel
=
tensor_parallel
,
fused_ft_kernel
=
fused_ft_kernel
,
)
)
inference_params
=
model
.
_decoding_cache
.
inference_params
inference_params
=
model
.
_decoding_cache
.
inference_params
inference_params
.
max_sequence_len
=
max_length
inference_params
.
max_sequence_len
=
max_length
...
@@ -135,9 +132,7 @@ def decode(
...
@@ -135,9 +132,7 @@ def decode(
inference_params
.
sequence_len_offset
=
0
inference_params
.
sequence_len_offset
=
0
inference_params
.
lengths_per_sample
.
zero_
()
inference_params
.
lengths_per_sample
.
zero_
()
else
:
else
:
inference_params
=
InferenceParams
(
inference_params
=
InferenceParams
(
max_sequence_len
=
max_length
,
max_batch_size
=
batch_size
)
max_sequence_len
=
max_length
,
max_batch_size
=
batch_size
,
fused_ft_kernel
=
fused_ft_kernel
)
def
get_logits
(
input_ids
,
inference_params
):
def
get_logits
(
input_ids
,
inference_params
):
decoding
=
inference_params
.
sequence_len_offset
>
0
decoding
=
inference_params
.
sequence_len_offset
>
0
...
@@ -273,7 +268,6 @@ def decode_speculative(
...
@@ -273,7 +268,6 @@ def decode_speculative(
eos_token_id
=
None
,
eos_token_id
=
None
,
vocab_size
=
None
,
vocab_size
=
None
,
tensor_parallel
=
1
,
tensor_parallel
=
1
,
fused_ft_kernel
=
False
,
cg
=
False
,
cg
=
False
,
enable_timing
=
False
,
enable_timing
=
False
,
debug
=
False
,
debug
=
False
,
...
@@ -307,23 +301,17 @@ def decode_speculative(
...
@@ -307,23 +301,17 @@ def decode_speculative(
seqlen_og
,
seqlen_og
,
max_length
,
max_length
,
tensor_parallel
=
tensor_parallel
,
tensor_parallel
=
tensor_parallel
,
fused_ft_kernel
=
fused_ft_kernel
,
)
)
inference_params_draft
=
model_draft
.
_decoding_cache
.
inference_params
inference_params_draft
=
model_draft
.
_decoding_cache
.
inference_params
inference_params_draft
.
max_sequence_len
=
max_length
inference_params_draft
.
max_sequence_len
=
max_length
inference_params_draft
.
max_batch_size
=
batch_size
inference_params_draft
.
max_batch_size
=
batch_size
inference_params_draft
.
sequence_len_offset
=
0
inference_params_draft
.
sequence_len_offset
=
0
# fused_ft_kernel doesn't support passing in multiple tokens at once
inference_params
=
InferenceParams
(
max_sequence_len
=
max_length
,
max_batch_size
=
batch_size
)
inference_params
=
InferenceParams
(
max_sequence_len
=
max_length
,
max_batch_size
=
batch_size
,
fused_ft_kernel
=
False
)
else
:
else
:
inference_params_draft
=
InferenceParams
(
inference_params_draft
=
InferenceParams
(
max_sequence_len
=
max_length
,
max_batch_size
=
batch_size
,
fused_ft_kernel
=
fused_ft_kernel
max_sequence_len
=
max_length
,
max_batch_size
=
batch_size
)
inference_params
=
InferenceParams
(
max_sequence_len
=
max_length
,
max_batch_size
=
batch_size
,
fused_ft_kernel
=
False
)
)
inference_params
=
InferenceParams
(
max_sequence_len
=
max_length
,
max_batch_size
=
batch_size
)
def
logits_forward_fn
(
model
,
input_ids
,
position_ids
,
inference_params
,
cg
=
False
):
def
logits_forward_fn
(
model
,
input_ids
,
position_ids
,
inference_params
,
cg
=
False
):
if
not
cg
:
if
not
cg
:
...
@@ -606,7 +594,6 @@ def allocate_inference_cache(
...
@@ -606,7 +594,6 @@ def allocate_inference_cache(
layers
:
Union
[
int
,
Sequence
],
layers
:
Union
[
int
,
Sequence
],
device
,
device
,
dtype
=
torch
.
float16
,
dtype
=
torch
.
float16
,
fused_ft_kernel
=
False
,
):
):
assert
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
assert
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
packsize
=
4
if
dtype
==
torch
.
float32
else
8
packsize
=
4
if
dtype
==
torch
.
float32
else
8
...
@@ -616,15 +603,7 @@ def allocate_inference_cache(
...
@@ -616,15 +603,7 @@ def allocate_inference_cache(
kv_cache_shape
=
(
max_batch_size
,
max_seqlen
,
2
,
nheads
,
headdim
)
kv_cache_shape
=
(
max_batch_size
,
max_seqlen
,
2
,
nheads
,
headdim
)
if
isinstance
(
layers
,
int
):
if
isinstance
(
layers
,
int
):
layers
=
range
(
layers
)
layers
=
range
(
layers
)
return
{
return
{
i
:
torch
.
empty
(
kv_cache_shape
,
device
=
device
,
dtype
=
dtype
)
for
i
in
layers
}
i
:
(
torch
.
empty
(
k_cache_shape
,
device
=
device
,
dtype
=
dtype
),
torch
.
empty
(
v_cache_shape
,
device
=
device
,
dtype
=
dtype
),
)
if
fused_ft_kernel
else
torch
.
empty
(
kv_cache_sahpe
,
device
=
device
,
dtype
=
dtype
)
for
i
in
layers
}
def
seqlen_to_seqlen_type
(
seqlen
:
int
)
->
int
:
def
seqlen_to_seqlen_type
(
seqlen
:
int
)
->
int
:
...
@@ -633,12 +612,12 @@ def seqlen_to_seqlen_type(seqlen: int) -> int:
...
@@ -633,12 +612,12 @@ def seqlen_to_seqlen_type(seqlen: int) -> int:
Arguments:
Arguments:
seqlen: int
seqlen: int
"""
"""
return
0
if
seqlen
<
32
else
(
1
if
seqlen
<
2048
else
2
)
return
0
def
seqlen_type_to_max_seqlen
(
seqlen_type
:
int
)
->
int
:
def
seqlen_type_to_max_seqlen
(
seqlen_type
:
int
)
->
int
:
assert
seqlen_type
in
[
0
,
1
,
2
]
assert
seqlen_type
in
[
0
]
return
32
if
seqlen_type
==
0
else
(
2048
if
seqlen_type
==
1
else
2
**
32
)
return
2
**
32
@
dataclass
@
dataclass
...
@@ -663,7 +642,6 @@ def update_graph_cache(
...
@@ -663,7 +642,6 @@ def update_graph_cache(
tensor_parallel
=
1
,
tensor_parallel
=
1
,
dtype
=
None
,
dtype
=
None
,
n_warmups
=
2
,
n_warmups
=
2
,
fused_ft_kernel
=
False
,
):
):
if
cache
is
None
:
if
cache
is
None
:
cache
=
DecodingCGCache
()
cache
=
DecodingCGCache
()
...
@@ -683,9 +661,7 @@ def update_graph_cache(
...
@@ -683,9 +661,7 @@ def update_graph_cache(
cache
.
device
,
cache
.
dtype
=
device
,
dtype
cache
.
device
,
cache
.
dtype
=
device
,
dtype
cache
.
max_batch_size
,
cache
.
max_seqlen
=
batch_size
,
max_seqlen
cache
.
max_batch_size
,
cache
.
max_seqlen
=
batch_size
,
max_seqlen
if
hasattr
(
model
,
"allocate_inference_cache"
):
if
hasattr
(
model
,
"allocate_inference_cache"
):
inf_cache
=
model
.
allocate_inference_cache
(
inf_cache
=
model
.
allocate_inference_cache
(
batch_size
,
max_seqlen
,
dtype
)
batch_size
,
max_seqlen
,
dtype
,
fused_ft_kernel
=
fused_ft_kernel
)
else
:
else
:
headdim
=
getattr
(
headdim
=
getattr
(
model
.
config
,
model
.
config
,
...
@@ -700,7 +676,6 @@ def update_graph_cache(
...
@@ -700,7 +676,6 @@ def update_graph_cache(
model
.
config
.
num_hidden_layers
,
model
.
config
.
num_hidden_layers
,
device
,
device
,
dtype
,
dtype
,
fused_ft_kernel
=
fused_ft_kernel
,
)
)
lengths_per_sample
=
torch
.
full
((
batch_size
,),
seqlen_og
,
dtype
=
torch
.
int32
,
device
=
device
)
lengths_per_sample
=
torch
.
full
((
batch_size
,),
seqlen_og
,
dtype
=
torch
.
int32
,
device
=
device
)
cache
.
inference_params
=
InferenceParams
(
cache
.
inference_params
=
InferenceParams
(
...
@@ -708,7 +683,6 @@ def update_graph_cache(
...
@@ -708,7 +683,6 @@ def update_graph_cache(
max_batch_size
=
batch_size
,
max_batch_size
=
batch_size
,
sequence_len_offset
=
seqlen_og
,
sequence_len_offset
=
seqlen_og
,
key_value_memory_dict
=
inf_cache
,
key_value_memory_dict
=
inf_cache
,
fused_ft_kernel
=
fused_ft_kernel
,
lengths_per_sample
=
lengths_per_sample
,
lengths_per_sample
=
lengths_per_sample
,
)
)
cache
.
mempool
=
torch
.
cuda
.
graphs
.
graph_pool_handle
()
cache
.
mempool
=
torch
.
cuda
.
graphs
.
graph_pool_handle
()
...
...
setup.py
View file @
dfe29f5e
...
@@ -122,10 +122,10 @@ if not SKIP_CUDA_BUILD:
...
@@ -122,10 +122,10 @@ if not SKIP_CUDA_BUILD:
# cc_flag.append("arch=compute_75,code=sm_75")
# cc_flag.append("arch=compute_75,code=sm_75")
cc_flag
.
append
(
"-gencode"
)
cc_flag
.
append
(
"-gencode"
)
cc_flag
.
append
(
"arch=compute_80,code=sm_80"
)
cc_flag
.
append
(
"arch=compute_80,code=sm_80"
)
if
CUDA_HOME
is
not
None
:
#
if CUDA_HOME is not None:
if
bare_metal_version
>=
Version
(
"11.8"
):
#
if bare_metal_version >= Version("11.8"):
cc_flag
.
append
(
"-gencode"
)
#
cc_flag.append("-gencode")
cc_flag
.
append
(
"arch=compute_90,code=sm_90"
)
#
cc_flag.append("arch=compute_90,code=sm_90")
# HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
# HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
# torch._C._GLIBCXX_USE_CXX11_ABI
# torch._C._GLIBCXX_USE_CXX11_ABI
...
...
tests/models/test_baichuan.py
View file @
dfe29f5e
...
@@ -217,9 +217,8 @@ def test_baichuan_parallel_forward(model_name, world_size):
...
@@ -217,9 +217,8 @@ def test_baichuan_parallel_forward(model_name, world_size):
).
abs
().
max
().
item
()
).
abs
().
max
().
item
()
@
pytest
.
mark
.
parametrize
(
"fused_ft_kernel"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"baichuan-inc/Baichuan-7B"
])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"baichuan-inc/Baichuan-7B"
])
def
test_baichuan_generation
(
model_name
,
fused_ft_kernel
):
def
test_baichuan_generation
(
model_name
):
dtype
=
torch
.
float16
dtype
=
torch
.
float16
device
=
"cuda"
device
=
"cuda"
config
=
baichuan_config_to_gpt2_config
(
config
=
baichuan_config_to_gpt2_config
(
...
@@ -236,8 +235,8 @@ def test_baichuan_generation(model_name, fused_ft_kernel):
...
@@ -236,8 +235,8 @@ def test_baichuan_generation(model_name, fused_ft_kernel):
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
batch_size
=
1
batch_size
=
1
seqlen
=
100
seqlen
=
2048
max_length
=
150
max_length
=
2048
+
150
input_ids
=
torch
.
randint
(
input_ids
=
torch
.
randint
(
0
,
config
.
vocab_size
,
(
batch_size
,
seqlen
),
dtype
=
torch
.
long
,
device
=
device
0
,
config
.
vocab_size
,
(
batch_size
,
seqlen
),
dtype
=
torch
.
long
,
device
=
device
)
)
...
@@ -285,7 +284,6 @@ def test_baichuan_generation(model_name, fused_ft_kernel):
...
@@ -285,7 +284,6 @@ def test_baichuan_generation(model_name, fused_ft_kernel):
input_ids
=
input_ids
,
input_ids
=
input_ids
,
max_length
=
max_length
,
max_length
=
max_length
,
eos_token_id
=
eos_token_id
,
eos_token_id
=
eos_token_id
,
fused_ft_kernel
=
fused_ft_kernel
,
return_dict_in_generate
=
True
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
output_scores
=
True
,
enable_timing
=
True
,
enable_timing
=
True
,
...
@@ -296,16 +294,13 @@ def test_baichuan_generation(model_name, fused_ft_kernel):
...
@@ -296,16 +294,13 @@ def test_baichuan_generation(model_name, fused_ft_kernel):
# Capture graph outside the timing loop
# Capture graph outside the timing loop
batch_size
,
seqlen_og
=
input_ids
.
shape
batch_size
,
seqlen_og
=
input_ids
.
shape
model
.
_decoding_cache
=
update_graph_cache
(
model
.
_decoding_cache
=
update_graph_cache
(
model
,
None
,
batch_size
,
seqlen_og
,
max_length
)
model
,
None
,
batch_size
,
seqlen_og
,
max_length
,
fused_ft_kernel
=
fused_ft_kernel
)
print
(
"With CUDA graph"
)
print
(
"With CUDA graph"
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
start
=
time
.
time
()
out_cg
=
model
.
generate
(
out_cg
=
model
.
generate
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
max_length
=
max_length
,
max_length
=
max_length
,
fused_ft_kernel
=
fused_ft_kernel
,
cg
=
True
,
cg
=
True
,
return_dict_in_generate
=
True
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
output_scores
=
True
,
...
@@ -403,9 +398,7 @@ def test_baichuan_parallel_generation(model_name, world_size):
...
@@ -403,9 +398,7 @@ def test_baichuan_parallel_generation(model_name, world_size):
# Capture graph outside the timing loop
# Capture graph outside the timing loop
batch_size
,
seqlen_og
=
input_ids
.
shape
batch_size
,
seqlen_og
=
input_ids
.
shape
model
.
_decoding_cache
=
update_graph_cache
(
model
.
_decoding_cache
=
update_graph_cache
(
model
,
None
,
batch_size
,
seqlen_og
,
max_length
)
model
,
None
,
batch_size
,
seqlen_og
,
max_length
,
fused_ft_kernel
=
False
)
print
(
"With CUDA graph"
)
print
(
"With CUDA graph"
)
out_cg
=
model
.
generate
(
out_cg
=
model
.
generate
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
...
...
tests/models/test_bigcode.py
View file @
dfe29f5e
...
@@ -141,7 +141,6 @@ def test_bigcode_generation(model_name):
...
@@ -141,7 +141,6 @@ def test_bigcode_generation(model_name):
input_ids
=
input_ids
,
input_ids
=
input_ids
,
max_length
=
max_length
,
max_length
=
max_length
,
eos_token_id
=
eos_token_id
,
eos_token_id
=
eos_token_id
,
fused_ft_kernel
=
True
,
return_dict_in_generate
=
True
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
output_scores
=
True
,
enable_timing
=
True
,
enable_timing
=
True
,
...
@@ -159,7 +158,6 @@ def test_bigcode_generation(model_name):
...
@@ -159,7 +158,6 @@ def test_bigcode_generation(model_name):
out_cg
=
model
.
generate
(
out_cg
=
model
.
generate
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
max_length
=
max_length
,
max_length
=
max_length
,
fused_ft_kernel
=
True
,
cg
=
True
,
cg
=
True
,
return_dict_in_generate
=
True
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
output_scores
=
True
,
...
...
tests/models/test_falcon.py
View file @
dfe29f5e
...
@@ -242,7 +242,6 @@ def test_falcon_generation(model_name):
...
@@ -242,7 +242,6 @@ def test_falcon_generation(model_name):
input_ids
=
input_ids
,
input_ids
=
input_ids
,
max_length
=
max_length
,
max_length
=
max_length
,
eos_token_id
=
eos_token_id
,
eos_token_id
=
eos_token_id
,
fused_ft_kernel
=
True
,
return_dict_in_generate
=
True
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
output_scores
=
True
,
enable_timing
=
True
,
enable_timing
=
True
,
...
@@ -253,16 +252,13 @@ def test_falcon_generation(model_name):
...
@@ -253,16 +252,13 @@ def test_falcon_generation(model_name):
# Capture graph outside the timing loop
# Capture graph outside the timing loop
batch_size
,
seqlen_og
=
input_ids
.
shape
batch_size
,
seqlen_og
=
input_ids
.
shape
model
.
_decoding_cache
=
update_graph_cache
(
model
.
_decoding_cache
=
update_graph_cache
(
model
,
None
,
batch_size
,
seqlen_og
,
max_length
)
model
,
None
,
batch_size
,
seqlen_og
,
max_length
,
fused_ft_kernel
=
True
)
print
(
"With CUDA graph"
)
print
(
"With CUDA graph"
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
start
=
time
.
time
()
out_cg
=
model
.
generate
(
out_cg
=
model
.
generate
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
max_length
=
max_length
,
max_length
=
max_length
,
fused_ft_kernel
=
True
,
cg
=
True
,
cg
=
True
,
return_dict_in_generate
=
True
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
output_scores
=
True
,
...
@@ -349,7 +345,6 @@ def test_falcon_parallel_generation(model_name, world_size):
...
@@ -349,7 +345,6 @@ def test_falcon_parallel_generation(model_name, world_size):
max_length
=
max_length
,
max_length
=
max_length
,
tensor_parallel
=
world_size
,
tensor_parallel
=
world_size
,
vocab_size
=
config
.
vocab_size
,
vocab_size
=
config
.
vocab_size
,
fused_ft_kernel
=
True
,
# teacher_outputs=out_hf.sequences,
# teacher_outputs=out_hf.sequences,
return_dict_in_generate
=
True
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
output_scores
=
True
,
...
@@ -358,16 +353,13 @@ def test_falcon_parallel_generation(model_name, world_size):
...
@@ -358,16 +353,13 @@ def test_falcon_parallel_generation(model_name, world_size):
# Capture graph outside the timing loop
# Capture graph outside the timing loop
batch_size
,
seqlen_og
=
input_ids
.
shape
batch_size
,
seqlen_og
=
input_ids
.
shape
model
.
_decoding_cache
=
update_graph_cache
(
model
.
_decoding_cache
=
update_graph_cache
(
model
,
None
,
batch_size
,
seqlen_og
,
max_length
)
model
,
None
,
batch_size
,
seqlen_og
,
max_length
,
fused_ft_kernel
=
True
)
print
(
"With CUDA graph"
)
print
(
"With CUDA graph"
)
out_cg
=
model
.
generate
(
out_cg
=
model
.
generate
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
max_length
=
max_length
,
max_length
=
max_length
,
tensor_parallel
=
world_size
,
tensor_parallel
=
world_size
,
vocab_size
=
config
.
vocab_size
,
vocab_size
=
config
.
vocab_size
,
fused_ft_kernel
=
True
,
cg
=
True
,
cg
=
True
,
# teacher_outputs=out_hf.sequences,
# teacher_outputs=out_hf.sequences,
return_dict_in_generate
=
True
,
return_dict_in_generate
=
True
,
...
...
tests/models/test_gpt.py
View file @
dfe29f5e
...
@@ -134,14 +134,12 @@ def test_gpt2_optimized(model_name):
...
@@ -134,14 +134,12 @@ def test_gpt2_optimized(model_name):
).
abs
().
max
().
item
()
).
abs
().
max
().
item
()
@
pytest
.
mark
.
parametrize
(
"fused_ft_kernel"
,
[
False
,
True
])
# @pytest.mark.parametrize('fused_ft_kernel', [True])
@
pytest
.
mark
.
parametrize
(
"optimized"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"optimized"
,
[
False
,
True
])
# @pytest.mark.parametrize('optimized', [True])
# @pytest.mark.parametrize('optimized', [True])
@
pytest
.
mark
.
parametrize
(
"rotary"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"rotary"
,
[
False
,
True
])
# @pytest.mark.parametrize('rotary', [False])
# @pytest.mark.parametrize('rotary', [False])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"gpt2"
])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"gpt2"
])
def
test_gpt2_generation
(
model_name
,
rotary
,
optimized
,
fused_ft_kernel
):
def
test_gpt2_generation
(
model_name
,
rotary
,
optimized
):
"""Check that our implementation of GPT2 generation matches the HF implementation:
"""Check that our implementation of GPT2 generation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the HF scores in fp32.
the HF scores in fp32.
...
@@ -202,18 +200,16 @@ def test_gpt2_generation(model_name, rotary, optimized, fused_ft_kernel):
...
@@ -202,18 +200,16 @@ def test_gpt2_generation(model_name, rotary, optimized, fused_ft_kernel):
out
=
model
.
generate
(
out
=
model
.
generate
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
max_length
=
max_length
,
max_length
=
max_length
,
fused_ft_kernel
=
fused_ft_kernel
,
return_dict_in_generate
=
True
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
output_scores
=
True
,
enable_timing
=
True
,
enable_timing
=
True
,
)
)
print
(
out
.
sequences
)
print
(
out
.
sequences
)
print
(
tokenizer
.
batch_decode
(
out
.
sequences
.
tolist
()))
print
(
tokenizer
.
batch_decode
(
out
.
sequences
.
tolist
()))
if
fused_ft_kernel
or
getattr
(
config
,
"use_flash_attn"
,
False
):
if
getattr
(
config
,
"use_flash_attn"
,
False
):
out_cg
=
model
.
generate
(
out_cg
=
model
.
generate
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
max_length
=
max_length
,
max_length
=
max_length
,
fused_ft_kernel
=
fused_ft_kernel
,
cg
=
True
,
cg
=
True
,
return_dict_in_generate
=
True
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
output_scores
=
True
,
...
@@ -282,10 +278,8 @@ def get_logits(model, input_ids, max_length, teacher_outputs=None, **kwargs):
...
@@ -282,10 +278,8 @@ def get_logits(model, input_ids, max_length, teacher_outputs=None, **kwargs):
# @pytest.mark.parametrize('seqlen,maxlen', [(10, 20)])
# @pytest.mark.parametrize('seqlen,maxlen', [(10, 20)])
@
pytest
.
mark
.
parametrize
(
"rotary"
,
[
None
,
"interleaved"
,
"contiguous"
])
@
pytest
.
mark
.
parametrize
(
"rotary"
,
[
None
,
"interleaved"
,
"contiguous"
])
# @pytest.mark.parametrize('rotary', [None])
# @pytest.mark.parametrize('rotary', [None])
@
pytest
.
mark
.
parametrize
(
"fused_ft_kernel"
,
[
False
,
True
])
# @pytest.mark.parametrize("fused_ft_kernel", [False])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"gpt2"
])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"gpt2"
])
def
test_gpt2_generation_cg
(
model_name
,
fused_ft_kernel
,
rotary
,
seqlen
,
maxlen
):
def
test_gpt2_generation_cg
(
model_name
,
rotary
,
seqlen
,
maxlen
):
"""Check that decoding with CUDA graph is the same as decoding without CUDA graph."""
"""Check that decoding with CUDA graph is the same as decoding without CUDA graph."""
dtype
=
torch
.
float16
dtype
=
torch
.
float16
device
=
"cuda"
device
=
"cuda"
...
@@ -315,17 +309,8 @@ def test_gpt2_generation_cg(model_name, fused_ft_kernel, rotary, seqlen, maxlen)
...
@@ -315,17 +309,8 @@ def test_gpt2_generation_cg(model_name, fused_ft_kernel, rotary, seqlen, maxlen)
0
,
config
.
vocab_size
,
(
batch_size
,
maxlen
),
dtype
=
torch
.
long
,
device
=
device
0
,
config
.
vocab_size
,
(
batch_size
,
maxlen
),
dtype
=
torch
.
long
,
device
=
device
)
)
logits
=
get_logits
(
logits
=
get_logits
(
model
,
input_ids
,
maxlen
,
teacher_outputs
=
teacher_outputs
)
model
,
input_ids
,
maxlen
,
teacher_outputs
=
teacher_outputs
,
fused_ft_kernel
=
fused_ft_kernel
logits_cg
=
get_logits
(
model
,
input_ids
,
maxlen
,
teacher_outputs
=
teacher_outputs
,
cg
=
True
)
)
logits_cg
=
get_logits
(
model
,
input_ids
,
maxlen
,
teacher_outputs
=
teacher_outputs
,
fused_ft_kernel
=
fused_ft_kernel
,
cg
=
True
,
)
assert
torch
.
equal
(
logits
,
logits_cg
)
assert
torch
.
equal
(
logits
,
logits_cg
)
# Try increasing batch size and seqlen, then decrease them to see if it's still correct
# Try increasing batch size and seqlen, then decrease them to see if it's still correct
...
@@ -369,7 +354,6 @@ def test_gpt2_multiple_token_generation(model_name, optimized):
...
@@ -369,7 +354,6 @@ def test_gpt2_multiple_token_generation(model_name, optimized):
config
.
fused_bias_fc
=
True
config
.
fused_bias_fc
=
True
config
.
fused_mlp
=
True
config
.
fused_mlp
=
True
config
.
fused_dropout_add_ln
=
True
config
.
fused_dropout_add_ln
=
True
# fused_ft_kernel currently doesn't work with multiple tokens at a time
model
=
GPTLMHeadModel
.
from_pretrained
(
model_name
,
config
,
device
=
device
,
dtype
=
dtype
)
model
=
GPTLMHeadModel
.
from_pretrained
(
model_name
,
config
,
device
=
device
,
dtype
=
dtype
)
model
.
eval
()
model
.
eval
()
...
@@ -398,13 +382,12 @@ def test_gpt2_multiple_token_generation(model_name, optimized):
...
@@ -398,13 +382,12 @@ def test_gpt2_multiple_token_generation(model_name, optimized):
assert
torch
.
allclose
(
logits
,
logits_ref
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
logits
,
logits_ref
,
rtol
=
rtol
,
atol
=
atol
)
@
pytest
.
mark
.
parametrize
(
"fused_ft_kernel, cg"
,
[(
False
,
False
),
(
True
,
False
),
(
True
,
True
)])
@
pytest
.
mark
.
parametrize
(
"cg"
,
[
False
,
True
])
# @pytest.mark.parametrize("fused_ft_kernel, cg", [(True, True)])
# @pytest.mark.parametrize("optimized", [False, True])
# @pytest.mark.parametrize("optimized", [False, True])
@
pytest
.
mark
.
parametrize
(
"optimized"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"optimized"
,
[
True
])
# @pytest.mark.parametrize("model_name", ["gpt2-medium"])
# @pytest.mark.parametrize("model_name", ["gpt2-medium"])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"gpt2-xl"
])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"gpt2-xl"
])
def
test_gpt2_speculative_decoding
(
model_name
,
optimized
,
fused_ft_kernel
,
cg
):
def
test_gpt2_speculative_decoding
(
model_name
,
optimized
,
cg
):
dtype
=
torch
.
float16
dtype
=
torch
.
float16
device
=
"cuda"
device
=
"cuda"
rtol
,
atol
=
3e-3
,
3e-1
rtol
,
atol
=
3e-3
,
3e-1
...
@@ -444,7 +427,6 @@ def test_gpt2_speculative_decoding(model_name, optimized, fused_ft_kernel, cg):
...
@@ -444,7 +427,6 @@ def test_gpt2_speculative_decoding(model_name, optimized, fused_ft_kernel, cg):
model_draft
,
model_draft
,
max_length
=
max_length
,
max_length
=
max_length
,
top_k
=
5
,
top_k
=
5
,
fused_ft_kernel
=
fused_ft_kernel
,
cg
=
cg
,
cg
=
cg
,
speculative_lookahead
=
4
,
speculative_lookahead
=
4
,
enable_timing
=
True
,
enable_timing
=
True
,
...
@@ -454,7 +436,6 @@ def test_gpt2_speculative_decoding(model_name, optimized, fused_ft_kernel, cg):
...
@@ -454,7 +436,6 @@ def test_gpt2_speculative_decoding(model_name, optimized, fused_ft_kernel, cg):
input_ids
,
input_ids
,
max_length
=
max_length
,
max_length
=
max_length
,
top_k
=
5
,
top_k
=
5
,
fused_ft_kernel
=
fused_ft_kernel
,
cg
=
False
,
cg
=
False
,
enable_timing
=
True
,
enable_timing
=
True
,
return_dict_in_generate
=
True
,
return_dict_in_generate
=
True
,
...
...
tests/models/test_gpt_generation_parallel.py
View file @
dfe29f5e
...
@@ -15,12 +15,10 @@ from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHead
...
@@ -15,12 +15,10 @@ from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHead
# @pytest.mark.parametrize('world_size', [1, 2, 4, 8])
# @pytest.mark.parametrize('world_size', [1, 2, 4, 8])
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
])
# @pytest.mark.parametrize('fused_ft_kernel', [False, True])
@
pytest
.
mark
.
parametrize
(
'rotary'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"fused_ft_kernel"
,
[
True
])
# @pytest.mark.parametrize("rotary", [False])
# @pytest.mark.parametrize('rotary', [False, True])
@
pytest
.
mark
.
parametrize
(
"rotary"
,
[
False
])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"gpt2"
])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"gpt2"
])
def
test_tensor_parallel
(
model_name
,
rotary
,
fused_ft_kernel
,
world_size
):
def
test_tensor_parallel
(
model_name
,
rotary
,
world_size
):
"""Check that our implementation of GPT2 generation matches the HF implementation:
"""Check that our implementation of GPT2 generation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the HF scores in fp32.
the HF scores in fp32.
...
@@ -111,19 +109,17 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
...
@@ -111,19 +109,17 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
max_length
=
max_length
,
max_length
=
max_length
,
tensor_parallel
=
world_size
,
tensor_parallel
=
world_size
,
vocab_size
=
config
.
vocab_size
,
vocab_size
=
config
.
vocab_size
,
fused_ft_kernel
=
fused_ft_kernel
,
return_dict_in_generate
=
True
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
output_scores
=
True
,
enable_timing
=
True
,
enable_timing
=
True
,
)
)
print
(
out
.
sequences
)
print
(
out
.
sequences
)
if
f
use
d
_f
t_kernel
:
if
getattr
(
config
,
"
use_f
lash_attn"
,
False
)
:
out_cg
=
model
.
generate
(
out_cg
=
model
.
generate
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
max_length
=
max_length
,
max_length
=
max_length
,
tensor_parallel
=
world_size
,
tensor_parallel
=
world_size
,
vocab_size
=
config
.
vocab_size
,
vocab_size
=
config
.
vocab_size
,
fused_ft_kernel
=
fused_ft_kernel
,
cg
=
True
,
cg
=
True
,
return_dict_in_generate
=
True
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
output_scores
=
True
,
...
...
tests/models/test_gptj.py
View file @
dfe29f5e
...
@@ -83,9 +83,8 @@ def test_gptj_optimized(model_name):
...
@@ -83,9 +83,8 @@ def test_gptj_optimized(model_name):
).
abs
().
max
().
item
()
).
abs
().
max
().
item
()
@
pytest
.
mark
.
parametrize
(
"fused_ft_kernel"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"EleutherAI/gpt-j-6B"
])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"EleutherAI/gpt-j-6B"
])
def
test_gptj_generation
(
model_name
,
fused_ft_kernel
):
def
test_gptj_generation
(
model_name
):
"""Check that our implementation of GPT-J (with all optimizations enabled) matches the
"""Check that our implementation of GPT-J (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
forward pass in fp16, when compared to the HF forward pass in fp32.
...
@@ -141,7 +140,6 @@ def test_gptj_generation(model_name, fused_ft_kernel):
...
@@ -141,7 +140,6 @@ def test_gptj_generation(model_name, fused_ft_kernel):
input_ids
=
input_ids
,
input_ids
=
input_ids
,
max_length
=
max_length
,
max_length
=
max_length
,
eos_token_id
=
eos_token_id
,
eos_token_id
=
eos_token_id
,
fused_ft_kernel
=
fused_ft_kernel
,
return_dict_in_generate
=
True
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
output_scores
=
True
,
enable_timing
=
True
,
enable_timing
=
True
,
...
@@ -152,16 +150,13 @@ def test_gptj_generation(model_name, fused_ft_kernel):
...
@@ -152,16 +150,13 @@ def test_gptj_generation(model_name, fused_ft_kernel):
# Capture graph outside the timing loop
# Capture graph outside the timing loop
batch_size
,
seqlen_og
=
input_ids
.
shape
batch_size
,
seqlen_og
=
input_ids
.
shape
model
.
_decoding_cache
=
update_graph_cache
(
model
.
_decoding_cache
=
update_graph_cache
(
model
,
None
,
batch_size
,
seqlen_og
,
max_length
)
model
,
None
,
batch_size
,
seqlen_og
,
max_length
,
fused_ft_kernel
=
fused_ft_kernel
)
print
(
"With CUDA graph"
)
print
(
"With CUDA graph"
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
start
=
time
.
time
()
out_cg
=
model
.
generate
(
out_cg
=
model
.
generate
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
max_length
=
max_length
,
max_length
=
max_length
,
fused_ft_kernel
=
fused_ft_kernel
,
cg
=
True
,
cg
=
True
,
return_dict_in_generate
=
True
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
output_scores
=
True
,
...
...
tests/models/test_llama.py
View file @
dfe29f5e
...
@@ -292,7 +292,6 @@ def test_llama_generation(model_name, checkpoint_format):
...
@@ -292,7 +292,6 @@ def test_llama_generation(model_name, checkpoint_format):
input_ids
=
input_ids
,
input_ids
=
input_ids
,
max_length
=
max_length
,
max_length
=
max_length
,
eos_token_id
=
eos_token_id
,
eos_token_id
=
eos_token_id
,
fused_ft_kernel
=
True
,
return_dict_in_generate
=
True
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
output_scores
=
True
,
enable_timing
=
True
,
enable_timing
=
True
,
...
@@ -303,16 +302,13 @@ def test_llama_generation(model_name, checkpoint_format):
...
@@ -303,16 +302,13 @@ def test_llama_generation(model_name, checkpoint_format):
# Capture graph outside the timing loop
# Capture graph outside the timing loop
batch_size
,
seqlen_og
=
input_ids
.
shape
batch_size
,
seqlen_og
=
input_ids
.
shape
model
.
_decoding_cache
=
update_graph_cache
(
model
.
_decoding_cache
=
update_graph_cache
(
model
,
None
,
batch_size
,
seqlen_og
,
max_length
)
model
,
None
,
batch_size
,
seqlen_og
,
max_length
,
fused_ft_kernel
=
True
)
print
(
"With CUDA graph"
)
print
(
"With CUDA graph"
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
start
=
time
.
time
()
out_cg
=
model
.
generate
(
out_cg
=
model
.
generate
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
max_length
=
max_length
,
max_length
=
max_length
,
fused_ft_kernel
=
True
,
cg
=
True
,
cg
=
True
,
return_dict_in_generate
=
True
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
output_scores
=
True
,
...
@@ -401,7 +397,6 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
...
@@ -401,7 +397,6 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
max_length
=
max_length
,
max_length
=
max_length
,
tensor_parallel
=
world_size
,
tensor_parallel
=
world_size
,
vocab_size
=
config
.
vocab_size
,
vocab_size
=
config
.
vocab_size
,
fused_ft_kernel
=
True
,
# teacher_outputs=out_hf.sequences,
# teacher_outputs=out_hf.sequences,
return_dict_in_generate
=
True
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
output_scores
=
True
,
...
@@ -410,16 +405,13 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
...
@@ -410,16 +405,13 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
# Capture graph outside the timing loop
# Capture graph outside the timing loop
batch_size
,
seqlen_og
=
input_ids
.
shape
batch_size
,
seqlen_og
=
input_ids
.
shape
model
.
_decoding_cache
=
update_graph_cache
(
model
.
_decoding_cache
=
update_graph_cache
(
model
,
None
,
batch_size
,
seqlen_og
,
max_length
)
model
,
None
,
batch_size
,
seqlen_og
,
max_length
,
fused_ft_kernel
=
True
)
print
(
"With CUDA graph"
)
print
(
"With CUDA graph"
)
out_cg
=
model
.
generate
(
out_cg
=
model
.
generate
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
max_length
=
max_length
,
max_length
=
max_length
,
tensor_parallel
=
world_size
,
tensor_parallel
=
world_size
,
vocab_size
=
config
.
vocab_size
,
vocab_size
=
config
.
vocab_size
,
fused_ft_kernel
=
True
,
cg
=
True
,
cg
=
True
,
# teacher_outputs=out_hf.sequences,
# teacher_outputs=out_hf.sequences,
return_dict_in_generate
=
True
,
return_dict_in_generate
=
True
,
...
...
tests/models/test_opt.py
View file @
dfe29f5e
...
@@ -107,7 +107,6 @@ def test_opt_generation(model_name):
...
@@ -107,7 +107,6 @@ def test_opt_generation(model_name):
dtype
=
torch
.
float16
dtype
=
torch
.
float16
device
=
"cuda"
device
=
"cuda"
rtol
,
atol
=
3e-3
,
3e-1
rtol
,
atol
=
3e-3
,
3e-1
fused_ft_kernel
=
True
config
=
opt_config_to_gpt2_config
(
OPTConfig
.
from_pretrained
(
model_name
))
config
=
opt_config_to_gpt2_config
(
OPTConfig
.
from_pretrained
(
model_name
))
# Only prenorm supports residual_in_fp32
# Only prenorm supports residual_in_fp32
config
.
residual_in_fp32
=
getattr
(
config
,
"prenorm"
,
True
)
config
.
residual_in_fp32
=
getattr
(
config
,
"prenorm"
,
True
)
...
@@ -155,7 +154,6 @@ def test_opt_generation(model_name):
...
@@ -155,7 +154,6 @@ def test_opt_generation(model_name):
input_ids
=
input_ids
,
input_ids
=
input_ids
,
max_length
=
max_length
,
max_length
=
max_length
,
eos_token_id
=
eos_token_id
,
eos_token_id
=
eos_token_id
,
fused_ft_kernel
=
fused_ft_kernel
,
return_dict_in_generate
=
True
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
output_scores
=
True
,
enable_timing
=
True
,
enable_timing
=
True
,
...
@@ -165,19 +163,16 @@ def test_opt_generation(model_name):
...
@@ -165,19 +163,16 @@ def test_opt_generation(model_name):
if
verbose
:
if
verbose
:
print
(
out
.
sequences
)
print
(
out
.
sequences
)
print
(
tokenizer
.
batch_decode
(
out
.
sequences
.
tolist
()))
print
(
tokenizer
.
batch_decode
(
out
.
sequences
.
tolist
()))
if
f
use
d
_f
t_kernel
:
if
getattr
(
config
,
"
use_f
lash_attn"
,
False
)
:
# Capture graph outside the timing loop
# Capture graph outside the timing loop
batch_size
,
seqlen_og
=
input_ids
.
shape
batch_size
,
seqlen_og
=
input_ids
.
shape
model
.
_decoding_cache
=
update_graph_cache
(
model
.
_decoding_cache
=
update_graph_cache
(
model
,
None
,
batch_size
,
seqlen_og
,
max_length
)
model
,
None
,
batch_size
,
seqlen_og
,
max_length
,
fused_ft_kernel
=
True
)
print
(
"With CUDA graph"
)
print
(
"With CUDA graph"
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
start
=
time
.
time
()
out_cg
=
model
.
generate
(
out_cg
=
model
.
generate
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
max_length
=
max_length
,
max_length
=
max_length
,
fused_ft_kernel
=
fused_ft_kernel
,
cg
=
True
,
cg
=
True
,
return_dict_in_generate
=
True
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
output_scores
=
True
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment