Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
78e974b2
"docs/vscode:/vscode.git/clone" did not exist on "f2eae16849ac7f608f224dbeb0babfc518d4da9d"
Unverified
Commit
78e974b2
authored
Jan 17, 2025
by
Xiaoyu Zhang
Committed by
GitHub
Jan 16, 2025
Browse files
[kernel] MiniMax-Text-01 decode lightning_attn with triton (#2920)
parent
bc6915e3
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
492 additions
and
0 deletions
+492
-0
benchmark/kernels/minmax-text-01-lighting_attention/benchmark_lighting_attention_decode.py
...lighting_attention/benchmark_lighting_attention_decode.py
+492
-0
No files found.
benchmark/kernels/minmax-text-01-lighting_attention/benchmark_lighting_attention_decode.py
0 → 100644
View file @
78e974b2
import
itertools
import
math
import
os
from
typing
import
Optional
,
Tuple
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
triton
import
triton.language
as
tl
from
einops
import
rearrange
@
triton
.
jit
def
_decode_kernel
(
Q
,
K
,
V
,
KV
,
Out
,
S
,
b
:
tl
.
constexpr
,
h
:
tl
.
constexpr
,
n
:
tl
.
constexpr
,
d
:
tl
.
constexpr
,
e
:
tl
.
constexpr
,
):
off_bh
=
tl
.
program_id
(
0
)
off_h
=
off_bh
%
h
qk_offset
=
off_bh
*
n
*
d
v_offset
=
off_bh
*
n
*
e
o_offset
=
off_bh
*
n
*
e
kv_offset
=
off_bh
*
d
*
e
s
=
tl
.
load
(
S
+
off_h
)
ratio
=
tl
.
exp
(
-
s
)
d_idx
=
tl
.
arange
(
0
,
d
)
e_idx
=
tl
.
arange
(
0
,
e
)
q
=
tl
.
load
(
Q
+
qk_offset
+
d_idx
)
k
=
tl
.
load
(
K
+
qk_offset
+
d_idx
)
v
=
tl
.
load
(
V
+
v_offset
+
e_idx
)
kv
=
tl
.
load
(
KV
+
kv_offset
+
d_idx
[:,
None
]
*
e
+
e_idx
[
None
,
:])
k_v_prod
=
k
[:,
None
]
*
v
[
None
,
:]
kv
=
ratio
*
kv
+
k_v_prod
tl
.
store
(
KV
+
kv_offset
+
d_idx
[:,
None
]
*
e
+
e_idx
[
None
,
:],
kv
.
to
(
KV
.
dtype
.
element_ty
)
)
o
=
tl
.
sum
(
q
[:,
None
]
*
kv
,
axis
=
0
)
tl
.
store
(
Out
+
o_offset
+
e_idx
,
o
.
to
(
Out
.
dtype
.
element_ty
))
def
lightning_attn_decode
(
q
,
k
,
v
,
kv
,
s
):
"""Triton implementation of Lightning Attention decode operation"""
b
,
h
,
n
,
d
=
q
.
shape
e
=
v
.
shape
[
-
1
]
assert
n
==
1
,
"Sequence length must be 1 in decode mode"
# Pad dimensions to power of 2
d_padded
=
next_power_of_2
(
d
)
e_padded
=
next_power_of_2
(
e
)
# Pad inputs
q_padded
=
F
.
pad
(
q
,
(
0
,
d_padded
-
d
))
k_padded
=
F
.
pad
(
k
,
(
0
,
d_padded
-
d
))
v_padded
=
F
.
pad
(
v
,
(
0
,
e_padded
-
e
))
kv_padded
=
F
.
pad
(
kv
,
(
0
,
e_padded
-
e
,
0
,
d_padded
-
d
))
# Ensure inputs are contiguous
q_padded
=
q_padded
.
contiguous
()
k_padded
=
k_padded
.
contiguous
()
v_padded
=
v_padded
.
contiguous
()
kv_padded
=
kv_padded
.
contiguous
().
to
(
torch
.
float32
)
s
=
s
.
contiguous
()
# Create output tensor (padded)
o_padded
=
torch
.
empty
(
b
,
h
,
n
,
e_padded
,
dtype
=
v
.
dtype
,
device
=
v
.
device
)
# Launch kernel
grid
=
(
b
*
h
,
1
)
_decode_kernel
[
grid
](
q_padded
,
k_padded
,
v_padded
,
kv_padded
,
o_padded
,
s
,
b
=
b
,
h
=
h
,
n
=
n
,
d
=
d_padded
,
e
=
e_padded
,
)
# Remove padding
o
=
o_padded
[...,
:
e
]
kv_out
=
kv_padded
[...,
:
d
,
:
e
]
return
o
,
kv_out
def
next_power_of_2
(
n
):
return
2
**
(
int
(
math
.
ceil
(
math
.
log
(
n
,
2
))))
class
MiniMaxText01LightningAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
=
None
,
layer_idx
:
Optional
[
int
]
=
None
,
**
kwargs
):
super
().
__init__
()
if
config
is
None
:
config
=
type
(
"Config"
,
(),
kwargs
)
bias
=
False
self
.
hidden_size
=
config
.
hidden_size
self
.
num_heads
=
config
.
num_attention_heads
self
.
head_dim
=
getattr
(
config
,
"head_dim"
,
self
.
hidden_size
//
self
.
num_heads
)
self
.
out_proj
=
nn
.
Linear
(
self
.
head_dim
*
self
.
num_heads
,
self
.
hidden_size
,
bias
=
bias
)
self
.
act
=
get_activation_fn
(
config
.
hidden_act
)
self
.
norm
=
MiniMaxText01RMSNorm
(
self
.
head_dim
*
self
.
num_heads
)
self
.
qkv_proj
=
nn
.
Linear
(
self
.
hidden_size
,
3
*
self
.
head_dim
*
self
.
num_heads
,
bias
=
bias
)
self
.
output_gate
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
head_dim
*
self
.
num_heads
,
bias
=
bias
)
# for inference only
self
.
offset
=
0
self
.
layer_idx
=
layer_idx
def
forward
(
self
,
hidden_states
,
attn_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
# (b, h, n, m)
output_attentions
:
bool
=
False
,
past_key_value
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
use_cache
:
bool
=
False
,
slope_rate
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
,
):
if
(
not
self
.
training
)
and
(
not
do_eval
):
return
self
.
inference
(
hidden_states
,
attn_mask
,
output_attentions
,
past_key_value
,
use_cache
,
slope_rate
,
)
def
inference
(
self
,
x
,
attn_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
# (b, n)
output_attentions
:
bool
=
False
,
past_key_value
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
use_cache
:
bool
=
False
,
slope_rate
:
Optional
[
torch
.
Tensor
]
=
None
,
# (h, 1, 1)
):
# x: b n d
b
,
n
,
d
=
x
.
shape
# linear map
qkv
=
self
.
act
(
self
.
qkv_proj
(
x
))
new_shape
=
qkv
.
size
()[:
-
1
]
+
(
self
.
num_heads
,
-
1
)
qkv
=
qkv
.
view
(
*
new_shape
)
q
,
k
,
v
=
torch
.
split
(
qkv
,
[
self
.
head_dim
]
*
3
,
dim
=
3
)
q
=
q
.
transpose
(
1
,
2
)
# [b, n, h, d] -> [b, h, n, d]
k
=
k
.
transpose
(
1
,
2
)
# [b, n, h, d] -> [b, h, n, d]
v
=
v
.
transpose
(
1
,
2
)
# [b, n, h, d] -> [b, h, n, e]
self
.
offset
+=
1
ratio
=
torch
.
exp
(
-
slope_rate
)
# [h, 1, 1]
# decode mode
kv
=
past_key_value
# [b, h, d, e]
output
=
[]
for
i
in
range
(
n
):
# kv: [b, h, d, e]
# ratio: [h, 1, 1]
# k: [b, h, n, d]
# v: [b, h, n, e]
# k[:, :, i : i + 1]: [b, h, 1, d]
# v[:, :, i : i + 1]: [b, h, 1, e]
# ratio * kv: [b, h, d, e]
# torch.einsum(
# "... n d, ... n e -> ... d e",
# k[:, :, i : i + 1],
# v[:, :, i : i + 1],
# )
# [b, h, d, e] + [b, h, d, e] -> [b, h, d, e]
kv
=
ratio
*
kv
+
torch
.
einsum
(
"... n d, ... n e -> ... d e"
,
k
[:,
:,
i
:
i
+
1
],
v
[:,
:,
i
:
i
+
1
],
)
# q[:, :, i : i + 1]: [b, h, 1, d]
# kv.to(q.dtype): [b, h, d, e]
# torch.einsum(
# "... n e, ... e d -> ... n d", q[:, :, i : i + 1], kv.to(q.dtype)
# )
# [b, h, 1, d] * [b, h, d, e] -> [b, h, 1, e]
qkv
=
torch
.
einsum
(
"... n e, ... e d -> ... n d"
,
q
[:,
:,
i
:
i
+
1
],
kv
.
to
(
q
.
dtype
)
)
output
.
append
(
qkv
)
output
=
torch
.
concat
(
output
,
dim
=-
2
)
# reshape
output
=
rearrange
(
output
,
"b h n d -> b n (h d)"
)
# normalize
output
=
self
.
norm
(
output
)
# gate
output
=
F
.
sigmoid
(
self
.
output_gate
(
x
))
*
output
# outproj
output
=
self
.
out_proj
(
output
)
attn_weights
=
None
return
output
,
attn_weights
,
kv
def
get_activation_fn
(
activation
):
if
activation
==
"gelu"
:
return
F
.
gelu
elif
activation
==
"relu"
:
return
F
.
relu
elif
activation
==
"elu"
:
return
F
.
elu
elif
activation
==
"sigmoid"
:
return
F
.
sigmoid
elif
activation
==
"exp"
:
def
f
(
x
):
with
torch
.
no_grad
():
x_max
=
torch
.
max
(
x
,
dim
=-
1
,
keepdims
=
True
).
values
y
=
torch
.
exp
(
x
-
x_max
)
return
y
return
f
elif
activation
==
"leak"
:
return
F
.
leaky_relu
elif
activation
==
"1+elu"
:
def
f
(
x
):
return
1
+
F
.
elu
(
x
)
return
f
elif
activation
==
"2+elu"
:
def
f
(
x
):
return
2
+
F
.
elu
(
x
)
return
f
elif
activation
==
"silu"
or
activation
==
"swish"
:
return
F
.
silu
elif
activation
==
"sine"
:
return
torch
.
sin
else
:
return
lambda
x
:
x
class
MiniMaxText01RMSNorm
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
eps
=
1e-6
):
"""
MiniMaxText01RMSNorm is equivalent to T5LayerNorm
"""
super
().
__init__
()
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
))
self
.
variance_epsilon
=
eps
def
forward
(
self
,
hidden_states
):
input_dtype
=
hidden_states
.
dtype
hidden_states
=
hidden_states
.
to
(
torch
.
float32
)
variance
=
hidden_states
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
hidden_states
=
hidden_states
*
torch
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
return
self
.
weight
*
hidden_states
.
to
(
input_dtype
)
def
test_lightning_attention_implementations
(
model_params
):
torch
.
manual_seed
(
42
)
batch_size
=
64
seq_len
=
1
dtype
=
torch
.
bfloat16
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
hidden_states
=
torch
.
randn
(
batch_size
,
seq_len
,
model_params
[
"hidden_size"
],
dtype
=
dtype
,
device
=
device
)
attention_mask
=
torch
.
ones
(
batch_size
,
seq_len
,
dtype
=
dtype
,
device
=
device
)
slope_rate
=
_build_slope_tensor
(
model_params
[
"num_attention_heads"
]).
to
(
device
)
model_attn
=
MiniMaxText01LightningAttention
(
**
model_params
).
to
(
dtype
).
to
(
device
)
model_attn
.
eval
()
d
=
model_params
[
"head_dim"
]
past_kv
=
torch
.
randn
(
batch_size
,
model_params
[
"num_attention_heads"
],
d
,
d
,
dtype
=
dtype
,
device
=
device
,
)
with
torch
.
no_grad
():
model_output
,
_
,
new_kv
=
model_attn
.
inference
(
hidden_states
,
attn_mask
=
attention_mask
,
slope_rate
=
slope_rate
,
past_key_value
=
past_kv
,
)
qkv
=
model_attn
.
act
(
model_attn
.
qkv_proj
(
hidden_states
))
new_shape
=
qkv
.
size
()[:
-
1
]
+
(
model_attn
.
num_heads
,
-
1
)
qkv
=
qkv
.
view
(
*
new_shape
)
q
,
k
,
v
=
torch
.
split
(
qkv
,
[
model_attn
.
head_dim
]
*
3
,
dim
=-
1
)
q
=
q
.
transpose
(
1
,
2
)
k
=
k
.
transpose
(
1
,
2
)
v
=
v
.
transpose
(
1
,
2
)
triton_output
,
triton_new_kv
=
lightning_attn_decode
(
q
,
k
,
v
,
past_kv
,
slope_rate
)
triton_output
=
triton_output
.
transpose
(
1
,
2
).
contiguous
()
triton_output
=
triton_output
.
view
(
batch_size
,
seq_len
,
-
1
)
triton_output
=
model_attn
.
norm
(
triton_output
)
triton_output
=
torch
.
sigmoid
(
model_attn
.
output_gate
(
hidden_states
))
*
triton_output
triton_output
=
model_attn
.
out_proj
(
triton_output
)
torch
.
testing
.
assert_close
(
model_output
,
triton_output
,
rtol
=
1e-3
,
atol
=
1e-2
,
msg
=
"Lightning attention implementations produce different output results"
,
)
torch
.
testing
.
assert_close
(
new_kv
,
triton_new_kv
,
rtol
=
1e-3
,
atol
=
1e-2
,
msg
=
"Lightning attention implementations produce different kv results"
,
)
def
_build_slope_tensor
(
n_attention_heads
:
int
):
def
get_slopes
(
n
):
def
get_slopes_power_of_2
(
n
):
start
=
2
**
(
-
(
2
**
-
(
math
.
log2
(
n
)
-
3
)))
ratio
=
start
return
[
start
*
ratio
**
i
for
i
in
range
(
n
)]
if
math
.
log2
(
n
).
is_integer
():
return
get_slopes_power_of_2
(
n
)
else
:
closest_power_of_2
=
2
**
math
.
floor
(
math
.
log2
(
n
))
return
(
get_slopes_power_of_2
(
closest_power_of_2
)
+
get_slopes
(
2
*
closest_power_of_2
)[
0
::
2
][:
n
-
closest_power_of_2
]
)
slopes
=
torch
.
tensor
(
get_slopes
(
n_attention_heads
)).
reshape
(
n_attention_heads
,
1
,
1
)
return
slopes
def
get_benchmark
():
batch_size_range
=
[
2
**
i
for
i
in
range
(
0
,
12
)]
# max 2048
seq_length_range
=
[
1
]
# decode mode sequence length is fixed to 1
configs
=
list
(
itertools
.
product
(
batch_size_range
,
seq_length_range
))
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
,
"seq_len"
],
x_vals
=
[
list
(
_
)
for
_
in
configs
],
line_arg
=
"provider"
,
line_vals
=
[
"Original"
,
"Triton"
],
line_names
=
[
"Original PyTorch Implementation"
,
"Triton Implementation"
,
],
styles
=
[(
"blue"
,
"-"
),
(
"green"
,
"-"
)],
ylabel
=
"us"
,
plot_name
=
"lightning-attention-decode-performance"
,
args
=
{},
)
)
def
benchmark
(
batch_size
,
seq_len
,
provider
):
dtype
=
torch
.
bfloat16
device
=
torch
.
device
(
"cuda"
)
params
=
{
"hidden_size"
:
6144
,
"num_attention_heads"
:
64
,
"head_dim"
:
96
,
"hidden_act"
:
"gelu"
,
}
hidden_states
=
torch
.
randn
(
batch_size
,
seq_len
,
params
[
"hidden_size"
],
dtype
=
dtype
,
device
=
device
)
attention_mask
=
torch
.
ones
(
batch_size
,
seq_len
,
dtype
=
dtype
,
device
=
device
)
slope_rate
=
_build_slope_tensor
(
params
[
"num_attention_heads"
]).
to
(
device
)
model_attn
=
MiniMaxText01LightningAttention
(
**
params
).
to
(
dtype
).
to
(
device
)
model_attn
.
eval
()
d
=
params
[
"head_dim"
]
past_kv
=
torch
.
randn
(
batch_size
,
params
[
"num_attention_heads"
],
d
,
d
,
dtype
=
dtype
,
device
=
device
,
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"Original"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
model_attn
.
inference
(
hidden_states
,
attn_mask
=
attention_mask
,
slope_rate
=
slope_rate
,
past_key_value
=
past_kv
,
),
quantiles
=
quantiles
,
)
else
:
def
run_triton
():
qkv
=
model_attn
.
act
(
model_attn
.
qkv_proj
(
hidden_states
))
new_shape
=
qkv
.
size
()[:
-
1
]
+
(
model_attn
.
num_heads
,
-
1
)
qkv
=
qkv
.
view
(
*
new_shape
)
q
,
k
,
v
=
torch
.
split
(
qkv
,
[
model_attn
.
head_dim
]
*
3
,
dim
=-
1
)
q
=
q
.
transpose
(
1
,
2
)
k
=
k
.
transpose
(
1
,
2
)
v
=
v
.
transpose
(
1
,
2
)
output
,
new_kv
=
lightning_attn_decode
(
q
,
k
,
v
,
past_kv
,
slope_rate
)
output
=
output
.
transpose
(
1
,
2
).
contiguous
()
output
=
output
.
view
(
batch_size
,
seq_len
,
-
1
)
output
=
model_attn
.
norm
(
output
)
output
=
torch
.
sigmoid
(
model_attn
.
output_gate
(
hidden_states
))
*
output
return
model_attn
.
out_proj
(
output
)
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
run_triton
,
quantiles
=
quantiles
,
)
return
1000
*
ms
,
1000
*
max_ms
,
1000
*
min_ms
return
benchmark
if
__name__
==
"__main__"
:
import
argparse
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--save_path"
,
type
=
str
,
default
=
"./configs/benchmark_ops/lightning_attention_decode/"
,
help
=
"Path to save lightning attention decode benchmark results"
,
)
args
=
parser
.
parse_args
()
params
=
{
"hidden_size"
:
6144
,
"num_attention_heads"
:
64
,
"head_dim"
:
96
,
"hidden_act"
:
"silu"
,
}
# Run correctness test first
# Adapted from https://huggingface.co/MiniMaxAI/MiniMax-Text-01/blob/main/config.json
test_lightning_attention_implementations
(
params
)
# Run performance benchmark
benchmark
=
get_benchmark
()
benchmark
.
run
(
print_data
=
True
,
save_path
=
args
.
save_path
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment