Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
c9d5b6d4
Unverified
Commit
c9d5b6d4
authored
May 05, 2023
by
Woosuk Kwon
Committed by
GitHub
May 05, 2023
Browse files
Replace FlashAttention with xformers (#70)
parent
189ae231
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
87 additions
and
131 deletions
+87
-131
README.md
README.md
+1
-5
cacheflow/master/server.py
cacheflow/master/server.py
+1
-1
cacheflow/models/attention.py
cacheflow/models/attention.py
+16
-38
cacheflow/models/input_metadata.py
cacheflow/models/input_metadata.py
+9
-12
cacheflow/models/llama.py
cacheflow/models/llama.py
+2
-2
cacheflow/models/memory_analyzer.py
cacheflow/models/memory_analyzer.py
+6
-6
cacheflow/models/opt.py
cacheflow/models/opt.py
+2
-2
cacheflow/worker/worker.py
cacheflow/worker/worker.py
+0
-8
tests/kernels/activation.py
tests/kernels/activation.py
+1
-1
tests/kernels/attention.py
tests/kernels/attention.py
+35
-44
tests/kernels/cache.py
tests/kernels/cache.py
+10
-9
tests/kernels/layernorm.py
tests/kernels/layernorm.py
+3
-2
tests/kernels/pos_encoding.py
tests/kernels/pos_encoding.py
+1
-1
No files found.
README.md
View file @
c9d5b6d4
...
...
@@ -3,11 +3,7 @@
## Installation
```
bash
pip
install
psutil numpy ray torch
pip
install
git+https://github.com/huggingface/transformers
# Required for LLaMA.
pip
install
sentencepiece
# Required for LlamaTokenizer.
pip
install
ninja
# To parallelize the compilation of flash-attn.
pip
install
flash-attn
# This may take up to 10 mins.
pip
install
ninja psutil numpy sentencepiece ray torch transformers xformers
pip
install
-e
.
```
...
...
cacheflow/master/server.py
View file @
c9d5b6d4
...
...
@@ -213,7 +213,7 @@ def add_server_arguments(parser: argparse.ArgumentParser):
parser
.
add_argument
(
'--use-np-cache'
,
action
=
'store_true'
,
help
=
'save a numpy copy of model weights for faster loading'
)
parser
.
add_argument
(
'--use-dummy-weights'
,
action
=
'store_true'
,
help
=
'use dummy values for model weights'
)
#
NOTE
(woosuk):
FlashAttention does not support float32
.
#
TODO
(woosuk):
Support FP32 for debugging
.
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
'default'
,
choices
=
[
'default'
,
'half'
,
'bfloat16'
],
help
=
(
'data type for model weights and activations. '
'The "default" option will use FP16 precision '
...
...
cacheflow/models/attention.py
View file @
c9d5b6d4
from
typing
import
Optional
from
flash_attn.flash_attn_interface
import
_flash_attn_forward
import
torch
import
torch.nn
as
nn
from
xformers
import
ops
as
xops
from
cacheflow
import
attention_ops
from
cacheflow
import
cache_ops
...
...
@@ -15,6 +15,7 @@ class GPTCacheFlowAttention(nn.Module):
def
__init__
(
self
,
scale
:
float
)
->
None
:
super
().
__init__
()
self
.
scale
=
float
(
scale
)
self
.
attn_op
=
xops
.
fmha
.
cutlass
.
FwOp
()
def
multi_query_kv_attention
(
self
,
...
...
@@ -22,32 +23,21 @@ class GPTCacheFlowAttention(nn.Module):
query
:
torch
.
Tensor
,
# [num_prompt_tokens, num_heads, head_size]
key
:
torch
.
Tensor
,
# [num_prompt_tokens, num_heads, head_size]
value
:
torch
.
Tensor
,
# [num_prompt_tokens, num_heads, head_size]
cumulative_prompt_lens
:
torch
.
Tensor
,
# [num_prompts + 1]
max_prompt_len
:
int
,
attn_bias
:
xops
.
AttentionBias
,
)
->
None
:
if
query
.
dtype
==
torch
.
float
:
raise
ValueError
(
'The float data type is not supported by '
'FlashAttention. Use the half data type instead.'
)
head_size
=
query
.
shape
[
-
1
]
if
head_size
>
128
:
raise
ValueError
(
'FlashAttention does not support head_size > 128.'
)
# Directly call FlashAttention's internal function to avoid allocating
# a new tensor for the output.
_flash_attn_forward
(
query
,
key
,
value
,
output
,
cumulative_prompt_lens
,
cumulative_prompt_lens
,
max_prompt_len
,
max_prompt_len
,
dropout_p
=
0.0
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
return_softmax
=
False
,
# TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
out
=
xops
.
memory_efficient_attention_forward
(
query
.
unsqueeze
(
0
),
key
.
unsqueeze
(
0
),
value
.
unsqueeze
(
0
),
attn_bias
=
attn_bias
,
p
=
0.0
,
scale
=
self
.
scale
,
op
=
self
.
attn_op
,
)
# TODO(woosuk): Unnecessary copy. Optimize.
output
.
copy_
(
out
.
squeeze
(
0
))
return
output
def
single_query_cached_kv_attention
(
self
,
...
...
@@ -109,8 +99,7 @@ class GPTCacheFlowAttention(nn.Module):
query
[:
num_prompt_tokens
],
key
[:
num_prompt_tokens
],
value
[:
num_prompt_tokens
],
input_metadata
.
cumulative_prompt_lens
,
input_metadata
.
max_prompt_len
,
input_metadata
.
attn_bias
,
)
# Wait until the cache op is done.
...
...
@@ -143,13 +132,6 @@ class GPTCacheFlowAttention(nn.Module):
return
output
.
view
(
-
1
,
num_heads
*
head_size
)
class
OPTCacheFlowAttention
(
GPTCacheFlowAttention
):
"""OPT uses the same attention mechanism as GPT."""
def
__init__
(
self
,
scale
:
float
)
->
None
:
super
().
__init__
(
scale
)
class
GPTNeoXCacheFlowAttention
(
GPTCacheFlowAttention
):
"""Attention with GPT-NeoX style rotary embedding."""
...
...
@@ -207,7 +189,3 @@ class GPTNeoXCacheFlowAttention(GPTCacheFlowAttention):
input_metadata
,
cache_event
,
)
class
LlamaCacheFlowAttention
(
GPTNeoXCacheFlowAttention
):
"""LLaMA uses the GPT-NeoX style rotary embedding."""
cacheflow/models/input_metadata.py
View file @
c9d5b6d4
from
typing
import
List
,
Dict
,
Tuple
import
torch
from
xformers.ops.fmha.attn_bias
import
BlockDiagonalCausalMask
from
cacheflow.sampling_params
import
SamplingParams
...
...
@@ -12,7 +13,6 @@ class InputMetadata:
seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]],
seq_logprobs
:
Dict
[
int
,
float
],
# Seq id -> cumulative logprobs.
prompt_lens
:
List
[
int
],
cumulative_prompt_lens
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
context_lens
:
torch
.
Tensor
,
max_context_len
:
int
,
...
...
@@ -21,15 +21,14 @@ class InputMetadata:
self
.
seq_groups
=
seq_groups
self
.
seq_logprobs
=
seq_logprobs
self
.
prompt_lens
=
prompt_lens
self
.
cumulative_prompt_lens
=
cumulative_prompt_lens
self
.
slot_mapping
=
slot_mapping
self
.
context_lens
=
context_lens
self
.
max_context_len
=
max_context_len
self
.
block_tables
=
block_tables
self
.
attn_bias
=
BlockDiagonalCausalMask
.
from_seqlens
(
prompt_lens
)
self
.
num_prompts
=
len
(
prompt_lens
)
self
.
num_prompt_tokens
=
sum
(
prompt_lens
)
self
.
max_prompt_len
=
max
(
prompt_lens
)
if
prompt_lens
else
0
self
.
num_generation_tokens
=
context_lens
.
shape
[
0
]
self
.
num_valid_tokens
=
slot_mapping
.
shape
[
0
]
if
block_tables
.
numel
()
>
0
:
...
...
@@ -41,15 +40,13 @@ class InputMetadata:
def
__repr__
(
self
)
->
str
:
return
(
f
'InputMetadata('
f
'num_prompts=
{
self
.
num_prompts
}
, '
f
'num_prompt_tokens=
{
self
.
num_prompt_tokens
}
, '
f
'max_prompt_len=
{
self
.
max_prompt_len
}
, '
f
'num_generation_tokens=
{
self
.
num_generation_tokens
}
, '
f
'num_valid_tokens=
{
self
.
num_valid_tokens
}
, '
f
'
max_num_blocks_per_seq=
{
self
.
max_num_blocks_per_seq
}
, '
f
'
max_context_len=
{
self
.
max_context_len
}
)
, '
f
'
num_prompt_tokens=
{
self
.
num_prompt_tokens
}
, '
f
'
num_prompts=
{
self
.
num_prompts
}
, '
f
'prompt_lens=
{
self
.
prompt_lens
}
, '
f
'cumulative_prompt_lens=
{
self
.
cumulative_prompt_lens
}
, '
f
'slot_mapping=
{
self
.
slot_mapping
}
, '
f
'num_generation_tokens=
{
self
.
num_generation_tokens
}
, '
f
'context_lens=
{
self
.
context_lens
}
, '
f
'block_tables=
{
self
.
block_tables
}
)'
)
f
'max_context_len=
{
self
.
max_context_len
}
), '
f
'max_num_blocks_per_seq=
{
self
.
max_num_blocks_per_seq
}
, '
f
'block_tables=
{
self
.
block_tables
}
), '
f
'slot_mapping=
{
self
.
slot_mapping
}
'
)
cacheflow/models/llama.py
View file @
c9d5b6d4
...
...
@@ -7,7 +7,7 @@ from transformers import LlamaConfig
from
cacheflow.models
import
InputMetadata
from
cacheflow.models.activation
import
SiluAndMul
from
cacheflow.models.attention
import
Llama
CacheFlowAttention
from
cacheflow.models.attention
import
GPTNeoX
CacheFlowAttention
from
cacheflow.models.layernorm
import
RMSNorm
from
cacheflow.models.sample
import
Sampler
from
cacheflow.models.utils
import
(
hf_model_weights_iterator
,
...
...
@@ -79,7 +79,7 @@ class LlamaAttention(nn.Module):
input_is_parallel
=
True
,
perform_initialization
=
False
,
)
self
.
attn
=
Llama
CacheFlowAttention
(
self
.
scaling
,
self
.
head_dim
)
self
.
attn
=
GPTNeoX
CacheFlowAttention
(
self
.
scaling
,
self
.
head_dim
)
def
forward
(
self
,
...
...
cacheflow/models/memory_analyzer.py
View file @
c9d5b6d4
...
...
@@ -202,8 +202,8 @@ class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer):
# estimating
# 1) the maximum activation tensor size during inference
# 2) the residual tensor size during inference
# Here, we assume that
FlashAttention is used and
#
thus the attention maps are never materialized
in GPU DRAM.
# Here, we assume that
we use memory-efficient attention which
#
does not materialize the attention maps
in GPU DRAM.
residual
=
max_num_batched_tokens
*
self
.
hidden_size
qkv
=
3
*
(
max_num_batched_tokens
*
self
.
hidden_size
)
//
self
.
tensor_parallel_size
ffn
=
max_num_batched_tokens
*
self
.
ffn_size
//
self
.
tensor_parallel_size
...
...
@@ -277,8 +277,8 @@ class LlamaMemoryAnalyzer(CacheFlowMemoryAnalyzer):
# estimating
# 1) the maximum activation tensor size during inference
# 2) the residual tensor size during inference
# Here, we assume that
FlashAttention is used and
#
thus the attention maps are never materialized
in GPU DRAM.
# Here, we assume that
we use memory-efficient attention which
#
does not materialize the attention maps
in GPU DRAM.
residual
=
max_num_batched_tokens
*
self
.
hidden_size
qkv
=
3
*
(
max_num_batched_tokens
*
self
.
hidden_size
)
//
self
.
tensor_parallel_size
ffn
=
2
*
(
max_num_batched_tokens
*
self
.
ffn_size
)
//
self
.
tensor_parallel_size
...
...
@@ -353,8 +353,8 @@ class GPTNeoXMemoryAnalyzer(CacheFlowMemoryAnalyzer):
# estimating
# 1) the maximum activation tensor size during inference
# 2) the residual tensor size during inference
# Here, we assume that
FlashAttention is used and
#
thus the attention maps are never materialized
in GPU DRAM.
# Here, we assume that
we use memory-efficient attention which
#
does not materialize the attention maps
in GPU DRAM.
residual
=
max_num_batched_tokens
*
self
.
hidden_size
qkv
=
3
*
(
max_num_batched_tokens
*
self
.
hidden_size
)
//
self
.
tensor_parallel_size
ffn
=
2
*
(
max_num_batched_tokens
*
self
.
ffn_size
)
//
self
.
tensor_parallel_size
...
...
cacheflow/models/opt.py
View file @
c9d5b6d4
...
...
@@ -6,7 +6,7 @@ from torch import nn
from
transformers
import
OPTConfig
from
cacheflow.models
import
InputMetadata
from
cacheflow.models.attention
import
O
PTCacheFlowAttention
from
cacheflow.models.attention
import
G
PTCacheFlowAttention
from
cacheflow.models.sample
import
Sampler
from
cacheflow.models.utils
import
(
hf_model_weights_iterator
,
load_tensor_parallel_weights
)
...
...
@@ -55,7 +55,7 @@ class OPTAttention(nn.Module):
self
.
out_proj
=
RowParallelLinear
(
embed_dim
,
embed_dim
,
bias
=
bias
,
input_is_parallel
=
True
,
perform_initialization
=
False
)
self
.
attn
=
O
PTCacheFlowAttention
(
scale
=
self
.
scaling
)
self
.
attn
=
G
PTCacheFlowAttention
(
scale
=
self
.
scaling
)
def
forward
(
self
,
...
...
cacheflow/worker/worker.py
View file @
c9d5b6d4
...
...
@@ -136,11 +136,6 @@ class Worker:
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
.
append
(
slot
)
cumulative_prompt_lens
:
List
[
int
]
=
[
0
]
for
prompt_len
in
prompt_lens
:
cumulative_prompt_lens
.
append
(
cumulative_prompt_lens
[
-
1
]
+
prompt_len
)
# Add generation tokens.
max_context_len
=
0
max_num_blocks_per_seq
=
0
...
...
@@ -196,14 +191,11 @@ class Worker:
for
block_table
in
generation_block_tables
]
block_tables_tensor
=
torch
.
tensor
(
padded_block_tables
,
dtype
=
torch
.
int
,
device
=
'cuda'
)
cumulative_prompt_lens_tensor
=
torch
.
tensor
(
cumulative_prompt_lens
,
dtype
=
torch
.
int
,
device
=
'cuda'
)
input_metadata
=
InputMetadata
(
seq_groups
=
seq_groups
,
seq_logprobs
=
seq_logprobs
,
prompt_lens
=
prompt_lens
,
cumulative_prompt_lens
=
cumulative_prompt_lens_tensor
,
slot_mapping
=
slot_mapping_tensor
,
context_lens
=
context_lens_tensor
,
max_context_len
=
max_context_len
,
...
...
tests/kernels/activation.py
View file @
c9d5b6d4
...
...
@@ -23,7 +23,7 @@ def test_silu_and_mul(
if
__name__
==
'__main__'
:
for
dtype
in
[
torch
.
half
,
torch
.
float
]:
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]:
for
num_tokens
in
[
7
,
83
,
2048
]:
for
d
in
[
512
,
4096
,
13824
]:
print
(
f
'Testing dtype=
{
dtype
}
, num_tokens=
{
num_tokens
}
, d=
{
d
}
'
)
...
...
tests/kernels/attention.py
View file @
c9d5b6d4
import
random
from
typing
import
List
,
Optional
from
flash_attn.flash_attn_interface
import
_flash_attn_forward
import
torch
from
xformers
import
ops
as
xops
from
xformers.ops.fmha.attn_bias
import
BlockDiagonalCausalMask
from
cacheflow
import
attention_ops
...
...
@@ -81,8 +82,10 @@ def ref_multi_query_kv_attention(
end_idx
=
cu_seq_lens
[
i
+
1
]
seq_len
=
end_idx
-
start_idx
# Create attention mask
attn_mask
=
torch
.
triu
(
torch
.
ones
(
seq_len
,
seq_len
),
diagonal
=
1
)
*
-
1e5
# Create attention mask.
attn_mask
=
torch
.
triu
(
torch
.
ones
(
seq_len
,
seq_len
,
dtype
=
dtype
),
diagonal
=
1
)
attn_mask
=
attn_mask
*
torch
.
finfo
(
dtype
).
min
attn_mask
=
attn_mask
.
to
(
dtype
=
dtype
,
device
=
'cuda'
)
ref_output
=
ref_masked_attention
(
...
...
@@ -160,21 +163,20 @@ def test_single_query_cached_kv_attention(
num_blocks
:
int
,
dtype
:
torch
.
dtype
,
)
->
None
:
qkv
=
torch
.
randn
(
qkv
=
torch
.
empty
(
num_tokens
,
3
,
num_heads
,
head_size
,
dtype
=
dtype
,
device
=
'cuda'
)
qkv
.
uniform_
(
-
1e-3
,
1e-3
)
query
,
_
,
_
=
qkv
.
unbind
(
dim
=
1
)
x
=
16
//
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
key_block_shape
=
(
num_heads
,
head_size
//
x
,
block_size
,
x
)
key_cache
=
torch
.
randn
(
key_cache
=
torch
.
empty
(
size
=
(
num_blocks
,
*
key_block_shape
),
dtype
=
dtype
,
device
=
'cuda'
)
key_cache
.
uniform_
(
-
1e-3
,
1e-3
)
value_block_shape
=
(
num_heads
,
head_size
,
block_size
)
value_cache
=
torch
.
randn
(
value_cache
=
torch
.
empty
(
size
=
(
num_blocks
,
*
value_block_shape
),
dtype
=
dtype
,
device
=
'cuda'
)
# Adjust the range of the values to reduce precision errors.
query
=
query
/
(
head_size
**
0.5
)
key_cache
=
key_cache
/
(
head_size
**
0.5
)
value_cache
=
value_cache
/
(
head_size
**
0.5
)
value_cache
.
uniform_
(
-
1e-3
,
1e-3
)
context_lens
=
[
random
.
randint
(
1
,
MAX_SEQ_LEN
)
for
_
in
range
(
num_tokens
)]
max_context_len
=
max
(
context_lens
)
...
...
@@ -228,39 +230,30 @@ def test_multi_query_kv_attention(
dtype
:
torch
.
dtype
,
)
->
None
:
seq_lens
=
random
.
sample
(
range
(
1
,
MAX_SEQ_LEN
),
num_seqs
)
max_seq_len
=
max
(
seq_lens
)
num_tokens
=
sum
(
seq_lens
)
cu_seq_lens
=
[
0
]
for
seq_len
in
seq_lens
:
cu_seq_lens
.
append
(
cu_seq_lens
[
-
1
]
+
seq_len
)
cu_seq_lens
=
torch
.
tensor
(
cu_seq_lens
,
dtype
=
torch
.
int
,
device
=
'cuda'
)
scale
=
float
(
1.0
/
(
head_size
**
0.5
))
qkv
=
torch
.
randn
(
qkv
=
torch
.
empty
(
num_tokens
,
3
,
num_heads
,
head_size
,
dtype
=
dtype
,
device
=
'cuda'
)
# Adjust the range of the values to reduce precision errors.
qkv
=
qkv
/
(
head_size
**
0.5
)
qkv
.
uniform_
(
-
1e-3
,
1e-3
)
query
,
key
,
value
=
qkv
.
unbind
(
dim
=
1
)
output
=
torch
.
empty
(
num_tokens
,
num_heads
,
head_size
,
dtype
=
dtype
,
device
=
'cuda'
)
_flash_attn_forward
(
query
,
key
,
value
,
output
,
cu_seq_lens
,
cu_seq_lens
,
max_seq_len
,
max_seq_len
,
dropout_p
=
0.0
,
softmax_scale
=
scale
,
causal
=
True
,
return_softmax
=
False
,
attn_op
=
xops
.
fmha
.
cutlass
.
FwOp
()
attn_bias
=
BlockDiagonalCausalMask
.
from_seqlens
(
seq_lens
)
output
=
xops
.
memory_efficient_attention_forward
(
query
.
unsqueeze
(
0
),
key
.
unsqueeze
(
0
),
value
.
unsqueeze
(
0
),
attn_bias
=
attn_bias
,
p
=
0.0
,
scale
=
scale
,
op
=
attn_op
,
)
output
=
output
.
squeeze
(
0
)
cu_seq_lens
=
cu_seq_lens
.
cpu
().
tolist
()
cu_seq_lens
=
[
0
]
for
seq_len
in
seq_lens
:
cu_seq_lens
.
append
(
cu_seq_lens
[
-
1
]
+
seq_len
)
ref_output
=
ref_multi_query_kv_attention
(
cu_seq_lens
,
query
,
...
...
@@ -277,8 +270,8 @@ def test_attention(seed: int) -> None:
# the test fails due to the precision issue. Re-run the test if it fails.
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
for
dtype
in
[
torch
.
half
,
torch
.
float
]:
for
block_size
in
[
8
,
16
,
32
]:
for
dtype
in
[
torch
.
half
,
torch
.
b
float
16
]:
for
block_size
in
[
8
,
16
,
32
,
64
]:
for
head_size
in
[
32
,
64
,
80
,
96
,
128
,
160
,
192
,
256
]:
print
(
f
'Testing single_query_cached_kv_attention with '
f
'dtype=
{
dtype
}
, block_size=
{
block_size
}
, '
...
...
@@ -292,14 +285,12 @@ def test_attention(seed: int) -> None:
dtype
=
dtype
,
)
# NOTE(woosuk): FlashAttention does not support FP32.
for
dtype
in
[
torch
.
half
]:
# NOTE(woosuk): FlashAttention does not support head_size > 128.
for
head_size
in
[
64
,
80
,
96
,
128
]:
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
]:
for
head_size
in
[
32
,
64
,
80
,
96
,
128
,
160
,
192
,
256
]:
print
(
f
'Testing multi_query_kv_attention with dtype=
{
dtype
}
, '
f
'head_size=
{
head_size
}
'
)
test_multi_query_kv_attention
(
num_seqs
=
11
,
num_seqs
=
5
,
num_heads
=
3
,
head_size
=
head_size
,
dtype
=
dtype
,
...
...
tests/kernels/cache.py
View file @
c9d5b6d4
...
...
@@ -142,15 +142,16 @@ def test_gather_cached_kv(
@
torch
.
inference_mode
()
def
test_cache
()
->
None
:
test_copy_blocks
(
num_mappings
=
23
,
num_layers
=
7
,
num_heads
=
17
,
head_size
=
16
,
block_size
=
8
,
num_blocks
=
1024
,
dtype
=
torch
.
half
)
test_reshape_and_cache
(
num_tokens
=
3
,
num_heads
=
2
,
head_size
=
16
,
block_size
=
8
,
num_blocks
=
2
,
dtype
=
torch
.
half
)
test_gather_cached_kv
(
num_tokens
=
3
,
num_heads
=
2
,
head_size
=
16
,
block_size
=
8
,
num_blocks
=
2
,
dtype
=
torch
.
half
)
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]:
test_copy_blocks
(
num_mappings
=
23
,
num_layers
=
7
,
num_heads
=
17
,
head_size
=
16
,
block_size
=
8
,
num_blocks
=
1024
,
dtype
=
dtype
)
test_reshape_and_cache
(
num_tokens
=
3
,
num_heads
=
2
,
head_size
=
16
,
block_size
=
8
,
num_blocks
=
2
,
dtype
=
dtype
)
test_gather_cached_kv
(
num_tokens
=
3
,
num_heads
=
2
,
head_size
=
16
,
block_size
=
8
,
num_blocks
=
2
,
dtype
=
dtype
)
if
__name__
==
'__main__'
:
...
...
tests/kernels/layernorm.py
View file @
c9d5b6d4
...
...
@@ -8,7 +8,8 @@ class RefRMSNorm(nn.Module):
def
__init__
(
self
,
hidden_size
,
eps
=
1e-6
):
super
().
__init__
()
weight
=
torch
.
randn
(
hidden_size
)
/
(
hidden_size
**
0.5
)
weight
=
torch
.
empty
(
hidden_size
)
weight
.
uniform_
(
-
1e-3
,
1e-3
)
self
.
weight
=
nn
.
Parameter
(
weight
)
self
.
variance_epsilon
=
eps
...
...
@@ -41,7 +42,7 @@ def test_rms_norm(
if
__name__
==
'__main__'
:
for
dtype
in
[
torch
.
half
,
torch
.
float
]:
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]:
for
num_tokens
in
[
7
,
128
,
2048
]:
for
hidden_size
in
[
13
,
64
,
1024
,
5120
]:
print
(
f
'Testing RMS kernel with dtype=
{
dtype
}
, num_tokens='
...
...
tests/kernels/pos_encoding.py
View file @
c9d5b6d4
...
...
@@ -129,7 +129,7 @@ def test_rotary_embedding_neox(
if
__name__
==
'__main__'
:
for
dtype
in
[
torch
.
half
,
torch
.
float
]:
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]:
for
head_size
in
[
32
,
64
,
80
,
96
,
128
,
160
,
192
,
256
]:
print
(
f
'Running tests for head_size=
{
head_size
}
and dtype=
{
dtype
}
'
)
test_rotary_embedding_neox
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment