Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c2f212d6
Unverified
Commit
c2f212d6
authored
Jan 18, 2025
by
Xiaoyu Zhang
Committed by
GitHub
Jan 18, 2025
Browse files
optimize MiniMax-Text-01 lightning_attn_decode triton (#2966)
parent
e2cdc8a5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
47 additions
and
22 deletions
+47
-22
benchmark/kernels/minmax-text-01-lighting_attention/benchmark_lighting_attention_decode.py
...lighting_attention/benchmark_lighting_attention_decode.py
+47
-22
No files found.
benchmark/kernels/minmax-text-01-lighting_attention/benchmark_lighting_attention_decode.py
View file @
c2f212d6
...
...
@@ -23,7 +23,10 @@ def _decode_kernel(
h
:
tl
.
constexpr
,
n
:
tl
.
constexpr
,
d
:
tl
.
constexpr
,
d_original
:
tl
.
constexpr
,
e
:
tl
.
constexpr
,
e_original
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
=
32
,
):
off_bh
=
tl
.
program_id
(
0
)
off_h
=
off_bh
%
h
...
...
@@ -39,21 +42,38 @@ def _decode_kernel(
d_idx
=
tl
.
arange
(
0
,
d
)
e_idx
=
tl
.
arange
(
0
,
e
)
q
=
tl
.
load
(
Q
+
qk_offset
+
d_idx
)
k
=
tl
.
load
(
K
+
qk_offset
+
d_idx
)
v
=
tl
.
load
(
V
+
v_offset
+
e_idx
)
# Create masks for original dimensions
d_mask
=
d_idx
<
d_original
e_mask
=
e_idx
<
e_original
kv
=
tl
.
load
(
KV
+
kv_offset
+
d_idx
[:,
None
]
*
e
+
e_idx
[
None
,
:])
# Load with masking
q
=
tl
.
load
(
Q
+
qk_offset
+
d_idx
,
mask
=
d_mask
,
other
=
0.0
)
k
=
tl
.
load
(
K
+
qk_offset
+
d_idx
,
mask
=
d_mask
,
other
=
0.0
)
v
=
tl
.
load
(
V
+
v_offset
+
e_idx
,
mask
=
e_mask
,
other
=
0.0
)
# Load KV with 2D masking
kv
=
tl
.
load
(
KV
+
kv_offset
+
d_idx
[:,
None
]
*
e
+
e_idx
[
None
,
:],
mask
=
(
d_mask
[:,
None
]
&
e_mask
[
None
,
:]),
other
=
0.0
,
)
# Compute outer product using element-wise operations
k_v_prod
=
k
[:,
None
]
*
v
[
None
,
:]
kv
=
ratio
*
kv
+
k_v_prod
# Store KV with 2D masking
tl
.
store
(
KV
+
kv_offset
+
d_idx
[:,
None
]
*
e
+
e_idx
[
None
,
:],
kv
.
to
(
KV
.
dtype
.
element_ty
)
KV
+
kv_offset
+
d_idx
[:,
None
]
*
e
+
e_idx
[
None
,
:],
kv
.
to
(
KV
.
dtype
.
element_ty
),
mask
=
(
d_mask
[:,
None
]
&
e_mask
[
None
,
:]),
)
# Compute matrix-vector multiplication using element-wise operations and reduction
o
=
tl
.
sum
(
q
[:,
None
]
*
kv
,
axis
=
0
)
tl
.
store
(
Out
+
o_offset
+
e_idx
,
o
.
to
(
Out
.
dtype
.
element_ty
))
# Store output with masking
tl
.
store
(
Out
+
o_offset
+
e_idx
,
o
.
to
(
Out
.
dtype
.
element_ty
),
mask
=
e_mask
)
def
lightning_attn_decode
(
q
,
k
,
v
,
kv
,
s
):
...
...
@@ -62,26 +82,27 @@ def lightning_attn_decode(q, k, v, kv, s):
e
=
v
.
shape
[
-
1
]
assert
n
==
1
,
"Sequence length must be 1 in decode mode"
#
Pa
d dimensions
to
power of 2
#
Get padde
d dimensions
(
power of 2
)
d_padded
=
next_power_of_2
(
d
)
e_padded
=
next_power_of_2
(
e
)
# Pad inputs
q_padded
=
F
.
pad
(
q
,
(
0
,
d_padded
-
d
))
k_padded
=
F
.
pad
(
k
,
(
0
,
d_padded
-
d
))
v_padded
=
F
.
pad
(
v
,
(
0
,
e_padded
-
e
))
kv_padded
=
F
.
pad
(
kv
,
(
0
,
e_padded
-
e
,
0
,
d_padded
-
d
))
# Ensure inputs are contiguous
q_padded
=
q_padded
.
contiguous
()
k_padded
=
k_padded
.
contiguous
()
v_padded
=
v_padded
.
contiguous
()
kv_padded
=
kv_padded
.
contiguous
().
to
(
torch
.
float32
)
s
=
s
.
contiguous
()
# Create output tensor (padded)
o_padded
=
torch
.
empty
(
b
,
h
,
n
,
e_padded
,
dtype
=
v
.
dtype
,
device
=
v
.
device
)
# Create padded tensors without actually padding the data
q_padded
=
torch
.
empty
(
b
,
h
,
n
,
d_padded
,
dtype
=
q
.
dtype
,
device
=
q
.
device
)
k_padded
=
torch
.
empty
(
b
,
h
,
n
,
d_padded
,
dtype
=
k
.
dtype
,
device
=
k
.
device
)
v_padded
=
torch
.
empty
(
b
,
h
,
n
,
e_padded
,
dtype
=
v
.
dtype
,
device
=
v
.
device
)
kv_padded
=
torch
.
empty
(
b
,
h
,
d_padded
,
e_padded
,
dtype
=
torch
.
float32
,
device
=
kv
.
device
)
# Copy data to padded tensors
q_padded
[...,
:
d
]
=
q
k_padded
[...,
:
d
]
=
k
v_padded
[...,
:
e
]
=
v
kv_padded
[...,
:
d
,
:
e
]
=
kv
# Launch kernel
grid
=
(
b
*
h
,
1
)
_decode_kernel
[
grid
](
...
...
@@ -95,10 +116,12 @@ def lightning_attn_decode(q, k, v, kv, s):
h
=
h
,
n
=
n
,
d
=
d_padded
,
d_original
=
d
,
e
=
e_padded
,
e_original
=
e
,
)
#
Remove padding
#
Get unpadded outputs
o
=
o_padded
[...,
:
e
]
kv_out
=
kv_padded
[...,
:
d
,
:
e
]
...
...
@@ -351,6 +374,8 @@ def test_lightning_attention_implementations(model_params):
msg
=
"Lightning attention implementations produce different kv results"
,
)
print
(
"✅ Two implementations match"
)
def
_build_slope_tensor
(
n_attention_heads
:
int
):
def
get_slopes
(
n
):
...
...
@@ -375,7 +400,7 @@ def _build_slope_tensor(n_attention_heads: int):
def
get_benchmark
():
batch_size_range
=
[
2
**
i
for
i
in
range
(
0
,
12
)]
# max 2
048
batch_size_range
=
[
i
for
i
in
range
(
1
,
33
)]
# max
3
2
seq_length_range
=
[
1
]
# decode mode sequence length is fixed to 1
configs
=
list
(
itertools
.
product
(
batch_size_range
,
seq_length_range
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment