Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a4c4daf3
Unverified
Commit
a4c4daf3
authored
Dec 02, 2024
by
youkaichao
Committed by
GitHub
Dec 02, 2024
Browse files
[misc] use out argument for flash attention (#10822)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
e95f275f
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
144 additions
and
157 deletions
+144
-157
vllm/attention/backends/abstract.py
vllm/attention/backends/abstract.py
+1
-0
vllm/attention/backends/blocksparse_attn.py
vllm/attention/backends/blocksparse_attn.py
+2
-0
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+19
-36
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+4
-0
vllm/attention/backends/hpu_attn.py
vllm/attention/backends/hpu_attn.py
+1
-0
vllm/attention/backends/ipex_attn.py
vllm/attention/backends/ipex_attn.py
+1
-0
vllm/attention/backends/pallas.py
vllm/attention/backends/pallas.py
+1
-0
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+1
-0
vllm/attention/backends/torch_sdpa.py
vllm/attention/backends/torch_sdpa.py
+1
-0
vllm/attention/backends/xformers.py
vllm/attention/backends/xformers.py
+1
-0
vllm/attention/layer.py
vllm/attention/layer.py
+72
-4
vllm/config.py
vllm/config.py
+1
-1
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+39
-116
No files found.
vllm/attention/backends/abstract.py
View file @
a4c4daf3
...
...
@@ -247,5 +247,6 @@ class AttentionImpl(ABC, Generic[T]):
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
attn_type
:
str
=
AttentionType
.
DECODER
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
raise
NotImplementedError
vllm/attention/backends/blocksparse_attn.py
View file @
a4c4daf3
...
...
@@ -360,6 +360,7 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
attn_type
:
str
=
AttentionType
.
DECODER
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""Forward pass with FlashAttention and PagedAttention.
...
...
@@ -448,5 +449,6 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
blocksparse_head_sliding_step
=
self
.
head_sliding_step
,
)
assert
output
is
not
None
# Reshape the output tensor.
return
output
.
view
(
num_tokens
,
hidden_size
)
vllm/attention/backends/flash_attn.py
View file @
a4c4daf3
...
...
@@ -638,24 +638,27 @@ class FlashAttentionImpl(AttentionImpl):
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
attn_type
:
str
=
AttentionType
.
DECODER
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""Forward pass with FlashAttention.
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
output: shape = [num_tokens, num_heads, head_size]
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
NOTE: It in-place updates the output tensor.
"""
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
assert
k_scale
==
1.0
and
v_scale
==
1.0
,
(
"key/v_scale is not supported in FlashAttention."
)
assert
output
is
not
None
,
"Output tensor must be provided."
if
(
attn_type
==
AttentionType
.
ENCODER
and
(
not
attn_metadata
.
is_all_encoder_attn_metadata_set
)):
raise
AttributeError
(
"Encoder attention requires setting "
...
...
@@ -666,23 +669,12 @@ class FlashAttentionImpl(AttentionImpl):
"requires setting cross-attention "
"metadata attributes."
)
num_heads
:
int
=
self
.
num_heads
head_size
:
int
=
self
.
head_size
num_kv_heads
:
int
=
self
.
num_kv_heads
kv_cache_dtype
:
str
=
self
.
kv_cache_dtype
softmax_scale
:
float
=
self
.
scale
window_size
=
self
.
sliding_window
alibi_slopes
:
Optional
[
torch
.
Tensor
]
=
self
.
alibi_slopes
logits_soft_cap
:
Optional
[
float
]
=
self
.
logits_soft_cap
num_tokens
,
hidden_size
=
query
.
shape
# Reshape the query, key, and value tensors.
query
=
query
.
view
(
-
1
,
num_heads
,
head_size
)
if
(
key
is
not
None
)
and
(
value
is
not
None
):
key
=
key
.
view
(
-
1
,
num_kv_heads
,
head_size
)
value
=
value
.
view
(
-
1
,
num_kv_heads
,
head_size
)
if
kv_cache
.
numel
()
>
0
:
key_cache
=
kv_cache
[
0
]
value_cache
=
kv_cache
[
1
]
...
...
@@ -721,13 +713,13 @@ class FlashAttentionImpl(AttentionImpl):
num_decode_query_tokens
)
=
\
get_num_prefill_decode_query_kv_tokens
(
attn_metadata
,
attn_type
)
decode_query
=
query
[
num_prefill_query_tokens
:]
decode_output
=
output
[
num_prefill_query_tokens
:]
# QKV for prefill.
query
=
query
[:
num_prefill_query_tokens
]
prefill_output
=
output
[:
num_prefill_query_tokens
]
assert
query
.
shape
[
0
]
==
num_prefill_query_tokens
assert
decode_query
.
shape
[
0
]
==
num_decode_query_tokens
prefill_output
:
Optional
[
torch
.
Tensor
]
=
None
decode_output
:
Optional
[
torch
.
Tensor
]
=
None
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
# Prompt run.
if
(
kv_cache
.
numel
()
==
0
or
prefill_meta
.
block_tables
is
None
...
...
@@ -741,7 +733,7 @@ class FlashAttentionImpl(AttentionImpl):
key
=
key
[:
num_prefill_kv_tokens
]
value
=
value
[:
num_prefill_kv_tokens
]
prefill_output
=
flash_attn_varlen_func
(
flash_attn_varlen_func
(
q
=
query
,
k
=
key
,
v
=
value
,
...
...
@@ -754,6 +746,7 @@ class FlashAttentionImpl(AttentionImpl):
window_size
=
window_size
,
alibi_slopes
=
alibi_slopes
,
softcap
=
logits_soft_cap
,
out
=
prefill_output
,
)
else
:
# prefix-enabled attention
...
...
@@ -761,7 +754,7 @@ class FlashAttentionImpl(AttentionImpl):
"Only decoder-only models support prefix caching"
)
assert
prefill_meta
.
seq_lens
is
not
None
max_seq_len
=
max
(
prefill_meta
.
seq_lens
)
prefill_output
=
flash_attn_varlen_func
(
# noqa
flash_attn_varlen_func
(
# noqa
q
=
query
,
k
=
key_cache
,
v
=
value_cache
,
...
...
@@ -775,6 +768,7 @@ class FlashAttentionImpl(AttentionImpl):
alibi_slopes
=
alibi_slopes
,
block_table
=
prefill_meta
.
block_tables
,
softcap
=
logits_soft_cap
,
out
=
prefill_output
,
)
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
...
...
@@ -788,7 +782,7 @@ class FlashAttentionImpl(AttentionImpl):
assert
attn_type
==
AttentionType
.
DECODER
,
(
"Only decoder-only models support max_decode_query_len > 1"
)
decode_output
=
flash_attn_varlen_func
(
flash_attn_varlen_func
(
q
=
decode_query
,
k
=
key_cache
,
v
=
value_cache
,
...
...
@@ -802,6 +796,7 @@ class FlashAttentionImpl(AttentionImpl):
alibi_slopes
=
alibi_slopes
,
softcap
=
logits_soft_cap
,
block_table
=
decode_meta
.
block_tables
,
out
=
decode_output
,
)
else
:
# Use flash_attn_with_kvcache for normal decoding.
...
...
@@ -810,7 +805,7 @@ class FlashAttentionImpl(AttentionImpl):
_
,
block_tables_arg
,
)
=
get_seq_len_block_table_args
(
decode_meta
,
False
,
attn_type
)
decode_output
=
flash_attn_with_kvcache
(
flash_attn_with_kvcache
(
q
=
decode_query
.
unsqueeze
(
1
),
k_cache
=
key_cache
,
v_cache
=
value_cache
,
...
...
@@ -821,20 +816,8 @@ class FlashAttentionImpl(AttentionImpl):
window_size
=
window_size
,
alibi_slopes
=
alibi_slopes
,
softcap
=
logits_soft_cap
,
).
squeeze
(
1
)
if
prefill_output
is
None
:
assert
decode_output
is
not
None
return
decode_output
.
view
(
num_decode_query_tokens
,
hidden_size
)
if
decode_output
is
None
:
assert
prefill_output
is
not
None
return
prefill_output
.
view
(
num_prefill_query_tokens
,
hidden_size
)
assert
decode_meta
is
not
None
decode_output
=
decode_output
.
squeeze
(
1
)
output
=
torch
.
cat
([
prefill_output
,
decode_output
],
dim
=
0
)
return
output
.
view
(
num_tokens
,
hidden_size
)
out
=
decode_output
.
unsqueeze
(
1
),
)
return
output
...
...
vllm/attention/backends/flashinfer.py
View file @
a4c4daf3
...
...
@@ -774,7 +774,11 @@ class FlashInferImpl(AttentionImpl):
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
attn_type
:
str
=
AttentionType
.
DECODER
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
# TODO: directly write to output tensor
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
"encoder/decoder cross-attention "
...
...
vllm/attention/backends/hpu_attn.py
View file @
a4c4daf3
...
...
@@ -145,6 +145,7 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
attn_type
:
str
=
AttentionType
.
DECODER
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""Forward pass with xFormers and PagedAttention.
...
...
vllm/attention/backends/ipex_attn.py
View file @
a4c4daf3
...
...
@@ -173,6 +173,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
attn_type
:
str
=
AttentionType
.
DECODER
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""Forward pass with IPEX varlen_attention and PagedAttention.
...
...
vllm/attention/backends/pallas.py
View file @
a4c4daf3
...
...
@@ -151,6 +151,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
attn_type
:
str
=
AttentionType
.
DECODER
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""Forward pass with Pallas attention.
...
...
vllm/attention/backends/rocm_flash_attn.py
View file @
a4c4daf3
...
...
@@ -415,6 +415,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
attn_type
:
str
=
AttentionType
.
DECODER
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""Forward pass with FlashAttention and PagedAttention.
...
...
vllm/attention/backends/torch_sdpa.py
View file @
a4c4daf3
...
...
@@ -431,6 +431,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
attn_type
:
str
=
AttentionType
.
DECODER
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""Forward pass with torch SDPA and PagedAttention.
...
...
vllm/attention/backends/xformers.py
View file @
a4c4daf3
...
...
@@ -417,6 +417,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
attn_type
:
str
=
AttentionType
.
DECODER
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""Forward pass with xFormers and PagedAttention.
...
...
vllm/attention/layer.py
View file @
a4c4daf3
...
...
@@ -4,7 +4,6 @@ from typing import Any, Dict, List, Optional
import
torch
import
torch.nn
as
nn
import
vllm.envs
as
envs
from
vllm.attention
import
AttentionMetadata
,
AttentionType
from
vllm.attention.selector
import
backend_name_to_enum
,
get_attn_backend
from
vllm.config
import
CacheConfig
,
get_current_vllm_config
...
...
@@ -12,7 +11,7 @@ from vllm.forward_context import ForwardContext, get_forward_context
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
_Backend
,
current_platform
from
vllm.utils
import
direct_register_custom_op
...
...
@@ -97,14 +96,23 @@ class Attention(nn.Module):
self
.
impl
=
impl_cls
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
alibi_slopes
,
sliding_window
,
kv_cache_dtype
,
blocksparse_params
,
logits_soft_cap
)
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
num_kv_heads
=
num_kv_heads
self
.
backend
=
backend_name_to_enum
(
attn_backend
.
get_name
())
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
# torch.compile works by registering the attention as one giant
# opaque custom op. For other platforms, we directly call them
# and let torch.compile handle them.
self
.
use_direct_call
=
envs
.
VLLM_USE_V1
or
not
(
current_platform
.
is_cuda_alike
()
or
current_platform
.
is_cpu
())
self
.
use_direct_call
=
not
current_platform
.
is_cuda_alike
(
)
and
not
current_platform
.
is_cpu
()
# For some attention backends, we allocate an output tensor before
# calling the custom op. When piecewise cudagraph is enabled, this
# makes sure the output tensor is allocated inside the cudagraph.
self
.
use_output
=
self
.
backend
==
_Backend
.
FLASH_ATTN
or
\
self
.
backend
==
_Backend
.
FLASH_ATTN_VLLM_V1
compilation_config
=
get_current_vllm_config
().
compilation_config
if
prefix
in
compilation_config
.
static_forward_context
:
raise
ValueError
(
f
"Duplicate layer name:
{
prefix
}
"
)
...
...
@@ -130,6 +138,22 @@ class Attention(nn.Module):
self
.
_k_scale
,
self
.
_v_scale
,
attn_type
=
attn_type
)
elif
self
.
use_output
:
output
=
torch
.
empty_like
(
query
)
hidden_size
=
query
.
size
(
-
1
)
# Reshape the query, key, and value tensors.
# NOTE(woosuk): We do this outside the custom op to minimize the
# CPU overheads from the non-CUDA-graph regions.
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
output
=
output
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
if
key
is
not
None
:
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
if
value
is
not
None
:
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
torch
.
ops
.
vllm
.
unified_attention_with_output
(
query
,
key
,
value
,
output
,
kv_cache
,
attn_type
,
self
.
layer_name
)
return
output
.
view
(
-
1
,
hidden_size
)
else
:
return
torch
.
ops
.
vllm
.
unified_attention
(
query
,
key
,
value
,
kv_cache
,
attn_type
,
...
...
@@ -183,3 +207,47 @@ direct_register_custom_op(
fake_impl
=
unified_attention_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
def
unified_attention_with_output
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_type
:
str
,
layer_name
:
str
,
)
->
None
:
forward_context
:
ForwardContext
=
get_forward_context
()
attn_metadata
=
forward_context
.
dynamic_forward_context
self
=
forward_context
.
static_forward_context
[
layer_name
]
self
.
impl
.
forward
(
query
,
key
,
value
,
kv_cache
,
attn_metadata
,
self
.
_k_scale
,
self
.
_v_scale
,
attn_type
=
attn_type
,
output
=
output
)
def
unified_attention_with_output_fake
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_type
:
str
,
layer_name
:
str
,
)
->
None
:
return
direct_register_custom_op
(
op_name
=
"unified_attention_with_output"
,
op_func
=
unified_attention_with_output
,
mutates_args
=
[
"kv_cache"
,
"output"
],
fake_impl
=
unified_attention_with_output_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
vllm/config.py
View file @
a4c4daf3
...
...
@@ -2238,7 +2238,7 @@ class CompilationConfig(BaseModel):
custom_ops
:
List
[
str
]
=
Field
(
default_factory
=
list
)
splitting_ops
:
List
[
str
]
=
Field
(
default_factory
=
lambda
:
[
"vllm.unified_attention"
,
"vllm.unified_
v1_flash_
attention"
,
"vllm.unified_attention
_with_output
"
,
])
use_inductor
:
bool
=
True
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
a4c4daf3
...
...
@@ -6,8 +6,6 @@ import torch
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionType
)
from
vllm.forward_context
import
get_forward_context
from
vllm.utils
import
direct_register_custom_op
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
...
...
@@ -113,13 +111,14 @@ class FlashAttentionImpl(AttentionImpl):
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""Forward pass with FlashAttention.
Args:
query: shape = [num_tokens, num_heads
*
head_size]
key: shape = [num_tokens, num_kv_heads
*
head_size]
value: shape = [num_tokens, num_kv_heads
*
head_size]
query: shape = [num_tokens, num_heads
,
head_size]
key: shape = [num_tokens, num_kv_heads
,
head_size]
value: shape = [num_tokens, num_kv_heads
,
head_size]
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
...
...
@@ -135,60 +134,10 @@ class FlashAttentionImpl(AttentionImpl):
assert
k_scale
==
1.0
and
v_scale
==
1.0
,
(
"key/v_scale is not supported in FlashAttention."
)
# Reshape the query, key, and value tensors.
# NOTE(woosuk): We do this outside the custom op to minimize the CPU
# overheads from the non-CUDA-graph regions.
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
output
=
torch
.
empty_like
(
query
)
torch
.
ops
.
vllm
.
unified_v1_flash_attention
(
output
,
query
,
key
,
value
,
self
.
num_heads
,
self
.
head_size
,
self
.
num_kv_heads
,
kv_cache
,
self
.
kv_cache_dtype
,
k_scale
,
v_scale
,
self
.
scale
,
self
.
sliding_window
,
self
.
alibi_slopes
,
self
.
logits_soft_cap
,
)
return
output
.
view
(
-
1
,
self
.
num_heads
*
self
.
head_size
)
def
unified_v1_flash_attention
(
output
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
num_heads
:
int
,
head_size
:
int
,
num_kv_heads
:
int
,
kv_cache
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
k_scale
:
float
,
v_scale
:
float
,
softmax_scale
:
float
,
window_size
:
Optional
[
List
[
int
]]
=
None
,
alibi_slopes
:
Optional
[
torch
.
Tensor
]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
)
->
None
:
context
=
get_forward_context
()
current_metadata
=
context
.
dynamic_forward_context
if
current_metadata
is
None
:
if
attn_metadata
is
None
:
# Profiling run.
return
return
output
assert
current_metadata
is
not
None
assert
isinstance
(
current_metadata
,
FlashAttentionMetadata
)
attn_metadata
:
FlashAttentionMetadata
=
current_metadata
num_actual_tokens
=
attn_metadata
.
num_actual_tokens
# Reshape the input keys and values and store them in the cache.
...
...
@@ -200,7 +149,7 @@ def unified_v1_flash_attention(
key_cache
,
value_cache
,
attn_metadata
.
slot_mapping
,
kv_cache_dtype
,
self
.
kv_cache_dtype
,
k_scale
,
v_scale
,
)
...
...
@@ -215,38 +164,12 @@ def unified_v1_flash_attention(
max_seqlen_q
=
attn_metadata
.
max_query_len
,
cu_seqlens_k
=
attn_metadata
.
seq_start_loc
,
max_seqlen_k
=
attn_metadata
.
max_seq_len
,
softmax_scale
=
s
oftmax_
scale
,
softmax_scale
=
s
elf
.
scale
,
causal
=
True
,
alibi_slopes
=
alibi_slopes
,
window_size
=
window
_size
,
alibi_slopes
=
self
.
alibi_slopes
,
window_size
=
self
.
sliding_
window
,
block_table
=
attn_metadata
.
block_table
,
softcap
=
logits_soft_cap
,
softcap
=
self
.
logits_soft_cap
,
)
def
unified_v1_flash_attention_fake
(
output
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
num_heads
:
int
,
head_size
:
int
,
num_kv_heads
:
int
,
kv_cache
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
k_scale
:
float
,
v_scale
:
float
,
softmax_scale
:
float
,
window_size
:
Optional
[
List
[
int
]]
=
None
,
alibi_slopes
:
Optional
[
torch
.
Tensor
]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
)
->
None
:
return
direct_register_custom_op
(
op_name
=
"unified_v1_flash_attention"
,
op_func
=
unified_v1_flash_attention
,
mutates_args
=
[
"kv_cache"
,
"output"
],
fake_impl
=
unified_v1_flash_attention_fake
,
)
return
output
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment