Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ac2dc35d
Unverified
Commit
ac2dc35d
authored
Jan 23, 2025
by
Xiaoyu Zhang
Committed by
GitHub
Jan 23, 2025
Browse files
support lightning_attention_decode in sgl-kernel for MiniMax-Text-01 (#3030)
parent
3e032c07
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
588 additions
and
8 deletions
+588
-8
benchmark/kernels/minmax-text-01-lightning_attention/benchmark_lightning_attention_decode.py
...ghtning_attention/benchmark_lightning_attention_decode.py
+69
-8
sgl-kernel/benchmark/bench_lightning_attention_decode.py
sgl-kernel/benchmark/bench_lightning_attention_decode.py
+299
-0
sgl-kernel/setup.py
sgl-kernel/setup.py
+1
-0
sgl-kernel/src/sgl-kernel/__init__.py
sgl-kernel/src/sgl-kernel/__init__.py
+2
-0
sgl-kernel/src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu
.../src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu
+119
-0
sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu
sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu
+7
-0
sgl-kernel/src/sgl-kernel/ops/__init__.py
sgl-kernel/src/sgl-kernel/ops/__init__.py
+7
-0
sgl-kernel/tests/test_lightning_attention_decode.py
sgl-kernel/tests/test_lightning_attention_decode.py
+84
-0
No files found.
benchmark/kernels/minmax-text-01-lightning_attention/benchmark_lightning_attention_decode.py
View file @
ac2dc35d
...
@@ -9,6 +9,7 @@ import torch.nn.functional as F
...
@@ -9,6 +9,7 @@ import torch.nn.functional as F
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
from
einops
import
rearrange
from
einops
import
rearrange
from
sgl_kernel
import
lightning_attention_decode
as
sgl_lightning_attention_decode
@
triton
.
jit
@
triton
.
jit
...
@@ -332,7 +333,6 @@ def test_lightning_attention_implementations(model_params):
...
@@ -332,7 +333,6 @@ def test_lightning_attention_implementations(model_params):
model_params
[
"num_attention_heads"
],
model_params
[
"num_attention_heads"
],
d
,
d
,
d
,
d
,
dtype
=
dtype
,
device
=
device
,
device
=
device
,
)
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
...
@@ -350,7 +350,13 @@ def test_lightning_attention_implementations(model_params):
...
@@ -350,7 +350,13 @@ def test_lightning_attention_implementations(model_params):
q
=
q
.
transpose
(
1
,
2
)
q
=
q
.
transpose
(
1
,
2
)
k
=
k
.
transpose
(
1
,
2
)
k
=
k
.
transpose
(
1
,
2
)
v
=
v
.
transpose
(
1
,
2
)
v
=
v
.
transpose
(
1
,
2
)
q
=
q
.
contiguous
()
k
=
k
.
contiguous
()
v
=
v
.
contiguous
()
past_kv
=
past_kv
.
contiguous
()
slope_rate
=
slope_rate
.
contiguous
()
# Test Triton implementation
triton_output
,
triton_new_kv
=
lightning_attn_decode
(
q
,
k
,
v
,
past_kv
,
slope_rate
)
triton_output
,
triton_new_kv
=
lightning_attn_decode
(
q
,
k
,
v
,
past_kv
,
slope_rate
)
triton_output
=
triton_output
.
transpose
(
1
,
2
).
contiguous
()
triton_output
=
triton_output
.
transpose
(
1
,
2
).
contiguous
()
triton_output
=
triton_output
.
view
(
batch_size
,
seq_len
,
-
1
)
triton_output
=
triton_output
.
view
(
batch_size
,
seq_len
,
-
1
)
...
@@ -358,22 +364,50 @@ def test_lightning_attention_implementations(model_params):
...
@@ -358,22 +364,50 @@ def test_lightning_attention_implementations(model_params):
triton_output
=
torch
.
sigmoid
(
model_attn
.
output_gate
(
hidden_states
))
*
triton_output
triton_output
=
torch
.
sigmoid
(
model_attn
.
output_gate
(
hidden_states
))
*
triton_output
triton_output
=
model_attn
.
out_proj
(
triton_output
)
triton_output
=
model_attn
.
out_proj
(
triton_output
)
# Test SGL implementation
sgl_output
=
torch
.
empty_like
(
v
)
sgl_new_kv
=
torch
.
empty_like
(
past_kv
)
sgl_lightning_attention_decode
(
q
,
k
,
v
,
past_kv
,
slope_rate
,
sgl_output
,
sgl_new_kv
)
sgl_output
=
sgl_output
.
transpose
(
1
,
2
).
contiguous
()
sgl_output
=
sgl_output
.
view
(
batch_size
,
seq_len
,
-
1
)
sgl_output
=
model_attn
.
norm
(
sgl_output
)
sgl_output
=
torch
.
sigmoid
(
model_attn
.
output_gate
(
hidden_states
))
*
sgl_output
sgl_output
=
model_attn
.
out_proj
(
sgl_output
)
# Verify Triton implementation results
torch
.
testing
.
assert_close
(
torch
.
testing
.
assert_close
(
model_output
,
model_output
,
triton_output
,
triton_output
,
rtol
=
1e-3
,
rtol
=
1e-3
,
atol
=
1e-2
,
atol
=
1e-2
,
msg
=
"
L
ightning attention implementation
s
produce different output results"
,
msg
=
"
Triton l
ightning attention implementation produce
s
different output results"
,
)
)
torch
.
testing
.
assert_close
(
torch
.
testing
.
assert_close
(
new_kv
,
new_kv
,
triton_new_kv
,
triton_new_kv
,
rtol
=
1e-3
,
rtol
=
1e-3
,
atol
=
1e-2
,
atol
=
1e-2
,
msg
=
"
L
ightning attention implementation
s
produce different kv results"
,
msg
=
"
Triton l
ightning attention implementation produce
s
different kv results"
,
)
)
print
(
"✅ Two implementations match"
)
# Verify SGL implementation results
torch
.
testing
.
assert_close
(
model_output
,
sgl_output
,
rtol
=
1e-3
,
atol
=
1e-2
,
msg
=
"SGL lightning attention implementation produces different output results"
,
)
torch
.
testing
.
assert_close
(
new_kv
,
sgl_new_kv
,
rtol
=
1e-3
,
atol
=
1e-2
,
msg
=
"SGL lightning attention implementation produces different kv results"
,
)
print
(
"✅ All implementations match"
)
def
_build_slope_tensor
(
n_attention_heads
:
int
):
def
_build_slope_tensor
(
n_attention_heads
:
int
):
...
@@ -408,12 +442,13 @@ def get_benchmark():
...
@@ -408,12 +442,13 @@ def get_benchmark():
x_names
=
[
"batch_size"
,
"seq_len"
],
x_names
=
[
"batch_size"
,
"seq_len"
],
x_vals
=
[
list
(
_
)
for
_
in
configs
],
x_vals
=
[
list
(
_
)
for
_
in
configs
],
line_arg
=
"provider"
,
line_arg
=
"provider"
,
line_vals
=
[
"Original"
,
"Triton"
],
line_vals
=
[
"Original"
,
"Triton"
,
"SGL"
],
line_names
=
[
line_names
=
[
"Original PyTorch Implementation"
,
"Original PyTorch Implementation"
,
"Triton Implementation"
,
"Triton Implementation"
,
"SGL Implementation"
,
],
],
styles
=
[(
"blue"
,
"-"
),
(
"green"
,
"-"
)],
styles
=
[(
"blue"
,
"-"
),
(
"green"
,
"-"
),
(
"red"
,
"-"
)],
ylabel
=
"us"
,
ylabel
=
"us"
,
plot_name
=
"lightning-attention-decode-performance"
,
plot_name
=
"lightning-attention-decode-performance"
,
args
=
{},
args
=
{},
...
@@ -446,7 +481,6 @@ def get_benchmark():
...
@@ -446,7 +481,6 @@ def get_benchmark():
params
[
"num_attention_heads"
],
params
[
"num_attention_heads"
],
d
,
d
,
d
,
d
,
dtype
=
dtype
,
device
=
device
,
device
=
device
,
)
)
...
@@ -461,7 +495,7 @@ def get_benchmark():
...
@@ -461,7 +495,7 @@ def get_benchmark():
),
),
quantiles
=
quantiles
,
quantiles
=
quantiles
,
)
)
el
se
:
el
if
provider
==
"Triton"
:
def
run_triton
():
def
run_triton
():
qkv
=
model_attn
.
act
(
model_attn
.
qkv_proj
(
hidden_states
))
qkv
=
model_attn
.
act
(
model_attn
.
qkv_proj
(
hidden_states
))
...
@@ -483,6 +517,33 @@ def get_benchmark():
...
@@ -483,6 +517,33 @@ def get_benchmark():
run_triton
,
run_triton
,
quantiles
=
quantiles
,
quantiles
=
quantiles
,
)
)
else
:
# SGL
def
run_sgl
():
qkv
=
model_attn
.
act
(
model_attn
.
qkv_proj
(
hidden_states
))
new_shape
=
qkv
.
size
()[:
-
1
]
+
(
model_attn
.
num_heads
,
-
1
)
qkv
=
qkv
.
view
(
*
new_shape
)
q
,
k
,
v
=
torch
.
split
(
qkv
,
[
model_attn
.
head_dim
]
*
3
,
dim
=-
1
)
q
=
q
.
transpose
(
1
,
2
).
contiguous
()
k
=
k
.
transpose
(
1
,
2
).
contiguous
()
v
=
v
.
transpose
(
1
,
2
).
contiguous
()
output
=
torch
.
empty_like
(
v
)
new_kv
=
torch
.
empty_like
(
past_kv
)
sgl_lightning_attention_decode
(
q
,
k
,
v
,
past_kv
,
slope_rate
,
output
,
new_kv
)
output
=
output
.
transpose
(
1
,
2
).
contiguous
()
output
=
output
.
view
(
batch_size
,
seq_len
,
-
1
)
output
=
model_attn
.
norm
(
output
)
output
=
torch
.
sigmoid
(
model_attn
.
output_gate
(
hidden_states
))
*
output
return
model_attn
.
out_proj
(
output
)
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
run_sgl
,
quantiles
=
quantiles
,
)
return
1000
*
ms
,
1000
*
max_ms
,
1000
*
min_ms
return
1000
*
ms
,
1000
*
max_ms
,
1000
*
min_ms
...
...
sgl-kernel/benchmark/bench_lightning_attention_decode.py
0 → 100644
View file @
ac2dc35d
import
itertools
import
math
import
torch
import
triton
import
triton.language
as
tl
from
sgl_kernel
import
lightning_attention_decode
def
next_power_of_2
(
n
):
return
2
**
(
int
(
math
.
ceil
(
math
.
log
(
n
,
2
))))
@
triton
.
jit
def
_decode_kernel
(
Q
,
K
,
V
,
KV
,
Out
,
S
,
b
:
tl
.
constexpr
,
h
:
tl
.
constexpr
,
n
:
tl
.
constexpr
,
d
:
tl
.
constexpr
,
d_original
:
tl
.
constexpr
,
e
:
tl
.
constexpr
,
e_original
:
tl
.
constexpr
,
):
off_bh
=
tl
.
program_id
(
0
)
off_h
=
off_bh
%
h
qk_offset
=
off_bh
*
n
*
d
v_offset
=
off_bh
*
n
*
e
o_offset
=
off_bh
*
n
*
e
kv_offset
=
off_bh
*
d
*
e
s
=
tl
.
load
(
S
+
off_h
)
ratio
=
tl
.
exp
(
-
s
)
d_idx
=
tl
.
arange
(
0
,
d
)
e_idx
=
tl
.
arange
(
0
,
e
)
# Create masks for original dimensions
d_mask
=
d_idx
<
d_original
e_mask
=
e_idx
<
e_original
# Load with masking
q
=
tl
.
load
(
Q
+
qk_offset
+
d_idx
,
mask
=
d_mask
,
other
=
0.0
)
k
=
tl
.
load
(
K
+
qk_offset
+
d_idx
,
mask
=
d_mask
,
other
=
0.0
)
v
=
tl
.
load
(
V
+
v_offset
+
e_idx
,
mask
=
e_mask
,
other
=
0.0
)
# Load KV with 2D masking
kv
=
tl
.
load
(
KV
+
kv_offset
+
d_idx
[:,
None
]
*
e
+
e_idx
[
None
,
:],
mask
=
(
d_mask
[:,
None
]
&
e_mask
[
None
,
:]),
other
=
0.0
,
)
# Compute outer product using element-wise operations
k_v_prod
=
k
[:,
None
]
*
v
[
None
,
:]
kv
=
ratio
*
kv
+
k_v_prod
# Store KV with 2D masking
tl
.
store
(
KV
+
kv_offset
+
d_idx
[:,
None
]
*
e
+
e_idx
[
None
,
:],
kv
.
to
(
KV
.
dtype
.
element_ty
),
mask
=
(
d_mask
[:,
None
]
&
e_mask
[
None
,
:]),
)
# Compute matrix-vector multiplication using element-wise operations and reduction
o
=
tl
.
sum
(
q
[:,
None
]
*
kv
,
axis
=
0
)
# Store output with masking
tl
.
store
(
Out
+
o_offset
+
e_idx
,
o
.
to
(
Out
.
dtype
.
element_ty
),
mask
=
e_mask
)
def
triton_lightning_attn_decode
(
q
,
k
,
v
,
kv
,
s
):
"""Triton implementation of Lightning Attention decode operation"""
b
,
h
,
n
,
d
=
q
.
shape
e
=
v
.
shape
[
-
1
]
assert
n
==
1
,
"Sequence length must be 1 in decode mode"
# Get padded dimensions (power of 2)
d_padded
=
next_power_of_2
(
d
)
e_padded
=
next_power_of_2
(
e
)
# Create output tensor (padded)
o_padded
=
torch
.
empty
(
b
,
h
,
n
,
e_padded
,
dtype
=
v
.
dtype
,
device
=
v
.
device
)
# Create padded tensors without actually padding the data
q_padded
=
torch
.
empty
(
b
,
h
,
n
,
d_padded
,
dtype
=
q
.
dtype
,
device
=
q
.
device
)
k_padded
=
torch
.
empty
(
b
,
h
,
n
,
d_padded
,
dtype
=
k
.
dtype
,
device
=
k
.
device
)
v_padded
=
torch
.
empty
(
b
,
h
,
n
,
e_padded
,
dtype
=
v
.
dtype
,
device
=
v
.
device
)
kv_padded
=
torch
.
empty
(
b
,
h
,
d_padded
,
e_padded
,
dtype
=
torch
.
float32
,
device
=
kv
.
device
)
# Copy data to padded tensors
q_padded
[...,
:
d
]
=
q
k_padded
[...,
:
d
]
=
k
v_padded
[...,
:
e
]
=
v
kv_padded
[...,
:
d
,
:
e
]
=
kv
# Launch kernel
grid
=
(
b
*
h
,
1
)
_decode_kernel
[
grid
](
q_padded
,
k_padded
,
v_padded
,
kv_padded
,
o_padded
,
s
,
b
=
b
,
h
=
h
,
n
=
n
,
d
=
d_padded
,
d_original
=
d
,
e
=
e_padded
,
e_original
=
e
,
)
# Get unpadded outputs
o
=
o_padded
[...,
:
e
]
kv_out
=
kv_padded
[...,
:
d
,
:
e
]
return
o
,
kv_out
def
lightning_attention_decode_naive
(
q
,
k
,
v
,
past_kv
,
slope
):
"""Naive implementation of lightning attention decode"""
original_dtype
=
q
.
dtype
ratio
=
torch
.
exp
(
-
slope
)
# [h, 1, 1]
kv
=
past_kv
b
,
h
,
n
,
d
=
q
.
shape
output
=
[]
for
i
in
range
(
n
):
kv
=
ratio
*
kv
.
to
(
torch
.
float32
)
+
torch
.
einsum
(
"... n d, ... n e -> ... d e"
,
k
[:,
:,
i
:
i
+
1
],
v
[:,
:,
i
:
i
+
1
],
)
qkv
=
torch
.
einsum
(
"... n e, ... e d -> ... n d"
,
q
[:,
:,
i
:
i
+
1
].
to
(
torch
.
float32
),
kv
.
to
(
torch
.
float32
),
)
output
.
append
(
qkv
)
output
=
torch
.
concat
(
output
,
dim
=-
2
)
return
output
.
to
(
original_dtype
),
kv
def
lightning_attention_decode_kernel
(
q
,
k
,
v
,
past_kv
,
slope
,
output
,
new_kv
):
return
lightning_attention_decode
(
q
,
k
,
v
,
past_kv
,
slope
,
output
,
new_kv
)
def
calculate_diff
(
batch_size
):
dtype
=
torch
.
bfloat16
device
=
torch
.
device
(
"cuda"
)
num_heads
=
64
head_dim
=
96
seq_len
=
1
q
=
torch
.
randn
(
batch_size
,
num_heads
,
seq_len
,
head_dim
,
device
=
device
,
dtype
=
dtype
)
k
=
torch
.
randn
(
batch_size
,
num_heads
,
seq_len
,
head_dim
,
device
=
device
,
dtype
=
dtype
)
v
=
torch
.
randn
(
batch_size
,
num_heads
,
seq_len
,
head_dim
,
device
=
device
,
dtype
=
dtype
)
past_kv
=
torch
.
randn
(
batch_size
,
num_heads
,
head_dim
,
head_dim
,
device
=
device
)
slope
=
torch
.
randn
(
num_heads
,
1
,
1
,
device
=
device
)
output_naive
,
new_kv_naive
=
lightning_attention_decode_naive
(
q
.
clone
(),
k
.
clone
(),
v
.
clone
(),
past_kv
.
clone
(),
slope
.
clone
()
)
output_kernel
=
torch
.
empty_like
(
output_naive
)
new_kv_kernel
=
torch
.
empty_like
(
new_kv_naive
)
lightning_attention_decode_kernel
(
q
.
clone
(),
k
.
clone
(),
v
.
clone
(),
past_kv
.
clone
(),
slope
.
clone
(),
output_kernel
,
new_kv_kernel
,
)
output_triton
,
new_kv_triton
=
triton_lightning_attn_decode
(
q
.
clone
(),
k
.
clone
(),
v
.
clone
(),
past_kv
.
clone
(),
slope
.
clone
()
)
if
(
torch
.
allclose
(
output_naive
,
output_kernel
,
atol
=
1e-2
,
rtol
=
1e-2
)
and
torch
.
allclose
(
output_naive
,
output_triton
,
atol
=
1e-2
,
rtol
=
1e-2
)
and
torch
.
allclose
(
new_kv_naive
,
new_kv_kernel
,
atol
=
1e-2
,
rtol
=
1e-2
)
and
torch
.
allclose
(
new_kv_naive
,
new_kv_triton
,
atol
=
1e-2
,
rtol
=
1e-2
)
):
print
(
"✅ All implementations match"
)
else
:
print
(
"❌ Implementations differ"
)
batch_size_range
=
[
i
for
i
in
range
(
1
,
65
)]
# 1 to 128
configs
=
[(
bs
,)
for
bs
in
batch_size_range
]
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
],
x_vals
=
[
list
(
_
)
for
_
in
configs
],
line_arg
=
"provider"
,
line_vals
=
[
"naive"
,
"kernel"
,
"triton"
],
line_names
=
[
"PyTorch Naive"
,
"SGL Kernel"
,
"Triton"
],
styles
=
[(
"blue"
,
"-"
),
(
"red"
,
"-"
),
(
"green"
,
"-"
)],
ylabel
=
"us"
,
plot_name
=
"lightning-attention-decode-performance"
,
args
=
{},
)
)
def
benchmark
(
batch_size
,
provider
):
dtype
=
torch
.
bfloat16
device
=
torch
.
device
(
"cuda"
)
num_heads
=
64
head_dim
=
96
seq_len
=
1
q
=
torch
.
randn
(
batch_size
,
num_heads
,
seq_len
,
head_dim
,
device
=
device
,
dtype
=
dtype
)
k
=
torch
.
randn
(
batch_size
,
num_heads
,
seq_len
,
head_dim
,
device
=
device
,
dtype
=
dtype
)
v
=
torch
.
randn
(
batch_size
,
num_heads
,
seq_len
,
head_dim
,
device
=
device
,
dtype
=
dtype
)
past_kv
=
torch
.
randn
(
batch_size
,
num_heads
,
head_dim
,
head_dim
,
device
=
device
)
slope
=
torch
.
randn
(
num_heads
,
1
,
1
,
device
=
device
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"naive"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
lightning_attention_decode_naive
(
q
.
clone
(),
k
.
clone
(),
v
.
clone
(),
past_kv
.
clone
(),
slope
.
clone
()
),
quantiles
=
quantiles
,
)
elif
provider
==
"kernel"
:
output
=
torch
.
empty
(
batch_size
,
num_heads
,
seq_len
,
head_dim
,
device
=
device
,
dtype
=
dtype
)
new_kv
=
torch
.
empty
(
batch_size
,
num_heads
,
head_dim
,
head_dim
,
device
=
device
)
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
lightning_attention_decode_kernel
(
q
.
clone
(),
k
.
clone
(),
v
.
clone
(),
past_kv
.
clone
(),
slope
.
clone
(),
output
,
new_kv
,
),
quantiles
=
quantiles
,
)
elif
provider
==
"triton"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
triton_lightning_attn_decode
(
q
.
clone
(),
k
.
clone
(),
v
.
clone
(),
past_kv
.
clone
(),
slope
.
clone
()
),
quantiles
=
quantiles
,
)
return
1000
*
ms
,
1000
*
max_ms
,
1000
*
min_ms
if
__name__
==
"__main__"
:
import
argparse
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--save_path"
,
type
=
str
,
default
=
"./configs/benchmark_ops/lightning_attention_decode_sgl/"
,
help
=
"Path to save lightning attention decode benchmark results"
,
)
args
=
parser
.
parse_args
()
# Run correctness test
calculate_diff
(
batch_size
=
4
)
# Run performance benchmark
benchmark
.
run
(
print_data
=
True
)
sgl-kernel/setup.py
View file @
ac2dc35d
...
@@ -100,6 +100,7 @@ ext_modules = [
...
@@ -100,6 +100,7 @@ ext_modules = [
"src/sgl-kernel/csrc/moe_align_kernel.cu"
,
"src/sgl-kernel/csrc/moe_align_kernel.cu"
,
"src/sgl-kernel/csrc/int8_gemm_kernel.cu"
,
"src/sgl-kernel/csrc/int8_gemm_kernel.cu"
,
"src/sgl-kernel/csrc/sampling_scaling_penalties.cu"
,
"src/sgl-kernel/csrc/sampling_scaling_penalties.cu"
,
"src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu"
,
"src/sgl-kernel/csrc/sgl_kernel_ops.cu"
,
"src/sgl-kernel/csrc/sgl_kernel_ops.cu"
,
"src/sgl-kernel/csrc/rotary_embedding.cu"
,
"src/sgl-kernel/csrc/rotary_embedding.cu"
,
"3rdparty/flashinfer/csrc/activation.cu"
,
"3rdparty/flashinfer/csrc/activation.cu"
,
...
...
sgl-kernel/src/sgl-kernel/__init__.py
View file @
ac2dc35d
...
@@ -10,6 +10,7 @@ from sgl_kernel.ops import (
...
@@ -10,6 +10,7 @@ from sgl_kernel.ops import (
get_graph_buffer_ipc_meta
,
get_graph_buffer_ipc_meta
,
init_custom_reduce
,
init_custom_reduce
,
int8_scaled_mm
,
int8_scaled_mm
,
lightning_attention_decode
,
moe_align_block_size
,
moe_align_block_size
,
register_graph_buffers
,
register_graph_buffers
,
rmsnorm
,
rmsnorm
,
...
@@ -35,5 +36,6 @@ __all__ = [
...
@@ -35,5 +36,6 @@ __all__ = [
"rmsnorm"
,
"rmsnorm"
,
"rotary_embedding"
,
"rotary_embedding"
,
"sampling_scaling_penalties"
,
"sampling_scaling_penalties"
,
"lightning_attention_decode"
,
"silu_and_mul"
,
"silu_and_mul"
,
]
]
sgl-kernel/src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu
0 → 100644
View file @
ac2dc35d
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include "utils.h"
#define THREADS_PER_BLOCK 128
template
<
typename
T
>
__global__
void
lightning_attention_decode_kernel
(
const
T
*
__restrict__
q
,
// [b, h, 1, d]
const
T
*
__restrict__
k
,
// [b, h, 1, d]
const
T
*
__restrict__
v
,
// [b, h, 1, e]
const
float
*
__restrict__
past_kv
,
// [b, h, d, e]
const
float
*
__restrict__
slope
,
// [h, 1, 1]
T
*
__restrict__
output
,
// [b, h, 1, e]
float
*
__restrict__
new_kv
,
// [b, h, d, e]
const
int
batch_size
,
const
int
num_heads
,
const
int
qk_dim
,
const
int
v_dim
)
{
extern
__shared__
char
smem
[];
T
*
q_shared
=
reinterpret_cast
<
T
*>
(
smem
);
T
*
k_shared
=
reinterpret_cast
<
T
*>
(
smem
+
qk_dim
*
sizeof
(
T
));
T
*
v_shared
=
reinterpret_cast
<
T
*>
(
smem
+
2
*
qk_dim
*
sizeof
(
T
));
float
*
new_kv_shared
=
reinterpret_cast
<
float
*>
(
smem
+
(
2
*
qk_dim
+
v_dim
)
*
sizeof
(
T
));
T
*
output_shared
=
reinterpret_cast
<
T
*>
(
smem
+
(
2
*
qk_dim
+
v_dim
)
*
sizeof
(
T
)
+
qk_dim
*
(
v_dim
+
1
)
*
sizeof
(
float
));
const
int32_t
tid
=
threadIdx
.
x
;
const
int32_t
current_head
=
blockIdx
.
x
;
const
int32_t
b
=
current_head
/
num_heads
;
const
int32_t
h
=
current_head
%
num_heads
;
if
(
b
>=
batch_size
)
return
;
const
int32_t
qk_offset
=
b
*
num_heads
*
qk_dim
+
h
*
qk_dim
;
const
int32_t
v_offset
=
b
*
num_heads
*
v_dim
+
h
*
v_dim
;
const
int32_t
kv_offset
=
b
*
num_heads
*
qk_dim
*
v_dim
+
h
*
qk_dim
*
v_dim
;
for
(
int
d
=
tid
;
d
<
qk_dim
;
d
+=
blockDim
.
x
)
{
q_shared
[
d
]
=
q
[
qk_offset
+
d
];
k_shared
[
d
]
=
k
[
qk_offset
+
d
];
}
for
(
int
e
=
tid
;
e
<
v_dim
;
e
+=
blockDim
.
x
)
{
v_shared
[
e
]
=
v
[
v_offset
+
e
];
}
__syncthreads
();
const
float
ratio
=
expf
(
-
1.0
f
*
slope
[
h
]);
for
(
int
d
=
tid
;
d
<
qk_dim
;
d
+=
blockDim
.
x
)
{
T
k_val
=
k_shared
[
d
];
for
(
int
e
=
0
;
e
<
v_dim
;
++
e
)
{
int
past_kv_idx
=
kv_offset
+
d
*
v_dim
+
e
;
T
v_val
=
v_shared
[
e
];
float
new_val
=
ratio
*
past_kv
[
past_kv_idx
]
+
k_val
*
v_val
;
int
shared_idx
=
d
*
(
v_dim
+
1
)
+
e
;
new_kv_shared
[
shared_idx
]
=
new_val
;
}
}
__syncthreads
();
for
(
int
idx
=
tid
;
idx
<
qk_dim
*
v_dim
;
idx
+=
blockDim
.
x
)
{
int
d
=
idx
/
v_dim
;
int
e
=
idx
%
v_dim
;
int
shared_idx
=
d
*
(
v_dim
+
1
)
+
e
;
int
global_idx
=
kv_offset
+
idx
;
new_kv
[
global_idx
]
=
new_kv_shared
[
shared_idx
];
}
__syncthreads
();
for
(
int
e
=
tid
;
e
<
v_dim
;
e
+=
blockDim
.
x
)
{
float
sum
=
0.0
f
;
for
(
int
d
=
0
;
d
<
qk_dim
;
++
d
)
{
int
shared_idx
=
d
*
(
v_dim
+
1
)
+
e
;
sum
+=
q_shared
[
d
]
*
new_kv_shared
[
shared_idx
];
}
output_shared
[
e
]
=
static_cast
<
T
>
(
sum
);
}
__syncthreads
();
if
(
tid
==
0
)
{
for
(
int
e
=
0
;
e
<
v_dim
;
++
e
)
{
output
[
v_offset
+
e
]
=
output_shared
[
e
];
}
}
}
void
lightning_attention_decode
(
const
torch
::
Tensor
&
q
,
const
torch
::
Tensor
&
k
,
const
torch
::
Tensor
&
v
,
const
torch
::
Tensor
&
past_kv
,
const
torch
::
Tensor
&
slope
,
torch
::
Tensor
output
,
torch
::
Tensor
new_kv
)
{
TORCH_CHECK
(
q
.
is_contiguous
(),
"q must be contiguous"
);
TORCH_CHECK
(
k
.
is_contiguous
(),
"k must be contiguous"
);
TORCH_CHECK
(
v
.
is_contiguous
(),
"v must be contiguous"
);
TORCH_CHECK
(
past_kv
.
is_contiguous
(),
"past_kv must be contiguous"
);
auto
batch_size
=
q
.
size
(
0
);
auto
num_heads
=
q
.
size
(
1
);
auto
qk_dim
=
q
.
size
(
3
);
auto
v_dim
=
v
.
size
(
3
);
dim3
block
(
THREADS_PER_BLOCK
);
dim3
grid
(
batch_size
*
num_heads
);
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES_AND2
(
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
q
.
scalar_type
(),
"lightning_attention_decode_kernel"
,
([
&
]
{
size_t
smem_size
=
(
2
*
qk_dim
+
2
*
v_dim
)
*
sizeof
(
scalar_t
)
+
qk_dim
*
(
v_dim
+
1
)
*
sizeof
(
float
);
lightning_attention_decode_kernel
<
scalar_t
><<<
grid
,
block
,
smem_size
,
stream
>>>
(
q
.
data_ptr
<
scalar_t
>
(),
k
.
data_ptr
<
scalar_t
>
(),
v
.
data_ptr
<
scalar_t
>
(),
past_kv
.
data_ptr
<
float
>
(),
slope
.
data_ptr
<
float
>
(),
output
.
data_ptr
<
scalar_t
>
(),
new_kv
.
data_ptr
<
float
>
(),
batch_size
,
num_heads
,
qk_dim
,
v_dim
);
}));
}
sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu
View file @
ac2dc35d
...
@@ -26,6 +26,11 @@ torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& ma
...
@@ -26,6 +26,11 @@ torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& ma
const
torch
::
Tensor
&
scales_b
,
const
torch
::
Dtype
&
out_dtype
,
const
torch
::
Tensor
&
scales_b
,
const
torch
::
Dtype
&
out_dtype
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
);
const
c10
::
optional
<
torch
::
Tensor
>&
bias
);
// lightning_attention_decode
void
lightning_attention_decode
(
const
torch
::
Tensor
&
q
,
const
torch
::
Tensor
&
k
,
const
torch
::
Tensor
&
v
,
const
torch
::
Tensor
&
past_kv
,
const
torch
::
Tensor
&
slope
,
torch
::
Tensor
output
,
torch
::
Tensor
new_kv
);
// rotary embedding
// rotary embedding
void
rotary_embedding
(
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key
,
int64_t
head_size
,
void
rotary_embedding
(
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key
,
int64_t
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
bool
is_neox
);
torch
::
Tensor
&
cos_sin_cache
,
bool
is_neox
);
...
@@ -69,6 +74,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -69,6 +74,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"sampling_scaling_penalties"
,
&
sampling_scaling_penalties
,
"Sampling scaling penalties (CUDA)"
);
m
.
def
(
"sampling_scaling_penalties"
,
&
sampling_scaling_penalties
,
"Sampling scaling penalties (CUDA)"
);
// int8_scaled_mm
// int8_scaled_mm
m
.
def
(
"int8_scaled_mm"
,
&
int8_scaled_mm
,
"INT8 scaled matmul (CUDA)"
);
m
.
def
(
"int8_scaled_mm"
,
&
int8_scaled_mm
,
"INT8 scaled matmul (CUDA)"
);
// lightning_attention_decode
m
.
def
(
"lightning_attention_decode"
,
&
lightning_attention_decode
,
"Lightning Attention Ddecode (CUDA)"
);
// rotary embedding
// rotary embedding
m
.
def
(
"rotary_embedding"
,
&
rotary_embedding
,
"Rotary Embedding (CUDA)"
);
m
.
def
(
"rotary_embedding"
,
&
rotary_embedding
,
"Rotary Embedding (CUDA)"
);
// rms norm
// rms norm
...
...
sgl-kernel/src/sgl-kernel/ops/__init__.py
View file @
ac2dc35d
...
@@ -14,6 +14,9 @@ from sgl_kernel.ops._kernels import (
...
@@ -14,6 +14,9 @@ from sgl_kernel.ops._kernels import (
)
)
from
sgl_kernel.ops._kernels
import
init_custom_ar
as
_init_custom_ar
from
sgl_kernel.ops._kernels
import
init_custom_ar
as
_init_custom_ar
from
sgl_kernel.ops._kernels
import
int8_scaled_mm
as
_int8_scaled_mm
from
sgl_kernel.ops._kernels
import
int8_scaled_mm
as
_int8_scaled_mm
from
sgl_kernel.ops._kernels
import
(
lightning_attention_decode
as
_lightning_attention_decode
,
)
from
sgl_kernel.ops._kernels
import
moe_align_block_size
as
_moe_align_block_size
from
sgl_kernel.ops._kernels
import
moe_align_block_size
as
_moe_align_block_size
from
sgl_kernel.ops._kernels
import
register_graph_buffers
as
_register_graph_buffers
from
sgl_kernel.ops._kernels
import
register_graph_buffers
as
_register_graph_buffers
from
sgl_kernel.ops._kernels
import
rmsnorm
as
_rmsnorm
from
sgl_kernel.ops._kernels
import
rmsnorm
as
_rmsnorm
...
@@ -86,6 +89,10 @@ def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
...
@@ -86,6 +89,10 @@ def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
)
)
def
lightning_attention_decode
(
q
,
k
,
v
,
past_kv
,
slope
,
output
,
new_kv
):
_lightning_attention_decode
(
q
,
k
,
v
,
past_kv
,
slope
,
output
,
new_kv
)
def
rotary_embedding
(
positions
,
query
,
key
,
head_size
,
cos_sin_cache
,
is_neox
):
def
rotary_embedding
(
positions
,
query
,
key
,
head_size
,
cos_sin_cache
,
is_neox
):
return
_rotary_embedding
(
positions
,
query
,
key
,
head_size
,
cos_sin_cache
,
is_neox
)
return
_rotary_embedding
(
positions
,
query
,
key
,
head_size
,
cos_sin_cache
,
is_neox
)
...
...
sgl-kernel/tests/test_lightning_attention_decode.py
0 → 100644
View file @
ac2dc35d
import
pytest
import
torch
from
sgl_kernel
import
lightning_attention_decode
def
naive_lightning_attention_decode
(
q
,
k
,
v
,
past_kv
,
slope
):
"""Naive implementation of lightning attention decode"""
original_dtype
=
q
.
dtype
ratio
=
torch
.
exp
(
-
slope
)
# [h, 1, 1]
kv
=
past_kv
b
,
h
,
n
,
d
=
q
.
shape
output
=
[]
for
i
in
range
(
n
):
kv
=
ratio
*
kv
.
to
(
torch
.
float32
)
+
torch
.
einsum
(
"... n d, ... n e -> ... d e"
,
k
[:,
:,
i
:
i
+
1
],
v
[:,
:,
i
:
i
+
1
],
)
qkv
=
torch
.
einsum
(
"... n e, ... e d -> ... n d"
,
q
[:,
:,
i
:
i
+
1
].
to
(
torch
.
float32
),
kv
.
to
(
torch
.
float32
),
)
output
.
append
(
qkv
)
output
=
torch
.
concat
(
output
,
dim
=-
2
)
return
output
.
to
(
original_dtype
),
kv
configs
=
[
# (batch_size, num_heads, dim, embed_dim)
(
1
,
8
,
64
,
64
),
(
2
,
8
,
64
,
64
),
(
1
,
32
,
32
,
64
),
(
2
,
32
,
32
,
64
),
(
4
,
32
,
64
,
64
),
(
4
,
32
,
64
,
64
),
(
16
,
64
,
96
,
96
),
(
64
,
64
,
96
,
96
),
]
dtypes
=
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"CUDA not available"
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
dtypes
)
@
pytest
.
mark
.
parametrize
(
"batch_size,num_heads,dim,embed_dim"
,
configs
)
def
test_lightning_attention_decode
(
dtype
,
batch_size
,
num_heads
,
dim
,
embed_dim
):
device
=
torch
.
device
(
"cuda"
)
q
=
torch
.
randn
(
batch_size
,
num_heads
,
1
,
dim
,
device
=
device
,
dtype
=
dtype
)
k
=
torch
.
randn
(
batch_size
,
num_heads
,
1
,
dim
,
device
=
device
,
dtype
=
dtype
)
v
=
torch
.
randn
(
batch_size
,
num_heads
,
1
,
embed_dim
,
device
=
device
,
dtype
=
dtype
)
past_kv
=
torch
.
randn
(
batch_size
,
num_heads
,
dim
,
embed_dim
,
device
=
device
)
slope
=
torch
.
randn
(
num_heads
,
1
,
1
,
device
=
device
)
ref_output
,
ref_new_kv
=
naive_lightning_attention_decode
(
q
,
k
,
v
,
past_kv
,
slope
)
output
=
torch
.
empty_like
(
ref_output
)
new_kv
=
torch
.
empty_like
(
ref_new_kv
)
lightning_attention_decode
(
q
,
k
,
v
,
past_kv
,
slope
,
output
,
new_kv
)
rtol
=
1e-2
atol
=
1e-2
torch
.
testing
.
assert_close
(
output
,
ref_output
,
rtol
=
rtol
,
atol
=
atol
,
msg
=
f
"Output mismatch for batch_size=
{
batch_size
}
, num_heads=
{
num_heads
}
, "
f
"dim=
{
dim
}
, embed_dim=
{
embed_dim
}
, dtype=
{
dtype
}
"
,
)
torch
.
testing
.
assert_close
(
new_kv
,
ref_new_kv
,
rtol
=
rtol
,
atol
=
atol
,
msg
=
f
"New KV mismatch for batch_size=
{
batch_size
}
, num_heads=
{
num_heads
}
, "
f
"dim=
{
dim
}
, embed_dim=
{
embed_dim
}
, dtype=
{
dtype
}
"
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment