Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
a86442f0
Commit
a86442f0
authored
Sep 07, 2023
by
Tri Dao
Browse files
[Gen] Use flash_attn_with_kvcache in generation
parent
a1576ad1
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
134 additions
and
36 deletions
+134
-36
flash_attn/layers/rotary.py
flash_attn/layers/rotary.py
+2
-1
flash_attn/modules/mha.py
flash_attn/modules/mha.py
+96
-16
flash_attn/utils/generation.py
flash_attn/utils/generation.py
+24
-9
tests/models/test_baichuan.py
tests/models/test_baichuan.py
+8
-8
tests/models/test_gpt.py
tests/models/test_gpt.py
+4
-2
No files found.
flash_attn/layers/rotary.py
View file @
a86442f0
...
...
@@ -146,7 +146,8 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
# Call 1 kernel instead of 2 kernels
# We need qkv to be contiguous so that when we reshape to combine (3, nheads)
# dimensions, we get the same tensor
qk
=
rearrange
(
qkv
[:,
:,
:
2
],
"b s t h d -> b s (t h) d"
)
# qk = rearrange(qkv[:, :, :2], "b s t h d -> b s (t h) d")
qk
=
qkv
[:,
:,
:
2
].
reshape
(
batch
,
seqlen
,
-
1
,
headdim
)
apply_rotary
(
qk
,
cos
,
sin
,
seqlen_offsets
=
seqlen_offsets
,
interleaved
=
interleaved
,
inplace
=
True
)
...
...
flash_attn/modules/mha.py
View file @
a86442f0
...
...
@@ -15,10 +15,12 @@ try:
flash_attn_qkvpacked_func
,
flash_attn_varlen_kvpacked_func
,
flash_attn_varlen_qkvpacked_func
,
flash_attn_with_kvcache
,
)
except
ImportError
:
flash_attn_varlen_qkvpacked_func
,
flash_attn_varlen_kvpacked_func
=
None
,
None
flash_attn_qkvpacked_func
,
flash_attn_kvpacked_func
=
None
,
None
flash_attn_with_kvcache
=
None
try
:
from
flash_attn.ops.fused_dense
import
ColumnParallelLinear
,
FusedDense
,
RowParallelLinear
...
...
@@ -556,6 +558,35 @@ class MHA(nn.Module):
else
False
,
)
def
_update_kvcache_attention
(
self
,
q
,
kv
,
inference_params
):
"""Write kv to inference_params, then do attention """
if
(
inference_params
.
sequence_len_offset
==
0
or
flash_attn_with_kvcache
is
None
or
not
self
.
use_flash_attn
):
# TODO: this only uses sequence_len_offset and not lengths_per_sample.
kv
=
self
.
_update_kv_cache
(
kv
,
inference_params
)
return
self
.
inner_cross_attn
(
q
,
kv
)
else
:
batch
=
q
.
shape
[
0
]
kv_cache
=
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
][:
batch
]
cache_seqlens
=
(
inference_params
.
lengths_per_sample
[:
batch
]
if
inference_params
.
lengths_per_sample
is
not
None
else
inference_params
.
sequence_len_offset
)
return
flash_attn_with_kvcache
(
q
,
kv_cache
[:,
:,
0
],
kv_cache
[:,
:,
1
],
kv
[:,
:,
0
],
kv
[:,
:,
1
],
cache_seqlens
=
cache_seqlens
,
softmax_scale
=
self
.
inner_cross_attn
.
softmax_scale
,
causal
=
self
.
inner_cross_attn
.
causal
,
)
def
forward
(
self
,
x
,
...
...
@@ -605,10 +636,19 @@ class MHA(nn.Module):
if
self
.
use_flash_attn
else
{
"key_padding_mask"
:
key_padding_mask
,
**
kwargs
}
)
seqlen_offset
=
0
if
inference_params
is
None
else
inference_params
.
sequence_len_offset
seqlen_offset
=
(
0
if
inference_params
is
None
else
(
inference_params
.
lengths_per_sample
if
inference_params
.
lengths_per_sample
is
not
None
else
inference_params
.
sequence_len_offset
)
)
rotary_max_seqlen
=
(
inference_params
.
max_sequence_len
if
inference_params
is
not
None
else
None
)
batch
,
seqlen
=
x
.
shape
[:
2
]
if
not
self
.
cross_attn
and
self
.
num_heads_kv
==
self
.
num_heads
:
assert
x_kv
is
None
and
mixer_subset
is
None
if
not
self
.
return_residual
:
...
...
@@ -619,7 +659,8 @@ class MHA(nn.Module):
qkv
=
rearrange
(
self
.
dwconv_qkv
(
rearrange
(
qkv
,
"b s d -> b d s"
))[...,
:
-
2
],
"b d s -> b s d"
).
contiguous
()
qkv
=
rearrange
(
qkv
,
"... (three h d) -> ... three h d"
,
three
=
3
,
d
=
self
.
head_dim
)
# qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
qkv
=
qkv
.
reshape
(
batch
,
seqlen
,
3
,
self
.
num_heads
,
self
.
head_dim
)
if
(
inference_params
is
None
or
inference_params
.
sequence_len_offset
==
0
...
...
@@ -635,9 +676,9 @@ class MHA(nn.Module):
else
:
context
=
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
inner_attn
,
qkv
,
**
kwargs
)
else
:
q
=
qkv
[:,
:,
0
]
kv
=
self
.
_update_kv_cache
(
qkv
[:,
:,
1
:],
inference_params
)
context
=
self
.
inner_cross_attn
(
q
,
kv
)
context
=
self
.
_update_kvcache_attention
(
qkv
[:,
:,
0
],
qkv
[:,
:,
1
:],
inference_params
)
else
:
context
=
self
.
_apply_rotary_single_query_attention
(
qkv
,
inference_params
)
else
:
...
...
@@ -659,8 +700,10 @@ class MHA(nn.Module):
qkv
,
x
=
self
.
Wqkv
(
x
)
q
=
qkv
[...,
:
self
.
num_heads
*
self
.
head_dim
]
kv
=
qkv
[...,
self
.
num_heads
*
self
.
head_dim
:]
q
=
rearrange
(
q
,
"... (h d) -> ... h d"
,
d
=
self
.
head_dim
)
kv
=
rearrange
(
kv
,
"... (two hkv d) -> ... two hkv d"
,
two
=
2
,
d
=
self
.
head_dim
)
# q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
q
=
q
.
reshape
(
batch
,
seqlen
,
-
1
,
self
.
head_dim
)
# kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
kv
=
kv
.
reshape
(
batch
,
seqlen
,
2
,
-
1
,
self
.
head_dim
)
if
self
.
dwconv
:
q
=
rearrange
(
self
.
dwconv_q
(
rearrange
(
q
,
"b s d -> b d s"
))[...,
:
-
2
],
"b d s -> b s d"
...
...
@@ -685,11 +728,11 @@ class MHA(nn.Module):
self
.
inner_cross_attn
,
q
,
kv
,
**
kwargs
)
else
:
kv
=
self
.
_update_kv_cache
(
kv
,
inference_params
)
context
=
self
.
inner_cross_attn
(
q
,
kv
)
context
=
self
.
_update_kvcache_attention
(
q
,
kv
,
inference_params
)
else
:
context
=
self
.
_apply_rotary_single_query_attention
(
q
,
inference_params
,
kv
=
kv
)
out
=
self
.
out_proj
(
rearrange
(
context
,
"... h d -> ... (h d)"
))
# out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
out
=
self
.
out_proj
(
context
.
reshape
(
batch
,
seqlen
,
-
1
))
return
out
if
not
self
.
return_residual
else
(
out
,
x
)
...
...
@@ -846,6 +889,36 @@ class ParallelMHA(nn.Module):
else
False
,
)
def
_update_kvcache_attention
(
self
,
q
,
kv
,
inference_params
):
"""Write kv to inference_params, then do attention """
if
(
inference_params
.
sequence_len_offset
==
0
or
flash_attn_with_kvcache
is
None
or
not
self
.
use_flash_attn
):
# TODO: this only uses sequence_len_offset and not lengths_per_sample.
kv
=
self
.
_update_kv_cache
(
kv
,
inference_params
)
return
self
.
inner_cross_attn
(
q
,
kv
)
else
:
batch
=
q
.
shape
[
0
]
kv_cache
=
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
][:
batch
]
cache_seqlens
=
(
inference_params
.
lengths_per_sample
[:
batch
]
if
inference_params
.
lengths_per_sample
is
not
None
else
inference_params
.
sequence_len_offset
)
context
=
flash_attn_with_kvcache
(
q
,
kv_cache
[:,
:,
0
],
kv_cache
[:,
:,
1
],
kv
[:,
:,
0
],
kv
[:,
:,
1
],
cache_seqlens
=
cache_seqlens
,
softmax_scale
=
self
.
inner_cross_attn
.
softmax_scale
,
causal
=
self
.
inner_cross_attn
.
causal
,
)
return
context
def
forward
(
self
,
x
,
seqlen
=
None
,
inference_params
=
None
,
**
kwargs
):
"""
Arguments:
...
...
@@ -857,7 +930,15 @@ class ParallelMHA(nn.Module):
qkv
=
self
.
Wqkv
(
x
)
if
seqlen
is
not
None
:
qkv
=
rearrange
(
qkv
,
"(b s) ... -> b s ..."
,
s
=
seqlen
)
seqlen_offset
=
0
if
inference_params
is
None
else
inference_params
.
sequence_len_offset
seqlen_offset
=
(
0
if
inference_params
is
None
else
(
inference_params
.
lengths_per_sample
if
inference_params
.
lengths_per_sample
is
not
None
else
inference_params
.
sequence_len_offset
)
)
rotary_max_seqlen
=
(
inference_params
.
max_sequence_len
if
inference_params
is
not
None
else
None
)
...
...
@@ -878,9 +959,9 @@ class ParallelMHA(nn.Module):
else
:
context
=
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
inner_attn
,
qkv
,
**
kwargs
)
else
:
q
=
qkv
[:,
:,
0
]
kv
=
_update_kv_cache
(
qkv
[:,
:,
1
:],
inference_params
,
self
.
layer_idx
)
context
=
self
.
inner_cross_attn
(
q
,
kv
)
context
=
self
.
_update_kvcache_attention
(
qkv
[:,
:,
0
],
qkv
[:,
:,
1
:],
inference_params
)
else
:
context
=
self
.
_apply_rotary_single_query_attention
(
qkv
,
inference_params
)
else
:
...
...
@@ -912,8 +993,7 @@ class ParallelMHA(nn.Module):
self
.
inner_cross_attn
,
q
,
kv
,
**
kwargs
)
else
:
kv
=
self
.
_update_kv_cache
(
kv
,
inference_params
)
context
=
self
.
inner_cross_attn
(
q
,
kv
)
context
=
self
.
_update_kvcache_attention
(
q
,
kv
,
inference_params
)
else
:
context
=
self
.
_apply_rotary_single_query_attention
(
q
,
inference_params
,
kv
=
kv
)
context
=
rearrange
(
context
,
"b s h d -> b s (h d)"
)
...
...
flash_attn/utils/generation.py
View file @
a86442f0
...
...
@@ -118,7 +118,6 @@ def decode(
batch_size
,
seqlen_og
=
input_ids
.
shape
teacher_output_len
=
teacher_outputs
.
shape
[
1
]
if
teacher_outputs
is
not
None
else
0
if
cg
:
assert
fused_ft_kernel
if
not
hasattr
(
model
,
"_decoding_cache"
):
model
.
_decoding_cache
=
None
model
.
_decoding_cache
=
update_graph_cache
(
...
...
@@ -128,11 +127,13 @@ def decode(
seqlen_og
,
max_length
,
tensor_parallel
=
tensor_parallel
,
fused_ft_kernel
=
fused_ft_kernel
,
)
inference_params
=
model
.
_decoding_cache
.
inference_params
inference_params
.
max_sequence_len
=
max_length
inference_params
.
max_batch_size
=
batch_size
inference_params
.
sequence_len_offset
=
0
inference_params
.
lengths_per_sample
.
zero_
()
else
:
inference_params
=
InferenceParams
(
max_sequence_len
=
max_length
,
max_batch_size
=
batch_size
,
fused_ft_kernel
=
fused_ft_kernel
...
...
@@ -167,7 +168,8 @@ def decode(
token
=
sample
(
logits
,
top_k
=
top_k
,
top_p
=
top_p
,
temperature
=
temperature
)
else
:
token
=
teacher_outputs
[:,
inference_params
.
sequence_len_offset
]
return
rearrange
(
token
,
"b -> b 1"
)
# return rearrange(token, "b -> b 1")
return
token
.
unsqueeze
(
1
)
def
should_stop
(
current_token
,
inference_params
):
if
inference_params
.
sequence_len_offset
==
0
:
...
...
@@ -197,9 +199,7 @@ def decode(
torch
.
cuda
.
synchronize
()
print
(
f
"Prompt processing + decoding time:
{
(
start
.
elapsed_time
(
end
)):.
0
f
}
ms"
)
output_cls
=
GreedySearchDecoderOnlyOutput
if
top_k
==
1
else
SampleDecoderOnlyOutput
return
output_cls
(
sequences
=
torch
.
cat
(
sequences
,
dim
=
1
),
scores
=
tuple
(
scores
)
)
return
output_cls
(
sequences
=
torch
.
cat
(
sequences
,
dim
=
1
),
scores
=
tuple
(
scores
))
def
sample_speculative
(
logits
,
logits_draft
,
tokens_draft
,
top_k
=
1
,
top_p
=
0.0
,
temperature
=
1.0
):
...
...
@@ -298,7 +298,6 @@ def decode_speculative(
assert
batch_size
==
1
,
"Speculative decoding implementation only supports batch_size=1"
assert
eos_token_id
is
None
,
"Speculative decoding implementation doesn't support eos_token_id"
if
cg
:
assert
fused_ft_kernel
if
not
hasattr
(
model_draft
,
"_decoding_cache"
):
model_draft
.
_decoding_cache
=
None
model_draft
.
_decoding_cache
=
update_graph_cache
(
...
...
@@ -308,6 +307,7 @@ def decode_speculative(
seqlen_og
,
max_length
,
tensor_parallel
=
tensor_parallel
,
fused_ft_kernel
=
fused_ft_kernel
,
)
inference_params_draft
=
model_draft
.
_decoding_cache
.
inference_params
inference_params_draft
.
max_sequence_len
=
max_length
...
...
@@ -606,12 +606,14 @@ def allocate_inference_cache(
layers
:
Union
[
int
,
Sequence
],
device
,
dtype
=
torch
.
float16
,
fused_ft_kernel
=
False
,
):
assert
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
packsize
=
4
if
dtype
==
torch
.
float32
else
8
assert
headdim
%
packsize
==
0
k_cache_shape
=
(
max_batch_size
,
nheads
,
headdim
//
packsize
,
max_seqlen
,
packsize
)
v_cache_shape
=
(
max_batch_size
,
nheads
,
max_seqlen
,
headdim
)
kv_cache_shape
=
(
max_batch_size
,
max_seqlen
,
2
,
nheads
,
headdim
)
if
isinstance
(
layers
,
int
):
layers
=
range
(
layers
)
return
{
...
...
@@ -619,6 +621,8 @@ def allocate_inference_cache(
torch
.
empty
(
k_cache_shape
,
device
=
device
,
dtype
=
dtype
),
torch
.
empty
(
v_cache_shape
,
device
=
device
,
dtype
=
dtype
),
)
if
fused_ft_kernel
else
torch
.
empty
(
kv_cache_sahpe
,
device
=
device
,
dtype
=
dtype
)
for
i
in
layers
}
...
...
@@ -651,7 +655,15 @@ class DecodingCGCache:
@
torch
.
inference_mode
()
def
update_graph_cache
(
model
,
cache
,
batch_size
,
seqlen_og
,
max_seqlen
,
tensor_parallel
=
1
,
dtype
=
None
,
n_warmups
=
2
model
,
cache
,
batch_size
,
seqlen_og
,
max_seqlen
,
tensor_parallel
=
1
,
dtype
=
None
,
n_warmups
=
2
,
fused_ft_kernel
=
False
,
):
if
cache
is
None
:
cache
=
DecodingCGCache
()
...
...
@@ -671,7 +683,9 @@ def update_graph_cache(
cache
.
device
,
cache
.
dtype
=
device
,
dtype
cache
.
max_batch_size
,
cache
.
max_seqlen
=
batch_size
,
max_seqlen
if
hasattr
(
model
,
"allocate_inference_cache"
):
inf_cache
=
model
.
allocate_inference_cache
(
batch_size
,
max_seqlen
,
dtype
)
inf_cache
=
model
.
allocate_inference_cache
(
batch_size
,
max_seqlen
,
dtype
,
fused_ft_kernel
=
fused_ft_kernel
)
else
:
headdim
=
getattr
(
model
.
config
,
...
...
@@ -686,6 +700,7 @@ def update_graph_cache(
model
.
config
.
num_hidden_layers
,
device
,
dtype
,
fused_ft_kernel
=
fused_ft_kernel
,
)
lengths_per_sample
=
torch
.
full
((
batch_size
,),
seqlen_og
,
dtype
=
torch
.
int32
,
device
=
device
)
cache
.
inference_params
=
InferenceParams
(
...
...
@@ -693,7 +708,7 @@ def update_graph_cache(
max_batch_size
=
batch_size
,
sequence_len_offset
=
seqlen_og
,
key_value_memory_dict
=
inf_cache
,
fused_ft_kernel
=
True
,
fused_ft_kernel
=
fused_ft_kernel
,
lengths_per_sample
=
lengths_per_sample
,
)
cache
.
mempool
=
torch
.
cuda
.
graphs
.
graph_pool_handle
()
...
...
tests/models/test_baichuan.py
View file @
a86442f0
...
...
@@ -217,8 +217,9 @@ def test_baichuan_parallel_forward(model_name, world_size):
).
abs
().
max
().
item
()
@
pytest
.
mark
.
parametrize
(
"fused_ft_kernel"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"baichuan-inc/Baichuan-7B"
])
def
test_baichuan_generation
(
model_name
):
def
test_baichuan_generation
(
model_name
,
fused_ft_kernel
):
dtype
=
torch
.
float16
device
=
"cuda"
config
=
baichuan_config_to_gpt2_config
(
...
...
@@ -276,6 +277,7 @@ def test_baichuan_generation(model_name):
model
.
load_state_dict
(
pretrained_state_dict
)
model
.
eval
()
model
(
input_ids
)
# Warm up
print
(
"Without CUDA graph"
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
...
...
@@ -283,7 +285,7 @@ def test_baichuan_generation(model_name):
input_ids
=
input_ids
,
max_length
=
max_length
,
eos_token_id
=
eos_token_id
,
fused_ft_kernel
=
True
,
fused_ft_kernel
=
fused_ft_kernel
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
enable_timing
=
True
,
...
...
@@ -295,7 +297,7 @@ def test_baichuan_generation(model_name):
# Capture graph outside the timing loop
batch_size
,
seqlen_og
=
input_ids
.
shape
model
.
_decoding_cache
=
update_graph_cache
(
model
,
None
,
batch_size
,
seqlen_og
,
max_length
model
,
None
,
batch_size
,
seqlen_og
,
max_length
,
fused_ft_kernel
=
fused_ft_kernel
)
print
(
"With CUDA graph"
)
torch
.
cuda
.
synchronize
()
...
...
@@ -303,7 +305,7 @@ def test_baichuan_generation(model_name):
out_cg
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
fused_ft_kernel
=
True
,
fused_ft_kernel
=
fused_ft_kernel
,
cg
=
True
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
...
...
@@ -346,7 +348,7 @@ def test_baichuan_parallel_generation(model_name, world_size):
config
=
baichuan_config_to_gpt2_config
(
AutoConfig
.
from_pretrained
(
model_name
,
trust_remote_code
=
True
)
)
config
.
use_flash_attn
=
Fals
e
config
.
use_flash_attn
=
Tru
e
config
.
fused_bias_fc
=
True
config
.
fused_mlp
=
False
# We don't have fused GatedMLP yet
config
.
fused_dropout_add_ln
=
False
...
...
@@ -393,7 +395,6 @@ def test_baichuan_parallel_generation(model_name, world_size):
max_length
=
max_length
,
tensor_parallel
=
world_size
,
vocab_size
=
config
.
vocab_size
,
fused_ft_kernel
=
True
,
# teacher_outputs=out_hf.sequences,
return_dict_in_generate
=
True
,
output_scores
=
True
,
...
...
@@ -411,7 +412,6 @@ def test_baichuan_parallel_generation(model_name, world_size):
max_length
=
max_length
,
tensor_parallel
=
world_size
,
vocab_size
=
config
.
vocab_size
,
fused_ft_kernel
=
True
,
cg
=
True
,
# teacher_outputs=out_hf.sequences,
return_dict_in_generate
=
True
,
...
...
@@ -458,6 +458,6 @@ def test_baichuan_parallel_generation(model_name, world_size):
hf_error
=
(
logits_hf
-
logits_ref
).
abs
().
max
().
item
()
print
(
f
"HF fp16 logits max diff:
{
hf_error
}
"
)
print
(
f
"Logits max diff:
{
(
logits
-
logits_ref
).
abs
().
max
().
item
()
}
"
)
assert
(
logits
-
logits_ref
).
abs
().
max
().
item
()
<
2
*
hf_error
print
(
f
"Logits CG max diff:
{
(
logits_cg
-
logits_ref
).
abs
().
max
().
item
()
}
"
)
assert
(
logits
-
logits_ref
).
abs
().
max
().
item
()
<
2
*
hf_error
assert
torch
.
equal
(
logits_cg
,
logits
)
tests/models/test_gpt.py
View file @
a86442f0
...
...
@@ -135,7 +135,7 @@ def test_gpt2_optimized(model_name):
@
pytest
.
mark
.
parametrize
(
"fused_ft_kernel"
,
[
False
,
True
])
# @pytest.mark.parametrize('fused_ft_kernel', [
Fals
e])
# @pytest.mark.parametrize('fused_ft_kernel', [
Tru
e])
@
pytest
.
mark
.
parametrize
(
"optimized"
,
[
False
,
True
])
# @pytest.mark.parametrize('optimized', [True])
@
pytest
.
mark
.
parametrize
(
"rotary"
,
[
False
,
True
])
...
...
@@ -209,7 +209,7 @@ def test_gpt2_generation(model_name, rotary, optimized, fused_ft_kernel):
)
print
(
out
.
sequences
)
print
(
tokenizer
.
batch_decode
(
out
.
sequences
.
tolist
()))
if
fused_ft_kernel
or
config
.
use_flash_attn
:
if
fused_ft_kernel
or
getattr
(
config
,
"
use_flash_attn
"
,
False
)
:
out_cg
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
...
...
@@ -220,6 +220,7 @@ def test_gpt2_generation(model_name, rotary, optimized, fused_ft_kernel):
enable_timing
=
True
,
)
print
(
out_cg
.
sequences
)
assert
torch
.
equal
(
torch
.
stack
(
out
.
scores
,
dim
=
1
),
torch
.
stack
(
out_cg
.
scores
,
dim
=
1
))
if
not
rotary
:
out_hf
=
model_hf
.
generate
(
...
...
@@ -282,6 +283,7 @@ def get_logits(model, input_ids, max_length, teacher_outputs=None, **kwargs):
@
pytest
.
mark
.
parametrize
(
"rotary"
,
[
None
,
"interleaved"
,
"block"
])
# @pytest.mark.parametrize('rotary', [None])
@
pytest
.
mark
.
parametrize
(
"fused_ft_kernel"
,
[
False
,
True
])
# @pytest.mark.parametrize("fused_ft_kernel", [False])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"gpt2"
])
def
test_gpt2_generation_cg
(
model_name
,
fused_ft_kernel
,
rotary
,
seqlen
,
maxlen
):
"""Check that decoding with CUDA graph is the same as decoding without CUDA graph."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment