Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
c9d5b6d4
Unverified
Commit
c9d5b6d4
authored
May 05, 2023
by
Woosuk Kwon
Committed by
GitHub
May 05, 2023
Browse files
Replace FlashAttention with xformers (#70)
parent
189ae231
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
87 additions
and
131 deletions
+87
-131
README.md
README.md
+1
-5
cacheflow/master/server.py
cacheflow/master/server.py
+1
-1
cacheflow/models/attention.py
cacheflow/models/attention.py
+16
-38
cacheflow/models/input_metadata.py
cacheflow/models/input_metadata.py
+9
-12
cacheflow/models/llama.py
cacheflow/models/llama.py
+2
-2
cacheflow/models/memory_analyzer.py
cacheflow/models/memory_analyzer.py
+6
-6
cacheflow/models/opt.py
cacheflow/models/opt.py
+2
-2
cacheflow/worker/worker.py
cacheflow/worker/worker.py
+0
-8
tests/kernels/activation.py
tests/kernels/activation.py
+1
-1
tests/kernels/attention.py
tests/kernels/attention.py
+35
-44
tests/kernels/cache.py
tests/kernels/cache.py
+10
-9
tests/kernels/layernorm.py
tests/kernels/layernorm.py
+3
-2
tests/kernels/pos_encoding.py
tests/kernels/pos_encoding.py
+1
-1
No files found.
README.md
View file @
c9d5b6d4
...
@@ -3,11 +3,7 @@
...
@@ -3,11 +3,7 @@
## Installation
## Installation
```
bash
```
bash
pip
install
psutil numpy ray torch
pip
install
ninja psutil numpy sentencepiece ray torch transformers xformers
pip
install
git+https://github.com/huggingface/transformers
# Required for LLaMA.
pip
install
sentencepiece
# Required for LlamaTokenizer.
pip
install
ninja
# To parallelize the compilation of flash-attn.
pip
install
flash-attn
# This may take up to 10 mins.
pip
install
-e
.
pip
install
-e
.
```
```
...
...
cacheflow/master/server.py
View file @
c9d5b6d4
...
@@ -213,7 +213,7 @@ def add_server_arguments(parser: argparse.ArgumentParser):
...
@@ -213,7 +213,7 @@ def add_server_arguments(parser: argparse.ArgumentParser):
parser
.
add_argument
(
'--use-np-cache'
,
action
=
'store_true'
,
parser
.
add_argument
(
'--use-np-cache'
,
action
=
'store_true'
,
help
=
'save a numpy copy of model weights for faster loading'
)
help
=
'save a numpy copy of model weights for faster loading'
)
parser
.
add_argument
(
'--use-dummy-weights'
,
action
=
'store_true'
,
help
=
'use dummy values for model weights'
)
parser
.
add_argument
(
'--use-dummy-weights'
,
action
=
'store_true'
,
help
=
'use dummy values for model weights'
)
#
NOTE
(woosuk):
FlashAttention does not support float32
.
#
TODO
(woosuk):
Support FP32 for debugging
.
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
'default'
,
choices
=
[
'default'
,
'half'
,
'bfloat16'
],
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
'default'
,
choices
=
[
'default'
,
'half'
,
'bfloat16'
],
help
=
(
'data type for model weights and activations. '
help
=
(
'data type for model weights and activations. '
'The "default" option will use FP16 precision '
'The "default" option will use FP16 precision '
...
...
cacheflow/models/attention.py
View file @
c9d5b6d4
from
typing
import
Optional
from
typing
import
Optional
from
flash_attn.flash_attn_interface
import
_flash_attn_forward
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
xformers
import
ops
as
xops
from
cacheflow
import
attention_ops
from
cacheflow
import
attention_ops
from
cacheflow
import
cache_ops
from
cacheflow
import
cache_ops
...
@@ -15,6 +15,7 @@ class GPTCacheFlowAttention(nn.Module):
...
@@ -15,6 +15,7 @@ class GPTCacheFlowAttention(nn.Module):
def
__init__
(
self
,
scale
:
float
)
->
None
:
def
__init__
(
self
,
scale
:
float
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
scale
=
float
(
scale
)
self
.
scale
=
float
(
scale
)
self
.
attn_op
=
xops
.
fmha
.
cutlass
.
FwOp
()
def
multi_query_kv_attention
(
def
multi_query_kv_attention
(
self
,
self
,
...
@@ -22,32 +23,21 @@ class GPTCacheFlowAttention(nn.Module):
...
@@ -22,32 +23,21 @@ class GPTCacheFlowAttention(nn.Module):
query
:
torch
.
Tensor
,
# [num_prompt_tokens, num_heads, head_size]
query
:
torch
.
Tensor
,
# [num_prompt_tokens, num_heads, head_size]
key
:
torch
.
Tensor
,
# [num_prompt_tokens, num_heads, head_size]
key
:
torch
.
Tensor
,
# [num_prompt_tokens, num_heads, head_size]
value
:
torch
.
Tensor
,
# [num_prompt_tokens, num_heads, head_size]
value
:
torch
.
Tensor
,
# [num_prompt_tokens, num_heads, head_size]
cumulative_prompt_lens
:
torch
.
Tensor
,
# [num_prompts + 1]
attn_bias
:
xops
.
AttentionBias
,
max_prompt_len
:
int
,
)
->
None
:
)
->
None
:
if
query
.
dtype
==
torch
.
float
:
# TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
raise
ValueError
(
'The float data type is not supported by '
out
=
xops
.
memory_efficient_attention_forward
(
'FlashAttention. Use the half data type instead.'
)
query
.
unsqueeze
(
0
),
head_size
=
query
.
shape
[
-
1
]
key
.
unsqueeze
(
0
),
if
head_size
>
128
:
value
.
unsqueeze
(
0
),
raise
ValueError
(
'FlashAttention does not support head_size > 128.'
)
attn_bias
=
attn_bias
,
p
=
0.0
,
# Directly call FlashAttention's internal function to avoid allocating
scale
=
self
.
scale
,
# a new tensor for the output.
op
=
self
.
attn_op
,
_flash_attn_forward
(
query
,
key
,
value
,
output
,
cumulative_prompt_lens
,
cumulative_prompt_lens
,
max_prompt_len
,
max_prompt_len
,
dropout_p
=
0.0
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
return_softmax
=
False
,
)
)
# TODO(woosuk): Unnecessary copy. Optimize.
output
.
copy_
(
out
.
squeeze
(
0
))
return
output
def
single_query_cached_kv_attention
(
def
single_query_cached_kv_attention
(
self
,
self
,
...
@@ -109,8 +99,7 @@ class GPTCacheFlowAttention(nn.Module):
...
@@ -109,8 +99,7 @@ class GPTCacheFlowAttention(nn.Module):
query
[:
num_prompt_tokens
],
query
[:
num_prompt_tokens
],
key
[:
num_prompt_tokens
],
key
[:
num_prompt_tokens
],
value
[:
num_prompt_tokens
],
value
[:
num_prompt_tokens
],
input_metadata
.
cumulative_prompt_lens
,
input_metadata
.
attn_bias
,
input_metadata
.
max_prompt_len
,
)
)
# Wait until the cache op is done.
# Wait until the cache op is done.
...
@@ -143,13 +132,6 @@ class GPTCacheFlowAttention(nn.Module):
...
@@ -143,13 +132,6 @@ class GPTCacheFlowAttention(nn.Module):
return
output
.
view
(
-
1
,
num_heads
*
head_size
)
return
output
.
view
(
-
1
,
num_heads
*
head_size
)
class
OPTCacheFlowAttention
(
GPTCacheFlowAttention
):
"""OPT uses the same attention mechanism as GPT."""
def
__init__
(
self
,
scale
:
float
)
->
None
:
super
().
__init__
(
scale
)
class
GPTNeoXCacheFlowAttention
(
GPTCacheFlowAttention
):
class
GPTNeoXCacheFlowAttention
(
GPTCacheFlowAttention
):
"""Attention with GPT-NeoX style rotary embedding."""
"""Attention with GPT-NeoX style rotary embedding."""
...
@@ -207,7 +189,3 @@ class GPTNeoXCacheFlowAttention(GPTCacheFlowAttention):
...
@@ -207,7 +189,3 @@ class GPTNeoXCacheFlowAttention(GPTCacheFlowAttention):
input_metadata
,
input_metadata
,
cache_event
,
cache_event
,
)
)
class
LlamaCacheFlowAttention
(
GPTNeoXCacheFlowAttention
):
"""LLaMA uses the GPT-NeoX style rotary embedding."""
cacheflow/models/input_metadata.py
View file @
c9d5b6d4
from
typing
import
List
,
Dict
,
Tuple
from
typing
import
List
,
Dict
,
Tuple
import
torch
import
torch
from
xformers.ops.fmha.attn_bias
import
BlockDiagonalCausalMask
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sampling_params
import
SamplingParams
...
@@ -12,7 +13,6 @@ class InputMetadata:
...
@@ -12,7 +13,6 @@ class InputMetadata:
seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]],
seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]],
seq_logprobs
:
Dict
[
int
,
float
],
# Seq id -> cumulative logprobs.
seq_logprobs
:
Dict
[
int
,
float
],
# Seq id -> cumulative logprobs.
prompt_lens
:
List
[
int
],
prompt_lens
:
List
[
int
],
cumulative_prompt_lens
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
context_lens
:
torch
.
Tensor
,
context_lens
:
torch
.
Tensor
,
max_context_len
:
int
,
max_context_len
:
int
,
...
@@ -21,15 +21,14 @@ class InputMetadata:
...
@@ -21,15 +21,14 @@ class InputMetadata:
self
.
seq_groups
=
seq_groups
self
.
seq_groups
=
seq_groups
self
.
seq_logprobs
=
seq_logprobs
self
.
seq_logprobs
=
seq_logprobs
self
.
prompt_lens
=
prompt_lens
self
.
prompt_lens
=
prompt_lens
self
.
cumulative_prompt_lens
=
cumulative_prompt_lens
self
.
slot_mapping
=
slot_mapping
self
.
slot_mapping
=
slot_mapping
self
.
context_lens
=
context_lens
self
.
context_lens
=
context_lens
self
.
max_context_len
=
max_context_len
self
.
max_context_len
=
max_context_len
self
.
block_tables
=
block_tables
self
.
block_tables
=
block_tables
self
.
attn_bias
=
BlockDiagonalCausalMask
.
from_seqlens
(
prompt_lens
)
self
.
num_prompts
=
len
(
prompt_lens
)
self
.
num_prompts
=
len
(
prompt_lens
)
self
.
num_prompt_tokens
=
sum
(
prompt_lens
)
self
.
num_prompt_tokens
=
sum
(
prompt_lens
)
self
.
max_prompt_len
=
max
(
prompt_lens
)
if
prompt_lens
else
0
self
.
num_generation_tokens
=
context_lens
.
shape
[
0
]
self
.
num_generation_tokens
=
context_lens
.
shape
[
0
]
self
.
num_valid_tokens
=
slot_mapping
.
shape
[
0
]
self
.
num_valid_tokens
=
slot_mapping
.
shape
[
0
]
if
block_tables
.
numel
()
>
0
:
if
block_tables
.
numel
()
>
0
:
...
@@ -41,15 +40,13 @@ class InputMetadata:
...
@@ -41,15 +40,13 @@ class InputMetadata:
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
(
f
'InputMetadata('
return
(
f
'InputMetadata('
f
'num_prompts=
{
self
.
num_prompts
}
, '
f
'num_prompt_tokens=
{
self
.
num_prompt_tokens
}
, '
f
'max_prompt_len=
{
self
.
max_prompt_len
}
, '
f
'num_generation_tokens=
{
self
.
num_generation_tokens
}
, '
f
'num_valid_tokens=
{
self
.
num_valid_tokens
}
, '
f
'num_valid_tokens=
{
self
.
num_valid_tokens
}
, '
f
'
max_num_blocks_per_seq=
{
self
.
max_num_blocks_per_seq
}
, '
f
'
num_prompt_tokens=
{
self
.
num_prompt_tokens
}
, '
f
'
max_context_len=
{
self
.
max_context_len
}
)
, '
f
'
num_prompts=
{
self
.
num_prompts
}
, '
f
'prompt_lens=
{
self
.
prompt_lens
}
, '
f
'prompt_lens=
{
self
.
prompt_lens
}
, '
f
'cumulative_prompt_lens=
{
self
.
cumulative_prompt_lens
}
, '
f
'num_generation_tokens=
{
self
.
num_generation_tokens
}
, '
f
'slot_mapping=
{
self
.
slot_mapping
}
, '
f
'context_lens=
{
self
.
context_lens
}
, '
f
'context_lens=
{
self
.
context_lens
}
, '
f
'block_tables=
{
self
.
block_tables
}
)'
)
f
'max_context_len=
{
self
.
max_context_len
}
), '
f
'max_num_blocks_per_seq=
{
self
.
max_num_blocks_per_seq
}
, '
f
'block_tables=
{
self
.
block_tables
}
), '
f
'slot_mapping=
{
self
.
slot_mapping
}
'
)
cacheflow/models/llama.py
View file @
c9d5b6d4
...
@@ -7,7 +7,7 @@ from transformers import LlamaConfig
...
@@ -7,7 +7,7 @@ from transformers import LlamaConfig
from
cacheflow.models
import
InputMetadata
from
cacheflow.models
import
InputMetadata
from
cacheflow.models.activation
import
SiluAndMul
from
cacheflow.models.activation
import
SiluAndMul
from
cacheflow.models.attention
import
Llama
CacheFlowAttention
from
cacheflow.models.attention
import
GPTNeoX
CacheFlowAttention
from
cacheflow.models.layernorm
import
RMSNorm
from
cacheflow.models.layernorm
import
RMSNorm
from
cacheflow.models.sample
import
Sampler
from
cacheflow.models.sample
import
Sampler
from
cacheflow.models.utils
import
(
hf_model_weights_iterator
,
from
cacheflow.models.utils
import
(
hf_model_weights_iterator
,
...
@@ -79,7 +79,7 @@ class LlamaAttention(nn.Module):
...
@@ -79,7 +79,7 @@ class LlamaAttention(nn.Module):
input_is_parallel
=
True
,
input_is_parallel
=
True
,
perform_initialization
=
False
,
perform_initialization
=
False
,
)
)
self
.
attn
=
Llama
CacheFlowAttention
(
self
.
scaling
,
self
.
head_dim
)
self
.
attn
=
GPTNeoX
CacheFlowAttention
(
self
.
scaling
,
self
.
head_dim
)
def
forward
(
def
forward
(
self
,
self
,
...
...
cacheflow/models/memory_analyzer.py
View file @
c9d5b6d4
...
@@ -202,8 +202,8 @@ class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer):
...
@@ -202,8 +202,8 @@ class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer):
# estimating
# estimating
# 1) the maximum activation tensor size during inference
# 1) the maximum activation tensor size during inference
# 2) the residual tensor size during inference
# 2) the residual tensor size during inference
# Here, we assume that
FlashAttention is used and
# Here, we assume that
we use memory-efficient attention which
#
thus the attention maps are never materialized
in GPU DRAM.
#
does not materialize the attention maps
in GPU DRAM.
residual
=
max_num_batched_tokens
*
self
.
hidden_size
residual
=
max_num_batched_tokens
*
self
.
hidden_size
qkv
=
3
*
(
max_num_batched_tokens
*
self
.
hidden_size
)
//
self
.
tensor_parallel_size
qkv
=
3
*
(
max_num_batched_tokens
*
self
.
hidden_size
)
//
self
.
tensor_parallel_size
ffn
=
max_num_batched_tokens
*
self
.
ffn_size
//
self
.
tensor_parallel_size
ffn
=
max_num_batched_tokens
*
self
.
ffn_size
//
self
.
tensor_parallel_size
...
@@ -277,8 +277,8 @@ class LlamaMemoryAnalyzer(CacheFlowMemoryAnalyzer):
...
@@ -277,8 +277,8 @@ class LlamaMemoryAnalyzer(CacheFlowMemoryAnalyzer):
# estimating
# estimating
# 1) the maximum activation tensor size during inference
# 1) the maximum activation tensor size during inference
# 2) the residual tensor size during inference
# 2) the residual tensor size during inference
# Here, we assume that
FlashAttention is used and
# Here, we assume that
we use memory-efficient attention which
#
thus the attention maps are never materialized
in GPU DRAM.
#
does not materialize the attention maps
in GPU DRAM.
residual
=
max_num_batched_tokens
*
self
.
hidden_size
residual
=
max_num_batched_tokens
*
self
.
hidden_size
qkv
=
3
*
(
max_num_batched_tokens
*
self
.
hidden_size
)
//
self
.
tensor_parallel_size
qkv
=
3
*
(
max_num_batched_tokens
*
self
.
hidden_size
)
//
self
.
tensor_parallel_size
ffn
=
2
*
(
max_num_batched_tokens
*
self
.
ffn_size
)
//
self
.
tensor_parallel_size
ffn
=
2
*
(
max_num_batched_tokens
*
self
.
ffn_size
)
//
self
.
tensor_parallel_size
...
@@ -353,8 +353,8 @@ class GPTNeoXMemoryAnalyzer(CacheFlowMemoryAnalyzer):
...
@@ -353,8 +353,8 @@ class GPTNeoXMemoryAnalyzer(CacheFlowMemoryAnalyzer):
# estimating
# estimating
# 1) the maximum activation tensor size during inference
# 1) the maximum activation tensor size during inference
# 2) the residual tensor size during inference
# 2) the residual tensor size during inference
# Here, we assume that
FlashAttention is used and
# Here, we assume that
we use memory-efficient attention which
#
thus the attention maps are never materialized
in GPU DRAM.
#
does not materialize the attention maps
in GPU DRAM.
residual
=
max_num_batched_tokens
*
self
.
hidden_size
residual
=
max_num_batched_tokens
*
self
.
hidden_size
qkv
=
3
*
(
max_num_batched_tokens
*
self
.
hidden_size
)
//
self
.
tensor_parallel_size
qkv
=
3
*
(
max_num_batched_tokens
*
self
.
hidden_size
)
//
self
.
tensor_parallel_size
ffn
=
2
*
(
max_num_batched_tokens
*
self
.
ffn_size
)
//
self
.
tensor_parallel_size
ffn
=
2
*
(
max_num_batched_tokens
*
self
.
ffn_size
)
//
self
.
tensor_parallel_size
...
...
cacheflow/models/opt.py
View file @
c9d5b6d4
...
@@ -6,7 +6,7 @@ from torch import nn
...
@@ -6,7 +6,7 @@ from torch import nn
from
transformers
import
OPTConfig
from
transformers
import
OPTConfig
from
cacheflow.models
import
InputMetadata
from
cacheflow.models
import
InputMetadata
from
cacheflow.models.attention
import
O
PTCacheFlowAttention
from
cacheflow.models.attention
import
G
PTCacheFlowAttention
from
cacheflow.models.sample
import
Sampler
from
cacheflow.models.sample
import
Sampler
from
cacheflow.models.utils
import
(
hf_model_weights_iterator
,
from
cacheflow.models.utils
import
(
hf_model_weights_iterator
,
load_tensor_parallel_weights
)
load_tensor_parallel_weights
)
...
@@ -55,7 +55,7 @@ class OPTAttention(nn.Module):
...
@@ -55,7 +55,7 @@ class OPTAttention(nn.Module):
self
.
out_proj
=
RowParallelLinear
(
embed_dim
,
embed_dim
,
bias
=
bias
,
self
.
out_proj
=
RowParallelLinear
(
embed_dim
,
embed_dim
,
bias
=
bias
,
input_is_parallel
=
True
,
input_is_parallel
=
True
,
perform_initialization
=
False
)
perform_initialization
=
False
)
self
.
attn
=
O
PTCacheFlowAttention
(
scale
=
self
.
scaling
)
self
.
attn
=
G
PTCacheFlowAttention
(
scale
=
self
.
scaling
)
def
forward
(
def
forward
(
self
,
self
,
...
...
cacheflow/worker/worker.py
View file @
c9d5b6d4
...
@@ -136,11 +136,6 @@ class Worker:
...
@@ -136,11 +136,6 @@ class Worker:
slot
=
block_number
*
self
.
block_size
+
block_offset
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
.
append
(
slot
)
slot_mapping
.
append
(
slot
)
cumulative_prompt_lens
:
List
[
int
]
=
[
0
]
for
prompt_len
in
prompt_lens
:
cumulative_prompt_lens
.
append
(
cumulative_prompt_lens
[
-
1
]
+
prompt_len
)
# Add generation tokens.
# Add generation tokens.
max_context_len
=
0
max_context_len
=
0
max_num_blocks_per_seq
=
0
max_num_blocks_per_seq
=
0
...
@@ -196,14 +191,11 @@ class Worker:
...
@@ -196,14 +191,11 @@ class Worker:
for
block_table
in
generation_block_tables
]
for
block_table
in
generation_block_tables
]
block_tables_tensor
=
torch
.
tensor
(
block_tables_tensor
=
torch
.
tensor
(
padded_block_tables
,
dtype
=
torch
.
int
,
device
=
'cuda'
)
padded_block_tables
,
dtype
=
torch
.
int
,
device
=
'cuda'
)
cumulative_prompt_lens_tensor
=
torch
.
tensor
(
cumulative_prompt_lens
,
dtype
=
torch
.
int
,
device
=
'cuda'
)
input_metadata
=
InputMetadata
(
input_metadata
=
InputMetadata
(
seq_groups
=
seq_groups
,
seq_groups
=
seq_groups
,
seq_logprobs
=
seq_logprobs
,
seq_logprobs
=
seq_logprobs
,
prompt_lens
=
prompt_lens
,
prompt_lens
=
prompt_lens
,
cumulative_prompt_lens
=
cumulative_prompt_lens_tensor
,
slot_mapping
=
slot_mapping_tensor
,
slot_mapping
=
slot_mapping_tensor
,
context_lens
=
context_lens_tensor
,
context_lens
=
context_lens_tensor
,
max_context_len
=
max_context_len
,
max_context_len
=
max_context_len
,
...
...
tests/kernels/activation.py
View file @
c9d5b6d4
...
@@ -23,7 +23,7 @@ def test_silu_and_mul(
...
@@ -23,7 +23,7 @@ def test_silu_and_mul(
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
for
dtype
in
[
torch
.
half
,
torch
.
float
]:
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]:
for
num_tokens
in
[
7
,
83
,
2048
]:
for
num_tokens
in
[
7
,
83
,
2048
]:
for
d
in
[
512
,
4096
,
13824
]:
for
d
in
[
512
,
4096
,
13824
]:
print
(
f
'Testing dtype=
{
dtype
}
, num_tokens=
{
num_tokens
}
, d=
{
d
}
'
)
print
(
f
'Testing dtype=
{
dtype
}
, num_tokens=
{
num_tokens
}
, d=
{
d
}
'
)
...
...
tests/kernels/attention.py
View file @
c9d5b6d4
import
random
import
random
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
from
flash_attn.flash_attn_interface
import
_flash_attn_forward
import
torch
import
torch
from
xformers
import
ops
as
xops
from
xformers.ops.fmha.attn_bias
import
BlockDiagonalCausalMask
from
cacheflow
import
attention_ops
from
cacheflow
import
attention_ops
...
@@ -81,8 +82,10 @@ def ref_multi_query_kv_attention(
...
@@ -81,8 +82,10 @@ def ref_multi_query_kv_attention(
end_idx
=
cu_seq_lens
[
i
+
1
]
end_idx
=
cu_seq_lens
[
i
+
1
]
seq_len
=
end_idx
-
start_idx
seq_len
=
end_idx
-
start_idx
# Create attention mask
# Create attention mask.
attn_mask
=
torch
.
triu
(
torch
.
ones
(
seq_len
,
seq_len
),
diagonal
=
1
)
*
-
1e5
attn_mask
=
torch
.
triu
(
torch
.
ones
(
seq_len
,
seq_len
,
dtype
=
dtype
),
diagonal
=
1
)
attn_mask
=
attn_mask
*
torch
.
finfo
(
dtype
).
min
attn_mask
=
attn_mask
.
to
(
dtype
=
dtype
,
device
=
'cuda'
)
attn_mask
=
attn_mask
.
to
(
dtype
=
dtype
,
device
=
'cuda'
)
ref_output
=
ref_masked_attention
(
ref_output
=
ref_masked_attention
(
...
@@ -160,21 +163,20 @@ def test_single_query_cached_kv_attention(
...
@@ -160,21 +163,20 @@ def test_single_query_cached_kv_attention(
num_blocks
:
int
,
num_blocks
:
int
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
)
->
None
:
)
->
None
:
qkv
=
torch
.
randn
(
qkv
=
torch
.
empty
(
num_tokens
,
3
,
num_heads
,
head_size
,
dtype
=
dtype
,
device
=
'cuda'
)
num_tokens
,
3
,
num_heads
,
head_size
,
dtype
=
dtype
,
device
=
'cuda'
)
qkv
.
uniform_
(
-
1e-3
,
1e-3
)
query
,
_
,
_
=
qkv
.
unbind
(
dim
=
1
)
query
,
_
,
_
=
qkv
.
unbind
(
dim
=
1
)
x
=
16
//
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
x
=
16
//
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
key_block_shape
=
(
num_heads
,
head_size
//
x
,
block_size
,
x
)
key_block_shape
=
(
num_heads
,
head_size
//
x
,
block_size
,
x
)
key_cache
=
torch
.
randn
(
key_cache
=
torch
.
empty
(
size
=
(
num_blocks
,
*
key_block_shape
),
dtype
=
dtype
,
device
=
'cuda'
)
size
=
(
num_blocks
,
*
key_block_shape
),
dtype
=
dtype
,
device
=
'cuda'
)
key_cache
.
uniform_
(
-
1e-3
,
1e-3
)
value_block_shape
=
(
num_heads
,
head_size
,
block_size
)
value_block_shape
=
(
num_heads
,
head_size
,
block_size
)
value_cache
=
torch
.
randn
(
value_cache
=
torch
.
empty
(
size
=
(
num_blocks
,
*
value_block_shape
),
dtype
=
dtype
,
device
=
'cuda'
)
size
=
(
num_blocks
,
*
value_block_shape
),
dtype
=
dtype
,
device
=
'cuda'
)
value_cache
.
uniform_
(
-
1e-3
,
1e-3
)
# Adjust the range of the values to reduce precision errors.
query
=
query
/
(
head_size
**
0.5
)
key_cache
=
key_cache
/
(
head_size
**
0.5
)
value_cache
=
value_cache
/
(
head_size
**
0.5
)
context_lens
=
[
random
.
randint
(
1
,
MAX_SEQ_LEN
)
for
_
in
range
(
num_tokens
)]
context_lens
=
[
random
.
randint
(
1
,
MAX_SEQ_LEN
)
for
_
in
range
(
num_tokens
)]
max_context_len
=
max
(
context_lens
)
max_context_len
=
max
(
context_lens
)
...
@@ -228,39 +230,30 @@ def test_multi_query_kv_attention(
...
@@ -228,39 +230,30 @@ def test_multi_query_kv_attention(
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
)
->
None
:
)
->
None
:
seq_lens
=
random
.
sample
(
range
(
1
,
MAX_SEQ_LEN
),
num_seqs
)
seq_lens
=
random
.
sample
(
range
(
1
,
MAX_SEQ_LEN
),
num_seqs
)
max_seq_len
=
max
(
seq_lens
)
num_tokens
=
sum
(
seq_lens
)
num_tokens
=
sum
(
seq_lens
)
cu_seq_lens
=
[
0
]
for
seq_len
in
seq_lens
:
cu_seq_lens
.
append
(
cu_seq_lens
[
-
1
]
+
seq_len
)
cu_seq_lens
=
torch
.
tensor
(
cu_seq_lens
,
dtype
=
torch
.
int
,
device
=
'cuda'
)
scale
=
float
(
1.0
/
(
head_size
**
0.5
))
scale
=
float
(
1.0
/
(
head_size
**
0.5
))
qkv
=
torch
.
randn
(
qkv
=
torch
.
empty
(
num_tokens
,
3
,
num_heads
,
head_size
,
dtype
=
dtype
,
device
=
'cuda'
)
num_tokens
,
3
,
num_heads
,
head_size
,
dtype
=
dtype
,
device
=
'cuda'
)
# Adjust the range of the values to reduce precision errors.
qkv
.
uniform_
(
-
1e-3
,
1e-3
)
qkv
=
qkv
/
(
head_size
**
0.5
)
query
,
key
,
value
=
qkv
.
unbind
(
dim
=
1
)
query
,
key
,
value
=
qkv
.
unbind
(
dim
=
1
)
output
=
torch
.
empty
(
num_tokens
,
num_heads
,
head_size
,
dtype
=
dtype
,
device
=
'cuda'
)
attn_op
=
xops
.
fmha
.
cutlass
.
FwOp
()
_flash_attn_forward
(
attn_bias
=
BlockDiagonalCausalMask
.
from_seqlens
(
seq_lens
)
query
,
output
=
xops
.
memory_efficient_attention_forward
(
key
,
query
.
unsqueeze
(
0
),
value
,
key
.
unsqueeze
(
0
),
output
,
value
.
unsqueeze
(
0
),
cu_seq_lens
,
attn_bias
=
attn_bias
,
cu_seq_lens
,
p
=
0.0
,
max_seq_len
,
scale
=
scale
,
max_seq_len
,
op
=
attn_op
,
dropout_p
=
0.0
,
softmax_scale
=
scale
,
causal
=
True
,
return_softmax
=
False
,
)
)
output
=
output
.
squeeze
(
0
)
cu_seq_lens
=
cu_seq_lens
.
cpu
().
tolist
()
cu_seq_lens
=
[
0
]
for
seq_len
in
seq_lens
:
cu_seq_lens
.
append
(
cu_seq_lens
[
-
1
]
+
seq_len
)
ref_output
=
ref_multi_query_kv_attention
(
ref_output
=
ref_multi_query_kv_attention
(
cu_seq_lens
,
cu_seq_lens
,
query
,
query
,
...
@@ -277,8 +270,8 @@ def test_attention(seed: int) -> None:
...
@@ -277,8 +270,8 @@ def test_attention(seed: int) -> None:
# the test fails due to the precision issue. Re-run the test if it fails.
# the test fails due to the precision issue. Re-run the test if it fails.
torch
.
random
.
manual_seed
(
seed
)
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
for
dtype
in
[
torch
.
half
,
torch
.
float
]:
for
dtype
in
[
torch
.
half
,
torch
.
b
float
16
]:
for
block_size
in
[
8
,
16
,
32
]:
for
block_size
in
[
8
,
16
,
32
,
64
]:
for
head_size
in
[
32
,
64
,
80
,
96
,
128
,
160
,
192
,
256
]:
for
head_size
in
[
32
,
64
,
80
,
96
,
128
,
160
,
192
,
256
]:
print
(
f
'Testing single_query_cached_kv_attention with '
print
(
f
'Testing single_query_cached_kv_attention with '
f
'dtype=
{
dtype
}
, block_size=
{
block_size
}
, '
f
'dtype=
{
dtype
}
, block_size=
{
block_size
}
, '
...
@@ -292,14 +285,12 @@ def test_attention(seed: int) -> None:
...
@@ -292,14 +285,12 @@ def test_attention(seed: int) -> None:
dtype
=
dtype
,
dtype
=
dtype
,
)
)
# NOTE(woosuk): FlashAttention does not support FP32.
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
]:
for
dtype
in
[
torch
.
half
]:
for
head_size
in
[
32
,
64
,
80
,
96
,
128
,
160
,
192
,
256
]:
# NOTE(woosuk): FlashAttention does not support head_size > 128.
for
head_size
in
[
64
,
80
,
96
,
128
]:
print
(
f
'Testing multi_query_kv_attention with dtype=
{
dtype
}
, '
print
(
f
'Testing multi_query_kv_attention with dtype=
{
dtype
}
, '
f
'head_size=
{
head_size
}
'
)
f
'head_size=
{
head_size
}
'
)
test_multi_query_kv_attention
(
test_multi_query_kv_attention
(
num_seqs
=
11
,
num_seqs
=
5
,
num_heads
=
3
,
num_heads
=
3
,
head_size
=
head_size
,
head_size
=
head_size
,
dtype
=
dtype
,
dtype
=
dtype
,
...
...
tests/kernels/cache.py
View file @
c9d5b6d4
...
@@ -142,15 +142,16 @@ def test_gather_cached_kv(
...
@@ -142,15 +142,16 @@ def test_gather_cached_kv(
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_cache
()
->
None
:
def
test_cache
()
->
None
:
test_copy_blocks
(
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]:
num_mappings
=
23
,
num_layers
=
7
,
num_heads
=
17
,
head_size
=
16
,
test_copy_blocks
(
block_size
=
8
,
num_blocks
=
1024
,
dtype
=
torch
.
half
)
num_mappings
=
23
,
num_layers
=
7
,
num_heads
=
17
,
head_size
=
16
,
test_reshape_and_cache
(
block_size
=
8
,
num_blocks
=
1024
,
dtype
=
dtype
)
num_tokens
=
3
,
num_heads
=
2
,
head_size
=
16
,
block_size
=
8
,
num_blocks
=
2
,
test_reshape_and_cache
(
dtype
=
torch
.
half
)
num_tokens
=
3
,
num_heads
=
2
,
head_size
=
16
,
block_size
=
8
,
num_blocks
=
2
,
test_gather_cached_kv
(
dtype
=
dtype
)
num_tokens
=
3
,
num_heads
=
2
,
head_size
=
16
,
block_size
=
8
,
num_blocks
=
2
,
test_gather_cached_kv
(
dtype
=
torch
.
half
)
num_tokens
=
3
,
num_heads
=
2
,
head_size
=
16
,
block_size
=
8
,
num_blocks
=
2
,
dtype
=
dtype
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
tests/kernels/layernorm.py
View file @
c9d5b6d4
...
@@ -8,7 +8,8 @@ class RefRMSNorm(nn.Module):
...
@@ -8,7 +8,8 @@ class RefRMSNorm(nn.Module):
def
__init__
(
self
,
hidden_size
,
eps
=
1e-6
):
def
__init__
(
self
,
hidden_size
,
eps
=
1e-6
):
super
().
__init__
()
super
().
__init__
()
weight
=
torch
.
randn
(
hidden_size
)
/
(
hidden_size
**
0.5
)
weight
=
torch
.
empty
(
hidden_size
)
weight
.
uniform_
(
-
1e-3
,
1e-3
)
self
.
weight
=
nn
.
Parameter
(
weight
)
self
.
weight
=
nn
.
Parameter
(
weight
)
self
.
variance_epsilon
=
eps
self
.
variance_epsilon
=
eps
...
@@ -41,7 +42,7 @@ def test_rms_norm(
...
@@ -41,7 +42,7 @@ def test_rms_norm(
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
for
dtype
in
[
torch
.
half
,
torch
.
float
]:
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]:
for
num_tokens
in
[
7
,
128
,
2048
]:
for
num_tokens
in
[
7
,
128
,
2048
]:
for
hidden_size
in
[
13
,
64
,
1024
,
5120
]:
for
hidden_size
in
[
13
,
64
,
1024
,
5120
]:
print
(
f
'Testing RMS kernel with dtype=
{
dtype
}
, num_tokens='
print
(
f
'Testing RMS kernel with dtype=
{
dtype
}
, num_tokens='
...
...
tests/kernels/pos_encoding.py
View file @
c9d5b6d4
...
@@ -129,7 +129,7 @@ def test_rotary_embedding_neox(
...
@@ -129,7 +129,7 @@ def test_rotary_embedding_neox(
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
for
dtype
in
[
torch
.
half
,
torch
.
float
]:
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]:
for
head_size
in
[
32
,
64
,
80
,
96
,
128
,
160
,
192
,
256
]:
for
head_size
in
[
32
,
64
,
80
,
96
,
128
,
160
,
192
,
256
]:
print
(
f
'Running tests for head_size=
{
head_size
}
and dtype=
{
dtype
}
'
)
print
(
f
'Running tests for head_size=
{
head_size
}
and dtype=
{
dtype
}
'
)
test_rotary_embedding_neox
(
test_rotary_embedding_neox
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment