Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
dfe29f5e
Commit
dfe29f5e
authored
Sep 18, 2023
by
Tri Dao
Browse files
[Gen] Don't use ft_attention, use flash_attn_with_kvcache instead
parent
3250ff3d
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
154 additions
and
313 deletions
+154
-313
csrc/ft_attention/README.md
csrc/ft_attention/README.md
+6
-0
flash_attn/modules/mha.py
flash_attn/modules/mha.py
+111
-192
flash_attn/utils/generation.py
flash_attn/utils/generation.py
+9
-35
setup.py
setup.py
+4
-4
tests/models/test_baichuan.py
tests/models/test_baichuan.py
+5
-12
tests/models/test_bigcode.py
tests/models/test_bigcode.py
+0
-2
tests/models/test_falcon.py
tests/models/test_falcon.py
+2
-10
tests/models/test_gpt.py
tests/models/test_gpt.py
+7
-26
tests/models/test_gpt_generation_parallel.py
tests/models/test_gpt_generation_parallel.py
+4
-8
tests/models/test_gptj.py
tests/models/test_gptj.py
+2
-7
tests/models/test_llama.py
tests/models/test_llama.py
+2
-10
tests/models/test_opt.py
tests/models/test_opt.py
+2
-7
No files found.
csrc/ft_attention/README.md
View file @
dfe29f5e
...
...
@@ -6,3 +6,9 @@ FasterTransformer v5.2.1 for benchmarking purpose.
```
sh
cd
csrc/ft_attention
&&
pip
install
.
```
As of 2023-09-17, this extension is no longer used in the FlashAttention repo.
FlashAttention now has implemented
[
`flash_attn_with_kvcache`
](
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attention_interface.py
)
with all the features of this
`ft_attention`
kernel (and more).
flash_attn/modules/mha.py
View file @
dfe29f5e
...
...
@@ -32,11 +32,6 @@ try:
except
ImportError
:
RotaryEmbedding
=
None
try
:
import
ft_attention
except
ImportError
:
ft_attention
=
None
class
FlashSelfAttention
(
nn
.
Module
):
"""Implement the scaled dot product attention with softmax.
...
...
@@ -314,14 +309,7 @@ def _update_kv_cache(kv, inference_params, layer_idx):
)
inference_params
.
key_value_memory_dict
[
layer_idx
]
=
kv_cache
else
:
if
not
inference_params
.
fused_ft_kernel
:
kv_cache
=
inference_params
.
key_value_memory_dict
[
layer_idx
]
else
:
# For FT, k_cache has shape (b, h, headdim / packsize, s, packsize)
# where packsize = 4 if fp32, 8 if fp16 or bf16.
# v_cache has shape (b, h, s, headdim)
k_cache
,
v_cache
=
inference_params
.
key_value_memory_dict
[
layer_idx
]
kv_cache
=
None
kv_cache
=
inference_params
.
key_value_memory_dict
[
layer_idx
]
# Adjust key and value for inference
batch_start
=
inference_params
.
batch_size_offset
batch_end
=
batch_start
+
kv
.
shape
[
0
]
...
...
@@ -329,79 +317,9 @@ def _update_kv_cache(kv, inference_params, layer_idx):
sequence_end
=
sequence_start
+
kv
.
shape
[
1
]
assert
batch_end
<=
(
kv_cache
.
shape
[
0
]
if
kv_cache
is
not
None
else
v_cache
.
shape
[
0
])
assert
sequence_end
<=
(
kv_cache
.
shape
[
1
]
if
kv_cache
is
not
None
else
v_cache
.
shape
[
2
])
# Copy key and values.
if
not
inference_params
.
fused_ft_kernel
:
assert
kv_cache
is
not
None
kv_cache
[
batch_start
:
batch_end
,
sequence_start
:
sequence_end
,
...]
=
kv
kv
=
kv_cache
[
batch_start
:
batch_end
,
:
sequence_end
,
...]
return
kv
else
:
assert
inference_params
.
sequence_len_offset
==
0
# FT kernel requires different layouts for the k_cache and v_cache.
assert
kv
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
packsize
=
4
if
kv
.
dtype
==
torch
.
float32
else
8
if
kv_cache
is
not
None
:
kv_cache
[
batch_start
:
batch_end
,
sequence_start
:
sequence_end
,
...]
=
kv
k_cache
=
rearrange
(
kv_cache
[:,
:,
0
],
"b s h (d packsize) -> b h d s packsize"
,
packsize
=
packsize
).
contiguous
()
v_cache
=
rearrange
(
kv_cache
[:,
:,
1
],
"b s h d -> b h s d"
).
contiguous
()
inference_params
.
key_value_memory_dict
[
layer_idx
]
=
(
k_cache
,
v_cache
)
else
:
k_cache
[
batch_start
:
batch_end
,
:,
:,
:
sequence_end
,
:]
=
rearrange
(
kv
[:,
:,
0
],
"b s h (d packsize) -> b h d s packsize"
,
packsize
=
packsize
)
v_cache
[
batch_start
:
batch_end
,
:,
:
sequence_end
,
:]
=
rearrange
(
kv
[:,
:,
1
],
"b s h d -> b h s d"
)
return
kv
def
_apply_rotary_single_query_attention
(
qkv
,
inference_params
,
layer_idx
,
rotary_emb_dim
,
rotary_emb_base
,
kv
=
None
,
rotary_emb_interleaved
=
False
,
):
"""
qkv: (batch_size, 1, 3, nheads, head_dim) if kv is None else it's just
q of shape (batch_size, 1, nheads, head_dim)
kv: (batch_size, 1, 2, nheads_kv, head_dim)
"""
assert
inference_params
.
fused_ft_kernel
assert
ft_attention
is
not
None
if
kv
is
None
:
q
,
k
,
v
=
rearrange
(
qkv
,
"b 1 three h d -> b three h d"
).
unbind
(
dim
=
1
)
else
:
q
=
rearrange
(
qkv
,
"b 1 h d -> b h d"
)
k
,
v
=
rearrange
(
kv
,
"b 1 two h d -> b two h d"
).
unbind
(
dim
=
1
)
batch_start
=
inference_params
.
batch_size_offset
batch_end
=
batch_start
+
q
.
shape
[
0
]
k_cache
,
v_cache
=
inference_params
.
key_value_memory_dict
[
layer_idx
]
lengths_per_sample
=
(
inference_params
.
lengths_per_sample
[
batch_start
:
batch_end
]
if
inference_params
.
lengths_per_sample
is
not
None
else
None
)
context
=
ft_attention
.
single_query_attention
(
q
,
k
,
v
,
k_cache
[
batch_start
:
batch_end
],
v_cache
[
batch_start
:
batch_end
],
lengths_per_sample
,
None
,
# rotary_cos_
None
,
# rotary_sin_
None
,
# nnz_head_idx
inference_params
.
sequence_len_offset
,
rotary_emb_dim
,
rotary_emb_base
,
not
rotary_emb_interleaved
,
# neox_rotary_style
)
return
rearrange
(
context
,
"b h d -> b 1 h d"
)
assert
kv_cache
is
not
None
kv_cache
[
batch_start
:
batch_end
,
sequence_start
:
sequence_end
,
...]
=
kv
return
kv_cache
[
batch_start
:
batch_end
,
:
sequence_end
,
...]
class
MHA
(
nn
.
Module
):
...
...
@@ -502,36 +420,18 @@ class MHA(nn.Module):
)
self
.
out_proj
=
linear_cls
(
embed_dim
,
embed_dim
,
bias
=
out_proj_bias
,
**
factory_kwargs
)
def
allocate_inference_cache
(
self
,
batch_size
,
max_seqlen
,
dtype
=
None
,
fused_ft_kernel
=
True
):
def
allocate_inference_cache
(
self
,
batch_size
,
max_seqlen
,
dtype
=
None
):
dtype
=
self
.
out_proj
.
weight
.
dtype
if
dtype
is
None
else
dtype
device
=
self
.
out_proj
.
weight
.
device
if
not
fused_ft_kernel
:
return
torch
.
empty
(
batch_size
,
max_seqlen
,
2
,
self
.
num_heads_kv
,
self
.
head_dim
,
dtype
=
dtype
,
device
=
device
,
)
else
:
assert
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
packsize
=
4
if
dtype
==
torch
.
float32
else
8
assert
self
.
head_dim
%
packsize
==
0
k_cache
=
torch
.
empty
(
batch_size
,
self
.
num_heads_kv
,
self
.
head_dim
//
packsize
,
max_seqlen
,
packsize
,
dtype
=
dtype
,
device
=
device
,
)
v_cache
=
torch
.
empty
(
batch_size
,
self
.
num_heads_kv
,
max_seqlen
,
self
.
head_dim
,
dtype
=
dtype
,
device
=
device
)
return
k_cache
,
v_cache
return
torch
.
empty
(
batch_size
,
max_seqlen
,
2
,
self
.
num_heads_kv
,
self
.
head_dim
,
dtype
=
dtype
,
device
=
device
,
)
def
_update_kv_cache
(
self
,
kv
,
inference_params
):
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
...
...
@@ -539,27 +439,46 @@ class MHA(nn.Module):
assert
self
.
layer_idx
is
not
None
,
"Generation requires layer_idx in the constructor"
return
_update_kv_cache
(
kv
,
inference_params
,
self
.
layer_idx
)
def
_apply_rotary_
single_query
_attention
(
self
,
qkv
,
inference_params
,
kv
=
None
):
def
_apply_rotary_
update_kvcache
_attention
(
self
,
q
,
kv
,
inference_params
):
"""
qkv: (batch_size, 1, 3, nheads, head_dim) if kv is None else it's just
q of shape
(batch_size,
1
, nheads, head_dim)
kv: (batch_size,
1
, 2, nheads_kv, head_dim)
Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
q:
(batch_size,
seqlen_q
, nheads, head_dim)
kv: (batch_size,
seqlen_k
, 2, nheads_kv, head_dim)
"""
rotary_emb_base
=
self
.
rotary_emb
.
base
if
self
.
rotary_emb_dim
>
0
else
0
return
_apply_rotary_single_query_attention
(
qkv
,
inference_params
,
self
.
layer_idx
,
self
.
rotary_emb_dim
,
rotary_emb_base
,
kv
=
kv
,
rotary_emb_interleaved
=
self
.
rotary_emb
.
interleaved
if
self
.
rotary_emb_dim
>
0
else
False
,
assert
inference_params
is
not
None
and
inference_params
.
sequence_len_offset
>
0
assert
self
.
use_flash_attn
if
self
.
rotary_emb_dim
>
0
:
assert
self
.
rotary_emb
.
scale
is
None
,
"This code path does not support xPos"
self
.
rotary_emb
.
_update_cos_sin_cache
(
inference_params
.
max_sequence_len
,
device
=
q
.
device
,
dtype
=
q
.
dtype
)
rotary_cos
,
rotary_sin
=
self
.
rotary_emb
.
_cos_cached
,
self
.
rotary_emb
.
_sin_cached
else
:
rotary_cos
,
rotary_sin
=
None
,
None
batch
=
q
.
shape
[
0
]
kv_cache
=
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
][:
batch
]
cache_seqlens
=
(
inference_params
.
lengths_per_sample
[:
batch
]
if
inference_params
.
lengths_per_sample
is
not
None
else
inference_params
.
sequence_len_offset
)
context
=
flash_attn_with_kvcache
(
q
,
kv_cache
[:,
:,
0
],
kv_cache
[:,
:,
1
],
kv
[:,
:,
0
],
kv
[:,
:,
1
],
rotary_cos
=
rotary_cos
,
rotary_sin
=
rotary_sin
,
cache_seqlens
=
cache_seqlens
,
softmax_scale
=
self
.
inner_cross_attn
.
softmax_scale
,
causal
=
self
.
inner_cross_attn
.
causal
,
rotary_interleaved
=
self
.
rotary_emb
.
interleaved
if
self
.
rotary_emb_dim
>
0
else
False
,
)
return
context
def
_update_kvcache_attention
(
self
,
q
,
kv
,
inference_params
):
"""Write kv to inference_params, then do attention
"""
"""Write kv to inference_params, then do attention"""
if
(
inference_params
.
sequence_len_offset
==
0
or
flash_attn_with_kvcache
is
None
...
...
@@ -663,7 +582,8 @@ class MHA(nn.Module):
if
(
inference_params
is
None
or
inference_params
.
sequence_len_offset
==
0
or
not
inference_params
.
fused_ft_kernel
or
(
self
.
rotary_emb_dim
==
0
or
self
.
rotary_emb_dim
%
16
!=
0
)
or
not
self
.
use_flash_attn
):
if
self
.
rotary_emb_dim
>
0
:
qkv
=
self
.
rotary_emb
(
...
...
@@ -679,7 +599,9 @@ class MHA(nn.Module):
qkv
[:,
:,
0
],
qkv
[:,
:,
1
:],
inference_params
)
else
:
context
=
self
.
_apply_rotary_single_query_attention
(
qkv
,
inference_params
)
context
=
self
.
_apply_rotary_update_kvcache_attention
(
qkv
[:,
:,
0
],
qkv
[:,
:,
1
:],
inference_params
)
else
:
if
self
.
cross_attn
:
if
not
self
.
return_residual
:
...
...
@@ -711,7 +633,8 @@ class MHA(nn.Module):
if
(
inference_params
is
None
or
inference_params
.
sequence_len_offset
==
0
or
not
inference_params
.
fused_ft_kernel
or
(
self
.
rotary_emb_dim
==
0
or
self
.
rotary_emb_dim
%
16
!=
0
)
or
not
self
.
use_flash_attn
):
if
self
.
rotary_emb_dim
>
0
:
q
,
kv
=
self
.
rotary_emb
(
...
...
@@ -727,7 +650,7 @@ class MHA(nn.Module):
else
:
context
=
self
.
_update_kvcache_attention
(
q
,
kv
,
inference_params
)
else
:
context
=
self
.
_apply_rotary_
single_query
_attention
(
q
,
inference_params
,
kv
=
kv
)
context
=
self
.
_apply_rotary_
update_kvcache
_attention
(
q
,
kv
,
inference_params
)
out
=
self
.
out_proj
(
rearrange
(
context
,
"... h d -> ... (h d)"
))
return
out
if
not
self
.
return_residual
else
(
out
,
x
)
...
...
@@ -825,73 +748,65 @@ class ParallelMHA(nn.Module):
**
factory_kwargs
,
)
def
allocate_inference_cache
(
self
,
batch_size
,
max_seqlen
,
dtype
=
None
,
fused_ft_kernel
=
True
):
def
allocate_inference_cache
(
self
,
batch_size
,
max_seqlen
,
dtype
=
None
):
dtype
=
self
.
out_proj
.
weight
.
dtype
if
dtype
is
None
else
dtype
device
=
self
.
out_proj
.
weight
.
device
if
not
fused_ft_kernel
:
return
torch
.
empty
(
batch_size
,
max_seqlen
,
2
,
self
.
num_heads_kv_per_rank
,
self
.
head_dim
,
dtype
=
dtype
,
device
=
device
,
)
else
:
assert
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
packsize
=
4
if
dtype
==
torch
.
float32
else
8
assert
self
.
head_dim
%
packsize
==
0
k_cache
=
torch
.
empty
(
batch_size
,
self
.
num_heads_kv_per_rank
,
self
.
head_dim
//
packsize
,
max_seqlen
,
packsize
,
dtype
=
dtype
,
device
=
device
,
)
v_cache
=
torch
.
empty
(
batch_size
,
self
.
num_heads_kv_per_rank
,
max_seqlen
,
self
.
head_dim
,
dtype
=
dtype
,
device
=
device
,
)
return
k_cache
,
v_cache
return
torch
.
empty
(
batch_size
,
max_seqlen
,
2
,
self
.
num_heads_kv_per_rank
,
self
.
head_dim
,
dtype
=
dtype
,
device
=
device
,
)
def
_update_kv_cache
(
self
,
kv
,
inference_params
):
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
assert
self
.
layer_idx
is
not
None
,
"Generation requires layer_idx in the constructor"
return
_update_kv_cache
(
kv
,
inference_params
,
self
.
layer_idx
)
def
_apply_rotary_
single_query
_attention
(
self
,
qkv
,
inference_params
,
kv
=
None
):
def
_apply_rotary_
update_kvcache
_attention
(
self
,
q
,
kv
,
inference_params
):
"""
qkv: (batch_size, 1, 3, nheads, head_dim) if kv is None else it's just
q of shape
(batch_size,
1
, nheads, head_dim)
kv: (batch_size,
1
, 2, nheads_kv, head_dim)
Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
q:
(batch_size,
seqlen_q
, nheads, head_dim)
kv: (batch_size,
seqlen_k
, 2, nheads_kv, head_dim)
"""
rotary_emb_base
=
self
.
rotary_emb
.
base
if
self
.
rotary_emb_dim
>
0
else
0
return
_apply_rotary_single_query_attention
(
qkv
,
inference_params
,
self
.
layer_idx
,
self
.
rotary_emb_dim
,
rotary_emb_base
,
kv
=
kv
,
rotary_emb_interleaved
=
self
.
rotary_emb
.
interleaved
if
self
.
rotary_emb_dim
>
0
else
False
,
assert
inference_params
is
not
None
and
inference_params
.
sequence_len_offset
>
0
assert
self
.
use_flash_attn
if
self
.
rotary_emb_dim
>
0
:
assert
self
.
rotary_emb
.
scale
is
None
,
"This code path does not support xPos"
self
.
rotary_emb
.
_update_cos_sin_cache
(
inference_params
.
max_sequence_len
,
device
=
q
.
device
,
dtype
=
q
.
dtype
)
rotary_cos
,
rotary_sin
=
self
.
rotary_emb
.
_cos_cached
,
self
.
rotary_emb
.
_sin_cached
else
:
rotary_cos
,
rotary_sin
=
None
,
None
batch
=
q
.
shape
[
0
]
kv_cache
=
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
][:
batch
]
cache_seqlens
=
(
inference_params
.
lengths_per_sample
[:
batch
]
if
inference_params
.
lengths_per_sample
is
not
None
else
inference_params
.
sequence_len_offset
)
context
=
flash_attn_with_kvcache
(
q
,
kv_cache
[:,
:,
0
],
kv_cache
[:,
:,
1
],
kv
[:,
:,
0
],
kv
[:,
:,
1
],
rotary_cos
=
rotary_cos
,
rotary_sin
=
rotary_sin
,
cache_seqlens
=
cache_seqlens
,
softmax_scale
=
self
.
inner_cross_attn
.
softmax_scale
,
causal
=
self
.
inner_cross_attn
.
causal
,
rotary_interleaved
=
self
.
rotary_emb
.
interleaved
if
self
.
rotary_emb_dim
>
0
else
False
,
)
return
context
def
_update_kvcache_attention
(
self
,
q
,
kv
,
inference_params
):
"""Write kv to inference_params, then do attention """
if
(
inference_params
.
sequence_len_offset
==
0
or
flash_attn_with_kvcache
is
None
or
not
self
.
use_flash_attn
):
"""Write kv to inference_params, then do attention"""
if
inference_params
.
sequence_len_offset
==
0
or
not
self
.
use_flash_attn
:
# TODO: this only uses sequence_len_offset and not lengths_per_sample.
kv
=
self
.
_update_kv_cache
(
kv
,
inference_params
)
return
self
.
inner_cross_attn
(
q
,
kv
)
...
...
@@ -943,7 +858,8 @@ class ParallelMHA(nn.Module):
if
(
inference_params
is
None
or
inference_params
.
sequence_len_offset
==
0
or
not
inference_params
.
fused_ft_kernel
or
(
self
.
rotary_emb_dim
==
0
or
self
.
rotary_emb_dim
%
16
!=
0
)
or
not
self
.
use_flash_attn
):
if
self
.
rotary_emb_dim
>
0
:
qkv
=
self
.
rotary_emb
(
...
...
@@ -959,7 +875,9 @@ class ParallelMHA(nn.Module):
qkv
[:,
:,
0
],
qkv
[:,
:,
1
:],
inference_params
)
else
:
context
=
self
.
_apply_rotary_single_query_attention
(
qkv
,
inference_params
)
context
=
self
.
_apply_rotary_update_kvcache_attention
(
qkv
[:,
:,
0
],
qkv
[:,
:,
1
:],
inference_params
)
else
:
q
=
rearrange
(
qkv
[...,
:
self
.
num_heads_per_rank
*
self
.
head_dim
],
...
...
@@ -975,7 +893,8 @@ class ParallelMHA(nn.Module):
if
(
inference_params
is
None
or
inference_params
.
sequence_len_offset
==
0
or
not
inference_params
.
fused_ft_kernel
or
(
self
.
rotary_emb_dim
==
0
or
self
.
rotary_emb_dim
%
16
!=
0
)
or
not
self
.
use_flash_attn
):
if
self
.
rotary_emb_dim
>
0
:
q
,
kv
=
self
.
rotary_emb
(
...
...
@@ -991,7 +910,7 @@ class ParallelMHA(nn.Module):
else
:
context
=
self
.
_update_kvcache_attention
(
q
,
kv
,
inference_params
)
else
:
context
=
self
.
_apply_rotary_
single_query
_attention
(
q
,
inference_params
,
kv
=
kv
)
context
=
self
.
_apply_rotary_
update_kvcache
_attention
(
q
,
kv
,
inference_params
)
context
=
rearrange
(
context
,
"b s h d -> b s (h d)"
)
if
seqlen
is
not
None
:
context
=
rearrange
(
context
,
"b s d -> (b s) d"
)
...
...
flash_attn/utils/generation.py
View file @
dfe29f5e
...
...
@@ -25,7 +25,6 @@ class InferenceParams:
sequence_len_offset
:
int
=
0
batch_size_offset
:
int
=
0
key_value_memory_dict
:
dict
=
field
(
default_factory
=
dict
)
fused_ft_kernel
:
bool
=
False
lengths_per_sample
:
Optional
[
Tensor
]
=
None
...
...
@@ -96,7 +95,6 @@ def decode(
teacher_outputs
=
None
,
vocab_size
=
None
,
tensor_parallel
=
1
,
fused_ft_kernel
=
False
,
cg
=
False
,
enable_timing
=
False
,
):
...
...
@@ -127,7 +125,6 @@ def decode(
seqlen_og
,
max_length
,
tensor_parallel
=
tensor_parallel
,
fused_ft_kernel
=
fused_ft_kernel
,
)
inference_params
=
model
.
_decoding_cache
.
inference_params
inference_params
.
max_sequence_len
=
max_length
...
...
@@ -135,9 +132,7 @@ def decode(
inference_params
.
sequence_len_offset
=
0
inference_params
.
lengths_per_sample
.
zero_
()
else
:
inference_params
=
InferenceParams
(
max_sequence_len
=
max_length
,
max_batch_size
=
batch_size
,
fused_ft_kernel
=
fused_ft_kernel
)
inference_params
=
InferenceParams
(
max_sequence_len
=
max_length
,
max_batch_size
=
batch_size
)
def
get_logits
(
input_ids
,
inference_params
):
decoding
=
inference_params
.
sequence_len_offset
>
0
...
...
@@ -273,7 +268,6 @@ def decode_speculative(
eos_token_id
=
None
,
vocab_size
=
None
,
tensor_parallel
=
1
,
fused_ft_kernel
=
False
,
cg
=
False
,
enable_timing
=
False
,
debug
=
False
,
...
...
@@ -307,23 +301,17 @@ def decode_speculative(
seqlen_og
,
max_length
,
tensor_parallel
=
tensor_parallel
,
fused_ft_kernel
=
fused_ft_kernel
,
)
inference_params_draft
=
model_draft
.
_decoding_cache
.
inference_params
inference_params_draft
.
max_sequence_len
=
max_length
inference_params_draft
.
max_batch_size
=
batch_size
inference_params_draft
.
sequence_len_offset
=
0
# fused_ft_kernel doesn't support passing in multiple tokens at once
inference_params
=
InferenceParams
(
max_sequence_len
=
max_length
,
max_batch_size
=
batch_size
,
fused_ft_kernel
=
False
)
inference_params
=
InferenceParams
(
max_sequence_len
=
max_length
,
max_batch_size
=
batch_size
)
else
:
inference_params_draft
=
InferenceParams
(
max_sequence_len
=
max_length
,
max_batch_size
=
batch_size
,
fused_ft_kernel
=
fused_ft_kernel
)
inference_params
=
InferenceParams
(
max_sequence_len
=
max_length
,
max_batch_size
=
batch_size
,
fused_ft_kernel
=
False
max_sequence_len
=
max_length
,
max_batch_size
=
batch_size
)
inference_params
=
InferenceParams
(
max_sequence_len
=
max_length
,
max_batch_size
=
batch_size
)
def
logits_forward_fn
(
model
,
input_ids
,
position_ids
,
inference_params
,
cg
=
False
):
if
not
cg
:
...
...
@@ -606,7 +594,6 @@ def allocate_inference_cache(
layers
:
Union
[
int
,
Sequence
],
device
,
dtype
=
torch
.
float16
,
fused_ft_kernel
=
False
,
):
assert
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
packsize
=
4
if
dtype
==
torch
.
float32
else
8
...
...
@@ -616,15 +603,7 @@ def allocate_inference_cache(
kv_cache_shape
=
(
max_batch_size
,
max_seqlen
,
2
,
nheads
,
headdim
)
if
isinstance
(
layers
,
int
):
layers
=
range
(
layers
)
return
{
i
:
(
torch
.
empty
(
k_cache_shape
,
device
=
device
,
dtype
=
dtype
),
torch
.
empty
(
v_cache_shape
,
device
=
device
,
dtype
=
dtype
),
)
if
fused_ft_kernel
else
torch
.
empty
(
kv_cache_sahpe
,
device
=
device
,
dtype
=
dtype
)
for
i
in
layers
}
return
{
i
:
torch
.
empty
(
kv_cache_shape
,
device
=
device
,
dtype
=
dtype
)
for
i
in
layers
}
def
seqlen_to_seqlen_type
(
seqlen
:
int
)
->
int
:
...
...
@@ -633,12 +612,12 @@ def seqlen_to_seqlen_type(seqlen: int) -> int:
Arguments:
seqlen: int
"""
return
0
if
seqlen
<
32
else
(
1
if
seqlen
<
2048
else
2
)
return
0
def
seqlen_type_to_max_seqlen
(
seqlen_type
:
int
)
->
int
:
assert
seqlen_type
in
[
0
,
1
,
2
]
return
32
if
seqlen_type
==
0
else
(
2048
if
seqlen_type
==
1
else
2
**
32
)
assert
seqlen_type
in
[
0
]
return
2
**
32
@
dataclass
...
...
@@ -663,7 +642,6 @@ def update_graph_cache(
tensor_parallel
=
1
,
dtype
=
None
,
n_warmups
=
2
,
fused_ft_kernel
=
False
,
):
if
cache
is
None
:
cache
=
DecodingCGCache
()
...
...
@@ -683,9 +661,7 @@ def update_graph_cache(
cache
.
device
,
cache
.
dtype
=
device
,
dtype
cache
.
max_batch_size
,
cache
.
max_seqlen
=
batch_size
,
max_seqlen
if
hasattr
(
model
,
"allocate_inference_cache"
):
inf_cache
=
model
.
allocate_inference_cache
(
batch_size
,
max_seqlen
,
dtype
,
fused_ft_kernel
=
fused_ft_kernel
)
inf_cache
=
model
.
allocate_inference_cache
(
batch_size
,
max_seqlen
,
dtype
)
else
:
headdim
=
getattr
(
model
.
config
,
...
...
@@ -700,7 +676,6 @@ def update_graph_cache(
model
.
config
.
num_hidden_layers
,
device
,
dtype
,
fused_ft_kernel
=
fused_ft_kernel
,
)
lengths_per_sample
=
torch
.
full
((
batch_size
,),
seqlen_og
,
dtype
=
torch
.
int32
,
device
=
device
)
cache
.
inference_params
=
InferenceParams
(
...
...
@@ -708,7 +683,6 @@ def update_graph_cache(
max_batch_size
=
batch_size
,
sequence_len_offset
=
seqlen_og
,
key_value_memory_dict
=
inf_cache
,
fused_ft_kernel
=
fused_ft_kernel
,
lengths_per_sample
=
lengths_per_sample
,
)
cache
.
mempool
=
torch
.
cuda
.
graphs
.
graph_pool_handle
()
...
...
setup.py
View file @
dfe29f5e
...
...
@@ -122,10 +122,10 @@ if not SKIP_CUDA_BUILD:
# cc_flag.append("arch=compute_75,code=sm_75")
cc_flag
.
append
(
"-gencode"
)
cc_flag
.
append
(
"arch=compute_80,code=sm_80"
)
if
CUDA_HOME
is
not
None
:
if
bare_metal_version
>=
Version
(
"11.8"
):
cc_flag
.
append
(
"-gencode"
)
cc_flag
.
append
(
"arch=compute_90,code=sm_90"
)
#
if CUDA_HOME is not None:
#
if bare_metal_version >= Version("11.8"):
#
cc_flag.append("-gencode")
#
cc_flag.append("arch=compute_90,code=sm_90")
# HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
# torch._C._GLIBCXX_USE_CXX11_ABI
...
...
tests/models/test_baichuan.py
View file @
dfe29f5e
...
...
@@ -217,9 +217,8 @@ def test_baichuan_parallel_forward(model_name, world_size):
).
abs
().
max
().
item
()
@
pytest
.
mark
.
parametrize
(
"fused_ft_kernel"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"baichuan-inc/Baichuan-7B"
])
def
test_baichuan_generation
(
model_name
,
fused_ft_kernel
):
def
test_baichuan_generation
(
model_name
):
dtype
=
torch
.
float16
device
=
"cuda"
config
=
baichuan_config_to_gpt2_config
(
...
...
@@ -236,8 +235,8 @@ def test_baichuan_generation(model_name, fused_ft_kernel):
torch
.
manual_seed
(
0
)
batch_size
=
1
seqlen
=
100
max_length
=
150
seqlen
=
2048
max_length
=
2048
+
150
input_ids
=
torch
.
randint
(
0
,
config
.
vocab_size
,
(
batch_size
,
seqlen
),
dtype
=
torch
.
long
,
device
=
device
)
...
...
@@ -285,7 +284,6 @@ def test_baichuan_generation(model_name, fused_ft_kernel):
input_ids
=
input_ids
,
max_length
=
max_length
,
eos_token_id
=
eos_token_id
,
fused_ft_kernel
=
fused_ft_kernel
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
enable_timing
=
True
,
...
...
@@ -296,16 +294,13 @@ def test_baichuan_generation(model_name, fused_ft_kernel):
# Capture graph outside the timing loop
batch_size
,
seqlen_og
=
input_ids
.
shape
model
.
_decoding_cache
=
update_graph_cache
(
model
,
None
,
batch_size
,
seqlen_og
,
max_length
,
fused_ft_kernel
=
fused_ft_kernel
)
model
.
_decoding_cache
=
update_graph_cache
(
model
,
None
,
batch_size
,
seqlen_og
,
max_length
)
print
(
"With CUDA graph"
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
out_cg
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
fused_ft_kernel
=
fused_ft_kernel
,
cg
=
True
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
...
...
@@ -403,9 +398,7 @@ def test_baichuan_parallel_generation(model_name, world_size):
# Capture graph outside the timing loop
batch_size
,
seqlen_og
=
input_ids
.
shape
model
.
_decoding_cache
=
update_graph_cache
(
model
,
None
,
batch_size
,
seqlen_og
,
max_length
,
fused_ft_kernel
=
False
)
model
.
_decoding_cache
=
update_graph_cache
(
model
,
None
,
batch_size
,
seqlen_og
,
max_length
)
print
(
"With CUDA graph"
)
out_cg
=
model
.
generate
(
input_ids
=
input_ids
,
...
...
tests/models/test_bigcode.py
View file @
dfe29f5e
...
...
@@ -141,7 +141,6 @@ def test_bigcode_generation(model_name):
input_ids
=
input_ids
,
max_length
=
max_length
,
eos_token_id
=
eos_token_id
,
fused_ft_kernel
=
True
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
enable_timing
=
True
,
...
...
@@ -159,7 +158,6 @@ def test_bigcode_generation(model_name):
out_cg
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
fused_ft_kernel
=
True
,
cg
=
True
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
...
...
tests/models/test_falcon.py
View file @
dfe29f5e
...
...
@@ -242,7 +242,6 @@ def test_falcon_generation(model_name):
input_ids
=
input_ids
,
max_length
=
max_length
,
eos_token_id
=
eos_token_id
,
fused_ft_kernel
=
True
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
enable_timing
=
True
,
...
...
@@ -253,16 +252,13 @@ def test_falcon_generation(model_name):
# Capture graph outside the timing loop
batch_size
,
seqlen_og
=
input_ids
.
shape
model
.
_decoding_cache
=
update_graph_cache
(
model
,
None
,
batch_size
,
seqlen_og
,
max_length
,
fused_ft_kernel
=
True
)
model
.
_decoding_cache
=
update_graph_cache
(
model
,
None
,
batch_size
,
seqlen_og
,
max_length
)
print
(
"With CUDA graph"
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
out_cg
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
fused_ft_kernel
=
True
,
cg
=
True
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
...
...
@@ -349,7 +345,6 @@ def test_falcon_parallel_generation(model_name, world_size):
max_length
=
max_length
,
tensor_parallel
=
world_size
,
vocab_size
=
config
.
vocab_size
,
fused_ft_kernel
=
True
,
# teacher_outputs=out_hf.sequences,
return_dict_in_generate
=
True
,
output_scores
=
True
,
...
...
@@ -358,16 +353,13 @@ def test_falcon_parallel_generation(model_name, world_size):
# Capture graph outside the timing loop
batch_size
,
seqlen_og
=
input_ids
.
shape
model
.
_decoding_cache
=
update_graph_cache
(
model
,
None
,
batch_size
,
seqlen_og
,
max_length
,
fused_ft_kernel
=
True
)
model
.
_decoding_cache
=
update_graph_cache
(
model
,
None
,
batch_size
,
seqlen_og
,
max_length
)
print
(
"With CUDA graph"
)
out_cg
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
tensor_parallel
=
world_size
,
vocab_size
=
config
.
vocab_size
,
fused_ft_kernel
=
True
,
cg
=
True
,
# teacher_outputs=out_hf.sequences,
return_dict_in_generate
=
True
,
...
...
tests/models/test_gpt.py
View file @
dfe29f5e
...
...
@@ -134,14 +134,12 @@ def test_gpt2_optimized(model_name):
).
abs
().
max
().
item
()
@
pytest
.
mark
.
parametrize
(
"fused_ft_kernel"
,
[
False
,
True
])
# @pytest.mark.parametrize('fused_ft_kernel', [True])
@
pytest
.
mark
.
parametrize
(
"optimized"
,
[
False
,
True
])
# @pytest.mark.parametrize('optimized', [True])
@
pytest
.
mark
.
parametrize
(
"rotary"
,
[
False
,
True
])
# @pytest.mark.parametrize('rotary', [False])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"gpt2"
])
def
test_gpt2_generation
(
model_name
,
rotary
,
optimized
,
fused_ft_kernel
):
def
test_gpt2_generation
(
model_name
,
rotary
,
optimized
):
"""Check that our implementation of GPT2 generation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the HF scores in fp32.
...
...
@@ -202,18 +200,16 @@ def test_gpt2_generation(model_name, rotary, optimized, fused_ft_kernel):
out
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
fused_ft_kernel
=
fused_ft_kernel
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
enable_timing
=
True
,
)
print
(
out
.
sequences
)
print
(
tokenizer
.
batch_decode
(
out
.
sequences
.
tolist
()))
if
fused_ft_kernel
or
getattr
(
config
,
"use_flash_attn"
,
False
):
if
getattr
(
config
,
"use_flash_attn"
,
False
):
out_cg
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
fused_ft_kernel
=
fused_ft_kernel
,
cg
=
True
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
...
...
@@ -282,10 +278,8 @@ def get_logits(model, input_ids, max_length, teacher_outputs=None, **kwargs):
# @pytest.mark.parametrize('seqlen,maxlen', [(10, 20)])
@
pytest
.
mark
.
parametrize
(
"rotary"
,
[
None
,
"interleaved"
,
"contiguous"
])
# @pytest.mark.parametrize('rotary', [None])
@
pytest
.
mark
.
parametrize
(
"fused_ft_kernel"
,
[
False
,
True
])
# @pytest.mark.parametrize("fused_ft_kernel", [False])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"gpt2"
])
def
test_gpt2_generation_cg
(
model_name
,
fused_ft_kernel
,
rotary
,
seqlen
,
maxlen
):
def
test_gpt2_generation_cg
(
model_name
,
rotary
,
seqlen
,
maxlen
):
"""Check that decoding with CUDA graph is the same as decoding without CUDA graph."""
dtype
=
torch
.
float16
device
=
"cuda"
...
...
@@ -315,17 +309,8 @@ def test_gpt2_generation_cg(model_name, fused_ft_kernel, rotary, seqlen, maxlen)
0
,
config
.
vocab_size
,
(
batch_size
,
maxlen
),
dtype
=
torch
.
long
,
device
=
device
)
logits
=
get_logits
(
model
,
input_ids
,
maxlen
,
teacher_outputs
=
teacher_outputs
,
fused_ft_kernel
=
fused_ft_kernel
)
logits_cg
=
get_logits
(
model
,
input_ids
,
maxlen
,
teacher_outputs
=
teacher_outputs
,
fused_ft_kernel
=
fused_ft_kernel
,
cg
=
True
,
)
logits
=
get_logits
(
model
,
input_ids
,
maxlen
,
teacher_outputs
=
teacher_outputs
)
logits_cg
=
get_logits
(
model
,
input_ids
,
maxlen
,
teacher_outputs
=
teacher_outputs
,
cg
=
True
)
assert
torch
.
equal
(
logits
,
logits_cg
)
# Try increasing batch size and seqlen, then decrease them to see if it's still correct
...
...
@@ -369,7 +354,6 @@ def test_gpt2_multiple_token_generation(model_name, optimized):
config
.
fused_bias_fc
=
True
config
.
fused_mlp
=
True
config
.
fused_dropout_add_ln
=
True
# fused_ft_kernel currently doesn't work with multiple tokens at a time
model
=
GPTLMHeadModel
.
from_pretrained
(
model_name
,
config
,
device
=
device
,
dtype
=
dtype
)
model
.
eval
()
...
...
@@ -398,13 +382,12 @@ def test_gpt2_multiple_token_generation(model_name, optimized):
assert
torch
.
allclose
(
logits
,
logits_ref
,
rtol
=
rtol
,
atol
=
atol
)
@
pytest
.
mark
.
parametrize
(
"fused_ft_kernel, cg"
,
[(
False
,
False
),
(
True
,
False
),
(
True
,
True
)])
# @pytest.mark.parametrize("fused_ft_kernel, cg", [(True, True)])
@
pytest
.
mark
.
parametrize
(
"cg"
,
[
False
,
True
])
# @pytest.mark.parametrize("optimized", [False, True])
@
pytest
.
mark
.
parametrize
(
"optimized"
,
[
True
])
# @pytest.mark.parametrize("model_name", ["gpt2-medium"])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"gpt2-xl"
])
def
test_gpt2_speculative_decoding
(
model_name
,
optimized
,
fused_ft_kernel
,
cg
):
def
test_gpt2_speculative_decoding
(
model_name
,
optimized
,
cg
):
dtype
=
torch
.
float16
device
=
"cuda"
rtol
,
atol
=
3e-3
,
3e-1
...
...
@@ -444,7 +427,6 @@ def test_gpt2_speculative_decoding(model_name, optimized, fused_ft_kernel, cg):
model_draft
,
max_length
=
max_length
,
top_k
=
5
,
fused_ft_kernel
=
fused_ft_kernel
,
cg
=
cg
,
speculative_lookahead
=
4
,
enable_timing
=
True
,
...
...
@@ -454,7 +436,6 @@ def test_gpt2_speculative_decoding(model_name, optimized, fused_ft_kernel, cg):
input_ids
,
max_length
=
max_length
,
top_k
=
5
,
fused_ft_kernel
=
fused_ft_kernel
,
cg
=
False
,
enable_timing
=
True
,
return_dict_in_generate
=
True
,
...
...
tests/models/test_gpt_generation_parallel.py
View file @
dfe29f5e
...
...
@@ -15,12 +15,10 @@ from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHead
# @pytest.mark.parametrize('world_size', [1, 2, 4, 8])
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
])
# @pytest.mark.parametrize('fused_ft_kernel', [False, True])
@
pytest
.
mark
.
parametrize
(
"fused_ft_kernel"
,
[
True
])
# @pytest.mark.parametrize('rotary', [False, True])
@
pytest
.
mark
.
parametrize
(
"rotary"
,
[
False
])
@
pytest
.
mark
.
parametrize
(
'rotary'
,
[
False
,
True
])
# @pytest.mark.parametrize("rotary", [False])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"gpt2"
])
def
test_tensor_parallel
(
model_name
,
rotary
,
fused_ft_kernel
,
world_size
):
def
test_tensor_parallel
(
model_name
,
rotary
,
world_size
):
"""Check that our implementation of GPT2 generation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the HF scores in fp32.
...
...
@@ -111,19 +109,17 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
max_length
=
max_length
,
tensor_parallel
=
world_size
,
vocab_size
=
config
.
vocab_size
,
fused_ft_kernel
=
fused_ft_kernel
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
enable_timing
=
True
,
)
print
(
out
.
sequences
)
if
f
use
d
_f
t_kernel
:
if
getattr
(
config
,
"
use_f
lash_attn"
,
False
)
:
out_cg
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
tensor_parallel
=
world_size
,
vocab_size
=
config
.
vocab_size
,
fused_ft_kernel
=
fused_ft_kernel
,
cg
=
True
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
...
...
tests/models/test_gptj.py
View file @
dfe29f5e
...
...
@@ -83,9 +83,8 @@ def test_gptj_optimized(model_name):
).
abs
().
max
().
item
()
@
pytest
.
mark
.
parametrize
(
"fused_ft_kernel"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"EleutherAI/gpt-j-6B"
])
def
test_gptj_generation
(
model_name
,
fused_ft_kernel
):
def
test_gptj_generation
(
model_name
):
"""Check that our implementation of GPT-J (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
...
...
@@ -141,7 +140,6 @@ def test_gptj_generation(model_name, fused_ft_kernel):
input_ids
=
input_ids
,
max_length
=
max_length
,
eos_token_id
=
eos_token_id
,
fused_ft_kernel
=
fused_ft_kernel
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
enable_timing
=
True
,
...
...
@@ -152,16 +150,13 @@ def test_gptj_generation(model_name, fused_ft_kernel):
# Capture graph outside the timing loop
batch_size
,
seqlen_og
=
input_ids
.
shape
model
.
_decoding_cache
=
update_graph_cache
(
model
,
None
,
batch_size
,
seqlen_og
,
max_length
,
fused_ft_kernel
=
fused_ft_kernel
)
model
.
_decoding_cache
=
update_graph_cache
(
model
,
None
,
batch_size
,
seqlen_og
,
max_length
)
print
(
"With CUDA graph"
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
out_cg
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
fused_ft_kernel
=
fused_ft_kernel
,
cg
=
True
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
...
...
tests/models/test_llama.py
View file @
dfe29f5e
...
...
@@ -292,7 +292,6 @@ def test_llama_generation(model_name, checkpoint_format):
input_ids
=
input_ids
,
max_length
=
max_length
,
eos_token_id
=
eos_token_id
,
fused_ft_kernel
=
True
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
enable_timing
=
True
,
...
...
@@ -303,16 +302,13 @@ def test_llama_generation(model_name, checkpoint_format):
# Capture graph outside the timing loop
batch_size
,
seqlen_og
=
input_ids
.
shape
model
.
_decoding_cache
=
update_graph_cache
(
model
,
None
,
batch_size
,
seqlen_og
,
max_length
,
fused_ft_kernel
=
True
)
model
.
_decoding_cache
=
update_graph_cache
(
model
,
None
,
batch_size
,
seqlen_og
,
max_length
)
print
(
"With CUDA graph"
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
out_cg
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
fused_ft_kernel
=
True
,
cg
=
True
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
...
...
@@ -401,7 +397,6 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
max_length
=
max_length
,
tensor_parallel
=
world_size
,
vocab_size
=
config
.
vocab_size
,
fused_ft_kernel
=
True
,
# teacher_outputs=out_hf.sequences,
return_dict_in_generate
=
True
,
output_scores
=
True
,
...
...
@@ -410,16 +405,13 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
# Capture graph outside the timing loop
batch_size
,
seqlen_og
=
input_ids
.
shape
model
.
_decoding_cache
=
update_graph_cache
(
model
,
None
,
batch_size
,
seqlen_og
,
max_length
,
fused_ft_kernel
=
True
)
model
.
_decoding_cache
=
update_graph_cache
(
model
,
None
,
batch_size
,
seqlen_og
,
max_length
)
print
(
"With CUDA graph"
)
out_cg
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
tensor_parallel
=
world_size
,
vocab_size
=
config
.
vocab_size
,
fused_ft_kernel
=
True
,
cg
=
True
,
# teacher_outputs=out_hf.sequences,
return_dict_in_generate
=
True
,
...
...
tests/models/test_opt.py
View file @
dfe29f5e
...
...
@@ -107,7 +107,6 @@ def test_opt_generation(model_name):
dtype
=
torch
.
float16
device
=
"cuda"
rtol
,
atol
=
3e-3
,
3e-1
fused_ft_kernel
=
True
config
=
opt_config_to_gpt2_config
(
OPTConfig
.
from_pretrained
(
model_name
))
# Only prenorm supports residual_in_fp32
config
.
residual_in_fp32
=
getattr
(
config
,
"prenorm"
,
True
)
...
...
@@ -155,7 +154,6 @@ def test_opt_generation(model_name):
input_ids
=
input_ids
,
max_length
=
max_length
,
eos_token_id
=
eos_token_id
,
fused_ft_kernel
=
fused_ft_kernel
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
enable_timing
=
True
,
...
...
@@ -165,19 +163,16 @@ def test_opt_generation(model_name):
if
verbose
:
print
(
out
.
sequences
)
print
(
tokenizer
.
batch_decode
(
out
.
sequences
.
tolist
()))
if
f
use
d
_f
t_kernel
:
if
getattr
(
config
,
"
use_f
lash_attn"
,
False
)
:
# Capture graph outside the timing loop
batch_size
,
seqlen_og
=
input_ids
.
shape
model
.
_decoding_cache
=
update_graph_cache
(
model
,
None
,
batch_size
,
seqlen_og
,
max_length
,
fused_ft_kernel
=
True
)
model
.
_decoding_cache
=
update_graph_cache
(
model
,
None
,
batch_size
,
seqlen_og
,
max_length
)
print
(
"With CUDA graph"
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
out_cg
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
fused_ft_kernel
=
fused_ft_kernel
,
cg
=
True
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment