Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0f9a6e3d
Unverified
Commit
0f9a6e3d
authored
May 09, 2024
by
DefTruth
Committed by
GitHub
May 08, 2024
Browse files
[Bugfix][Kernel] allow non-power-of-2 for prefix prefill with alibi (#4573)
parent
f6a59309
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
267 additions
and
17 deletions
+267
-17
tests/kernels/test_prefix_prefill.py
tests/kernels/test_prefix_prefill.py
+242
-1
vllm/attention/ops/prefix_prefill.py
vllm/attention/ops/prefix_prefill.py
+25
-16
No files found.
tests/kernels/test_prefix_prefill.py
View file @
0f9a6e3d
import
math
import
random
import
random
import
time
import
time
...
@@ -6,11 +7,12 @@ import torch
...
@@ -6,11 +7,12 @@ import torch
from
xformers
import
ops
as
xops
from
xformers
import
ops
as
xops
from
xformers.ops.fmha.attn_bias
import
BlockDiagonalCausalFromBottomRightMask
from
xformers.ops.fmha.attn_bias
import
BlockDiagonalCausalFromBottomRightMask
from
vllm.attention.backends.xformers
import
_make_alibi_bias
from
vllm.attention.ops.prefix_prefill
import
context_attention_fwd
from
vllm.attention.ops.prefix_prefill
import
context_attention_fwd
NUM_HEADS
=
[
64
]
NUM_HEADS
=
[
64
]
NUM_QUERIES_PER_KV
=
[
1
,
8
,
64
]
NUM_QUERIES_PER_KV
=
[
1
,
8
,
64
]
HEAD_SIZES
=
[
128
,
96
]
HEAD_SIZES
=
[
128
,
96
,
24
]
DTYPES
=
[
torch
.
float16
]
DTYPES
=
[
torch
.
float16
]
CUDA_DEVICES
=
[
CUDA_DEVICES
=
[
f
"cuda:
{
i
}
"
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
2
)
f
"cuda:
{
i
}
"
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
2
)
...
@@ -207,3 +209,242 @@ def test_contexted_kv_attention(
...
@@ -207,3 +209,242 @@ def test_contexted_kv_attention(
print
(
f
"xformers Time:
{
(
end_time
-
start_time
)
*
1000
:.
2
f
}
ms"
)
print
(
f
"xformers Time:
{
(
end_time
-
start_time
)
*
1000
:.
2
f
}
ms"
)
output_ref
=
output_ref
.
reshape
(
output
.
shape
)
output_ref
=
output_ref
.
reshape
(
output
.
shape
)
assert
torch
.
allclose
(
output_ref
,
output
,
atol
=
1e-6
,
rtol
=
0
)
assert
torch
.
allclose
(
output_ref
,
output
,
atol
=
1e-6
,
rtol
=
0
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"num_queries_per_kv"
,
NUM_QUERIES_PER_KV
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
def
test_contexted_kv_attention_alibi
(
num_heads
:
int
,
num_queries_per_kv
:
int
,
head_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
str
,
)
->
None
:
random
.
seed
(
0
)
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed
(
0
)
torch
.
set_default_device
(
device
)
# Need this, otherwise when we capture the graph the process
# for GPU 1 would run on both GPU0 and GPU1 and things would hang
#
# see also similar issue: https://github.com/Dao-AILab/flash-attention/issues/523
torch
.
cuda
.
set_device
(
device
)
def
_get_alibi_slopes
(
total_num_heads
:
int
)
->
torch
.
Tensor
:
# Fork from: vllm/vllm/model_executor/models/bloom.py#L44
closest_power_of_2
=
2
**
math
.
floor
(
math
.
log2
(
total_num_heads
))
base
=
torch
.
tensor
(
2
**
(
-
(
2
**-
(
math
.
log2
(
closest_power_of_2
)
-
3
))),
dtype
=
torch
.
float32
,
)
powers
=
torch
.
arange
(
1
,
1
+
closest_power_of_2
,
dtype
=
torch
.
int32
)
slopes
=
torch
.
pow
(
base
,
powers
)
if
closest_power_of_2
!=
total_num_heads
:
extra_base
=
torch
.
tensor
(
2
**
(
-
(
2
**-
(
math
.
log2
(
2
*
closest_power_of_2
)
-
3
))),
dtype
=
torch
.
float32
,
)
num_remaining_heads
=
min
(
closest_power_of_2
,
total_num_heads
-
closest_power_of_2
)
extra_powers
=
torch
.
arange
(
start
=
1
,
end
=
1
+
2
*
num_remaining_heads
,
step
=
2
,
dtype
=
torch
.
int32
)
slopes
=
torch
.
cat
(
[
slopes
,
torch
.
pow
(
extra_base
,
extra_powers
)],
dim
=
0
)
return
slopes
alibi_slopes
=
_get_alibi_slopes
(
num_heads
).
to
(
device
)
MAX_SEQ_LEN
=
1024
MAX_CTX_LEN
=
1024
BS
=
10
cache_size
=
640
block_size
=
32
max_block_per_request
=
64
query_lens
=
[
random
.
randint
(
16
,
MAX_SEQ_LEN
)
for
_
in
range
(
BS
)]
ctx_lens
=
[
random
.
randint
(
16
,
MAX_CTX_LEN
)
for
_
in
range
(
BS
)]
seq_lens
=
[
a
+
b
for
a
,
b
in
zip
(
query_lens
,
ctx_lens
)]
num_kv_heads
=
num_heads
//
num_queries_per_kv
num_tokens
=
sum
(
query_lens
)
query
=
torch
.
empty
(
num_tokens
,
num_heads
,
head_size
,
dtype
=
dtype
)
query
.
uniform_
(
-
1e-3
,
1e-3
)
output
=
torch
.
empty
(
num_tokens
,
num_heads
,
head_size
,
dtype
=
dtype
)
kv
=
torch
.
empty
(
sum
(
seq_lens
),
2
,
num_kv_heads
,
head_size
,
dtype
=
dtype
)
kv
.
uniform_
(
-
1e-3
,
1e-3
)
key
,
value
=
kv
.
unbind
(
dim
=
1
)
k_cache
=
torch
.
zeros
(
cache_size
,
block_size
,
num_kv_heads
,
head_size
,
dtype
=
dtype
)
v_cache
=
torch
.
zeros
(
cache_size
,
block_size
,
num_kv_heads
,
head_size
,
dtype
=
dtype
)
k
=
torch
.
zeros
(
sum
(
query_lens
),
num_kv_heads
,
head_size
,
dtype
=
dtype
)
v
=
torch
.
zeros
(
sum
(
query_lens
),
num_kv_heads
,
head_size
,
dtype
=
dtype
)
values
=
torch
.
arange
(
0
,
cache_size
,
dtype
=
torch
.
long
)
values
=
values
[
torch
.
randperm
(
cache_size
)]
block_table
=
values
[:
BS
*
max_block_per_request
].
view
(
BS
,
max_block_per_request
)
b_seq_len
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
long
)
b_ctx_len
=
torch
.
tensor
(
ctx_lens
,
dtype
=
torch
.
long
)
b_start_loc
=
torch
.
cumsum
(
torch
.
tensor
([
0
]
+
query_lens
[:
-
1
],
dtype
=
torch
.
long
),
dim
=
0
)
max_input_len
=
MAX_SEQ_LEN
# copy kv to cache
b_seq_start_loc
=
torch
.
cumsum
(
torch
.
tensor
([
0
]
+
seq_lens
[:
-
1
],
dtype
=
torch
.
long
),
dim
=
0
)
for
i
in
range
(
BS
):
for
j
in
range
(
query_lens
[
i
]):
k
[
b_start_loc
[
i
]
+
j
].
copy_
(
key
[
b_seq_start_loc
[
i
]
+
b_ctx_len
[
i
]
+
j
])
v
[
b_start_loc
[
i
]
+
j
].
copy_
(
value
[
b_seq_start_loc
[
i
]
+
b_ctx_len
[
i
]
+
j
])
cur_ctx
=
0
block_id
=
0
while
cur_ctx
<
b_ctx_len
[
i
]:
start_loc
=
b_seq_start_loc
[
i
]
+
cur_ctx
if
cur_ctx
+
block_size
>
b_ctx_len
[
i
]:
end_loc
=
b_seq_start_loc
[
i
]
+
b_ctx_len
[
i
]
else
:
end_loc
=
start_loc
+
block_size
start_slot
=
block_table
[
i
,
block_id
]
*
block_size
end_slot
=
start_slot
+
end_loc
-
start_loc
k_cache
.
view
(
-
1
,
num_kv_heads
,
head_size
)[
start_slot
:
end_slot
].
copy_
(
key
[
start_loc
:
end_loc
])
v_cache
.
view
(
-
1
,
num_kv_heads
,
head_size
)[
start_slot
:
end_slot
].
copy_
(
value
[
start_loc
:
end_loc
])
cur_ctx
+=
block_size
block_id
+=
1
# transpose K_cache[num_blocks, block_size, num_kv_heads, head_size]
# to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8]
k_cache
=
k_cache
.
view
(
-
1
,
block_size
,
num_kv_heads
,
head_size
//
8
,
8
).
permute
(
0
,
2
,
3
,
1
,
4
).
contiguous
()
# transpose V_cache[num_blocks, block_size, num_kv_heads, head_size]
# to V_cache[num_blocks, num_kv_heads, head_size, block_size]
v_cache
=
v_cache
.
view
(
-
1
,
block_size
,
num_kv_heads
,
head_size
).
permute
(
0
,
2
,
3
,
1
).
contiguous
()
# Warm up the Triton kernel by calling it once before actually measuring
# generation time
context_attention_fwd
(
query
,
k
,
v
,
output
,
k_cache
,
v_cache
,
block_table
,
b_start_loc
,
b_seq_len
,
b_ctx_len
,
max_input_len
,
alibi_slopes
=
alibi_slopes
)
torch
.
cuda
.
synchronize
()
start_time
=
time
.
time
()
context_attention_fwd
(
query
,
k
,
v
,
output
,
k_cache
,
v_cache
,
block_table
,
b_start_loc
,
b_seq_len
,
b_ctx_len
,
max_input_len
,
alibi_slopes
=
alibi_slopes
)
torch
.
cuda
.
synchronize
()
end_time
=
time
.
time
()
print
(
f
"triton Time:
{
(
end_time
-
start_time
)
*
1000
:.
2
f
}
ms"
)
scale
=
float
(
1.0
/
(
head_size
**
0.5
))
# NOTE(DefTruth): In order to reuse _make_alibi_bias function,
# we have to pad query tensor before MQA/GQA expanding.
if
query
.
shape
[
0
]
!=
key
.
shape
[
0
]:
query_pad
=
torch
.
empty
(
sum
(
seq_lens
),
num_heads
,
head_size
,
dtype
=
dtype
)
query_pad
.
uniform_
(
-
1e-3
,
1e-3
)
seq_start
=
0
query_start
=
0
for
i
,
(
query_len
,
seq_len
)
in
enumerate
(
zip
(
query_lens
,
seq_lens
)):
seq_end
=
seq_start
+
seq_len
query_end
=
query_start
+
query_len
query_pad
[
seq_start
:
seq_end
,
...]
=
torch
.
cat
([
torch
.
zeros
(
seq_len
-
query_len
,
num_heads
,
head_size
,
dtype
=
dtype
),
query
[
query_start
:
query_end
,
...]
],
dim
=
0
)
seq_start
+=
seq_len
query_start
+=
query_len
query
=
query_pad
if
num_kv_heads
!=
num_heads
:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of
# heads.
#
# see also: vllm/model_executor/layers/attention.py
query
=
query
.
view
(
query
.
shape
[
0
],
num_kv_heads
,
num_queries_per_kv
,
query
.
shape
[
-
1
])
key
=
key
[:,
:,
None
,
:].
expand
(
key
.
shape
[
0
],
num_kv_heads
,
num_queries_per_kv
,
key
.
shape
[
-
1
])
value
=
value
[:,
:,
None
,
:].
expand
(
value
.
shape
[
0
],
num_kv_heads
,
num_queries_per_kv
,
value
.
shape
[
-
1
])
query
=
query
.
unsqueeze
(
0
)
key
=
key
.
unsqueeze
(
0
)
value
=
value
.
unsqueeze
(
0
)
attn_bias
=
_make_alibi_bias
(
alibi_slopes
,
num_kv_heads
,
dtype
,
seq_lens
)
output_ref
=
torch
.
empty_like
(
output
)
seq_start
=
0
query_start
=
0
start_time
=
time
.
time
()
# Attention with alibi slopes.
# FIXME(DefTruth): Because xformers does not support dynamic sequence
# lengths with custom attention bias, we process each prompt one by
# one. This is inefficient, especially when we have many short prompts.
# modified from: vllm/attention/backends/xformers.py#L343
for
i
,
(
query_len
,
seq_len
)
in
enumerate
(
zip
(
query_lens
,
seq_lens
)):
seq_end
=
seq_start
+
seq_len
query_end
=
query_start
+
query_len
out
=
xops
.
memory_efficient_attention_forward
(
query
[:,
seq_start
:
seq_end
],
key
[:,
seq_start
:
seq_end
],
value
[:,
seq_start
:
seq_end
],
attn_bias
=
attn_bias
[
i
],
p
=
0.0
,
scale
=
scale
)
out
=
out
.
view_as
(
query
[:,
seq_start
:
seq_end
]).
view
(
seq_len
,
num_heads
,
head_size
)
output_ref
[
query_start
:
query_end
,
...].
copy_
(
out
[
seq_len
-
query_len
:,
...])
seq_start
+=
seq_len
query_start
+=
query_len
torch
.
cuda
.
synchronize
()
end_time
=
time
.
time
()
print
(
f
"xformers Time:
{
(
end_time
-
start_time
)
*
1000
:.
2
f
}
ms"
)
assert
torch
.
allclose
(
output_ref
,
output
,
atol
=
1e-6
,
rtol
=
0
)
vllm/attention/ops/prefix_prefill.py
View file @
0f9a6e3d
...
@@ -472,7 +472,8 @@ if triton.__version__ >= "2.1.0":
...
@@ -472,7 +472,8 @@ if triton.__version__ >= "2.1.0":
stride_v_cache_bl
,
stride_v_cache_bl
,
num_queries_per_kv
:
int
,
num_queries_per_kv
:
int
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
# head size
BLOCK_DMODEL_PADDED
:
tl
.
constexpr
,
# head size padded to a power of 2
BLOCK_N
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
):
):
# attn_bias[]
# attn_bias[]
...
@@ -493,21 +494,24 @@ if triton.__version__ >= "2.1.0":
...
@@ -493,21 +494,24 @@ if triton.__version__ >= "2.1.0":
# initialize offsets
# initialize offsets
offs_n
=
tl
.
arange
(
0
,
BLOCK_N
)
offs_n
=
tl
.
arange
(
0
,
BLOCK_N
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
_PADDED
)
offs_m
=
start_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
offs_m
=
start_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
off_q
=
(
off_q
=
(
(
cur_batch_in_all_start_index
+
offs_m
[:,
None
])
*
stride_qbs
+
(
cur_batch_in_all_start_index
+
offs_m
[:,
None
])
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_d
[
None
,
:]
*
stride_qd
)
cur_head
*
stride_qh
+
offs_d
[
None
,
:]
*
stride_qd
)
q
=
tl
.
load
(
dim_mask
=
tl
.
where
(
Q
+
off_q
,
tl
.
arange
(
0
,
BLOCK_DMODEL_PADDED
)
<
BLOCK_DMODEL
,
1
,
0
).
to
(
tl
.
int1
)
mask
=
offs_m
[:,
None
]
<
cur_batch_seq_len
-
cur_batch_ctx_len
,
q
=
tl
.
load
(
Q
+
off_q
,
mask
=
dim_mask
[
None
,
:]
&
(
offs_m
[:,
None
]
<
cur_batch_seq_len
-
cur_batch_ctx_len
),
other
=
0.0
)
other
=
0.0
)
# # initialize pointer to m and l
# # initialize pointer to m and l
m_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
m_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
l_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
l_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_DMODEL
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_DMODEL
_PADDED
],
dtype
=
tl
.
float32
)
alibi_slope
=
tl
.
load
(
Alibi_slopes
+
cur_head
)
alibi_slope
=
tl
.
load
(
Alibi_slopes
+
cur_head
)
alibi_start_q
=
tl
.
arange
(
alibi_start_q
=
tl
.
arange
(
...
@@ -532,8 +536,9 @@ if triton.__version__ >= "2.1.0":
...
@@ -532,8 +536,9 @@ if triton.__version__ >= "2.1.0":
offs_d
[
None
,
:]
*
stride_v_cache_d
+
offs_d
[
None
,
:]
*
stride_v_cache_d
+
(
start_n
+
offs_n
[:,
None
])
%
block_size
*
stride_v_cache_bl
)
(
start_n
+
offs_n
[:,
None
])
%
block_size
*
stride_v_cache_bl
)
k
=
tl
.
load
(
K_cache
+
off_k
,
k
=
tl
.
load
(
K_cache
+
off_k
,
mask
=
(
start_n
+
offs_n
[
None
,
:])
<
cur_batch_ctx_len
,
mask
=
dim_mask
[:,
None
]
&
other
=
0.0
)
((
start_n
+
offs_n
[
None
,
:])
<
cur_batch_ctx_len
),
other
=
0.0
)
# [D,N]
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
qk
+=
tl
.
dot
(
q
,
k
)
qk
+=
tl
.
dot
(
q
,
k
)
...
@@ -567,7 +572,8 @@ if triton.__version__ >= "2.1.0":
...
@@ -567,7 +572,8 @@ if triton.__version__ >= "2.1.0":
acc
=
acc
*
acc_scale
[:,
None
]
acc
=
acc
*
acc_scale
[:,
None
]
# update acc
# update acc
v
=
tl
.
load
(
V_cache
+
off_v
,
v
=
tl
.
load
(
V_cache
+
off_v
,
mask
=
(
start_n
+
offs_n
[:,
None
])
<
cur_batch_ctx_len
,
mask
=
dim_mask
[
None
,
:]
&
((
start_n
+
offs_n
[:,
None
])
<
cur_batch_ctx_len
),
other
=
0.0
)
other
=
0.0
)
p
=
p
.
to
(
v
.
dtype
)
p
=
p
.
to
(
v
.
dtype
)
...
@@ -600,8 +606,9 @@ if triton.__version__ >= "2.1.0":
...
@@ -600,8 +606,9 @@ if triton.__version__ >= "2.1.0":
# -- compute qk ----
# -- compute qk ----
k
=
tl
.
load
(
k_ptrs
+
k
=
tl
.
load
(
k_ptrs
+
(
cur_batch_in_all_start_index
+
start_n
)
*
stride_kbs
,
(
cur_batch_in_all_start_index
+
start_n
)
*
stride_kbs
,
mask
=
(
start_n
+
offs_n
[
None
,
:])
<
mask
=
dim_mask
[:,
None
]
&
cur_batch_seq_len
-
cur_batch_ctx_len
,
((
start_n
+
offs_n
[
None
,
:])
<
cur_batch_seq_len
-
cur_batch_ctx_len
),
other
=
0.0
)
other
=
0.0
)
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
...
@@ -637,8 +644,9 @@ if triton.__version__ >= "2.1.0":
...
@@ -637,8 +644,9 @@ if triton.__version__ >= "2.1.0":
# update acc
# update acc
v
=
tl
.
load
(
v_ptrs
+
v
=
tl
.
load
(
v_ptrs
+
(
cur_batch_in_all_start_index
+
start_n
)
*
stride_vbs
,
(
cur_batch_in_all_start_index
+
start_n
)
*
stride_vbs
,
mask
=
(
start_n
+
offs_n
[:,
None
])
<
mask
=
dim_mask
[
None
,
:]
&
cur_batch_seq_len
-
cur_batch_ctx_len
,
((
start_n
+
offs_n
[:,
None
])
<
cur_batch_seq_len
-
cur_batch_ctx_len
),
other
=
0.0
)
other
=
0.0
)
p
=
p
.
to
(
v
.
dtype
)
p
=
p
.
to
(
v
.
dtype
)
...
@@ -656,7 +664,8 @@ if triton.__version__ >= "2.1.0":
...
@@ -656,7 +664,8 @@ if triton.__version__ >= "2.1.0":
out_ptrs
=
Out
+
off_o
out_ptrs
=
Out
+
off_o
tl
.
store
(
out_ptrs
,
tl
.
store
(
out_ptrs
,
acc
,
acc
,
mask
=
offs_m
[:,
None
]
<
cur_batch_seq_len
-
cur_batch_ctx_len
)
mask
=
dim_mask
[
None
,
:]
&
(
offs_m
[:,
None
]
<
cur_batch_seq_len
-
cur_batch_ctx_len
))
return
return
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
...
@@ -690,7 +699,6 @@ if triton.__version__ >= "2.1.0":
...
@@ -690,7 +699,6 @@ if triton.__version__ >= "2.1.0":
num_warps
=
8
if
Lk
<=
64
else
8
num_warps
=
8
if
Lk
<=
64
else
8
if
alibi_slopes
is
not
None
:
if
alibi_slopes
is
not
None
:
assert
Lk
==
Lk_padded
_fwd_kernel_alibi
[
grid
](
_fwd_kernel_alibi
[
grid
](
q
,
q
,
k
,
k
,
...
@@ -735,6 +743,7 @@ if triton.__version__ >= "2.1.0":
...
@@ -735,6 +743,7 @@ if triton.__version__ >= "2.1.0":
num_queries_per_kv
=
num_queries_per_kv
,
num_queries_per_kv
=
num_queries_per_kv
,
BLOCK_M
=
BLOCK
,
BLOCK_M
=
BLOCK
,
BLOCK_DMODEL
=
Lk
,
BLOCK_DMODEL
=
Lk
,
BLOCK_DMODEL_PADDED
=
Lk_padded
,
BLOCK_N
=
BLOCK
,
BLOCK_N
=
BLOCK
,
num_warps
=
num_warps
,
num_warps
=
num_warps
,
num_stages
=
1
,
num_stages
=
1
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment