Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ab317936
Unverified
Commit
ab317936
authored
Jan 16, 2025
by
Xiaoyu Zhang
Committed by
GitHub
Jan 16, 2025
Browse files
[kernel] MiniMax-Text-01 prefill lightning_attn with triton (#2911)
parent
b7f3fec1
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
601 additions
and
0 deletions
+601
-0
benchmark/kernels/minmax-text-01-lighting_attention/benchmark_lighting_attention_prefill.py
...ighting_attention/benchmark_lighting_attention_prefill.py
+601
-0
No files found.
benchmark/kernels/minmax-text-01-lighting_attention/benchmark_lighting_attention_prefill.py
0 → 100644
View file @
ab317936
import
itertools
import
math
import
os
from
typing
import
Optional
,
Tuple
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
triton
import
triton.language
as
tl
from
einops
import
rearrange
# Adapted from https://github.com/OpenNLPLab/lightning-attention/blob/main/lightning_attn/ops/triton/lightning_attn2.py
@
triton
.
jit
def
_fwd_kernel
(
Q
,
K
,
V
,
Out
,
S
,
# log lambda
b
:
tl
.
constexpr
,
h
:
tl
.
constexpr
,
n
:
tl
.
constexpr
,
d
:
tl
.
constexpr
,
e
:
tl
.
constexpr
,
BLOCK
:
tl
.
constexpr
,
NUM_BLOCK
:
tl
.
constexpr
,
BLOCK_MODEL
:
tl
.
constexpr
,
):
##### get offset
off_bh
=
tl
.
program_id
(
0
)
off_h
=
off_bh
%
h
off_e
=
tl
.
program_id
(
1
)
qk_offset
=
off_bh
*
n
*
d
v_offset
=
off_bh
*
n
*
e
o_offset
=
off_bh
*
n
*
e
# channel offset
e_offset
=
off_e
*
BLOCK_MODEL
##### get block ptr
Q_block_ptr
=
Q
+
qk_offset
+
tl
.
arange
(
0
,
d
)[
None
,
:]
K_trans_block_ptr
=
K
+
qk_offset
+
tl
.
arange
(
0
,
d
)[:,
None
]
V_block_ptr
=
V
+
v_offset
+
e_offset
+
tl
.
arange
(
0
,
BLOCK_MODEL
)[
None
,
:]
O_block_ptr
=
Out
+
o_offset
+
e_offset
+
tl
.
arange
(
0
,
BLOCK_MODEL
)[
None
,
:]
S_block_ptr
=
S
+
off_h
##### init diag decay(Lambda); q, k decay; kv
s
=
tl
.
load
(
S_block_ptr
)
# q, k decay
off_block
=
tl
.
arange
(
0
,
BLOCK
)
# Not bug, this is a bit different from algorithm 1, but is mathematically equivalent
q_decay
=
tl
.
exp
(
-
s
.
to
(
tl
.
float32
)
*
off_block
[:,
None
])
k_trans_decay
=
tl
.
exp
(
-
s
.
to
(
tl
.
float32
)
*
(
BLOCK
-
off_block
[
None
,
:]))
block_decay
=
tl
.
exp
(
-
s
.
to
(
tl
.
float32
)
*
BLOCK
)
# diag decay
index
=
off_block
[:,
None
]
-
off_block
[
None
,
:]
s_index
=
s
*
index
s_index
=
tl
.
where
(
index
>=
0
,
-
s_index
,
float
(
"-inf"
))
diag_decay
=
tl
.
exp
(
s_index
)
kv
=
tl
.
zeros
([
d
,
BLOCK_MODEL
],
dtype
=
tl
.
float32
)
##### compute
for
i
in
range
(
NUM_BLOCK
):
# load
q
=
tl
.
load
(
Q_block_ptr
+
off_block
[:,
None
]
*
d
,
mask
=
off_block
[:,
None
]
<
n
,
other
=
0.0
).
to
(
tl
.
float32
)
k_trans
=
tl
.
load
(
K_trans_block_ptr
+
off_block
[
None
,
:]
*
d
,
mask
=
off_block
[
None
,
:]
<
n
,
other
=
0.0
,
).
to
(
tl
.
float32
)
v
=
tl
.
load
(
V_block_ptr
+
off_block
[:,
None
]
*
e
,
mask
=
off_block
[:,
None
]
<
n
,
other
=
0.0
).
to
(
tl
.
float32
)
# compute
qk
=
tl
.
dot
(
q
,
k_trans
)
*
diag_decay
o_intra
=
tl
.
dot
(
qk
,
v
)
o_inter
=
tl
.
dot
(
q
,
kv
)
*
q_decay
o
=
o_intra
+
o_inter
# save and update
tl
.
store
(
O_block_ptr
+
off_block
[:,
None
]
*
e
,
o
.
to
(
O_block_ptr
.
dtype
.
element_ty
),
mask
=
off_block
[:,
None
]
<
n
,
)
kv
=
block_decay
*
kv
+
tl
.
dot
(
k_trans
*
k_trans_decay
,
v
)
off_block
+=
BLOCK
def
lightning_attn2
(
q
,
k
,
v
,
s
):
q
=
q
.
contiguous
()
k
=
k
.
contiguous
()
v
=
v
.
contiguous
()
s
=
s
.
contiguous
()
b
,
h
,
n
,
d
=
q
.
shape
e
=
v
.
shape
[
-
1
]
# Pad d to next power of 2
d_padded
=
next_power_of_2
(
d
)
if
d_padded
!=
d
:
q_padded
=
F
.
pad
(
q
,
(
0
,
d_padded
-
d
))
k_padded
=
F
.
pad
(
k
,
(
0
,
d_padded
-
d
))
else
:
q_padded
=
q
k_padded
=
k
# Pad e to next power of 2
e_padded
=
next_power_of_2
(
e
)
if
e_padded
!=
e
:
v_padded
=
F
.
pad
(
v
,
(
0
,
e_padded
-
e
))
else
:
v_padded
=
v
o_padded
=
torch
.
empty
((
b
,
h
,
n
,
e_padded
),
dtype
=
q
.
dtype
,
device
=
q
.
device
)
BLOCK
=
64
NUM_BLOCK
=
triton
.
cdiv
(
q
.
shape
[
2
],
BLOCK
)
# parallel over channel
BLOCK_MODEL
=
min
(
triton
.
next_power_of_2
(
e_padded
),
32
)
grid
=
(
b
*
h
,
triton
.
cdiv
(
e_padded
,
BLOCK_MODEL
))
_fwd_kernel
[
grid
](
q_padded
,
k_padded
,
v_padded
,
o_padded
,
s
,
b
,
h
,
n
,
d_padded
,
e_padded
,
BLOCK
=
BLOCK
,
NUM_BLOCK
=
NUM_BLOCK
,
BLOCK_MODEL
=
BLOCK_MODEL
,
)
# Remove padding from output
if
e_padded
!=
e
:
o
=
o_padded
[...,
:
e
]
else
:
o
=
o_padded
return
o
def
is_support
(
dim
):
return
16
%
dim
def
next_power_of_2
(
n
):
return
2
**
(
int
(
math
.
ceil
(
math
.
log
(
n
,
2
))))
def
lightning_attn_func
(
q
,
k
,
v
,
s
):
b
,
h
,
n
,
d
=
q
.
shape
e
=
v
.
shape
[
-
1
]
assert
is_support
(
d
)
and
is_support
(
e
)
# pad v's feature dim to power of 2
e_pad
=
next_power_of_2
(
e
)
need_pad
=
e_pad
!=
e
if
need_pad
:
v
=
F
.
pad
(
v
,
(
0
,
e_pad
-
e
))
if
d
>
128
:
# split over head
if
64
%
d
:
m
=
64
elif
32
%
d
:
m
=
32
elif
16
%
d
:
m
=
16
arr
=
[
m
*
i
for
i
in
range
(
d
//
m
+
1
)]
if
arr
[
-
1
]
!=
d
:
arr
.
append
(
d
)
n
=
len
(
arr
)
o
=
0
for
i
in
range
(
n
-
1
):
start
=
arr
[
i
]
end
=
arr
[
i
+
1
]
q1
=
q
[...,
start
:
end
]
k1
=
k
[...,
start
:
end
]
o
+=
lightning_attn2
(
q1
,
k1
,
v
,
s
)
else
:
o
=
lightning_attn2
(
q
,
k
,
v
,
s
)
if
need_pad
:
o
=
o
[:,
:,
:,
:
e
]
return
o
debug
=
eval
(
os
.
environ
.
get
(
"debug"
,
default
=
"False"
))
BLOCK
=
256
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MiniMaxText01
class
MiniMaxText01RMSNorm
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
eps
=
1e-6
):
"""
MiniMaxText01RMSNorm is equivalent to T5LayerNorm
"""
super
().
__init__
()
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
))
self
.
variance_epsilon
=
eps
def
forward
(
self
,
hidden_states
):
input_dtype
=
hidden_states
.
dtype
hidden_states
=
hidden_states
.
to
(
torch
.
float32
)
variance
=
hidden_states
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
hidden_states
=
hidden_states
*
torch
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
return
self
.
weight
*
hidden_states
.
to
(
input_dtype
)
# Copied from https://huggingface.co/MiniMaxAI/MiniMax-Text-01/blob/main/modeling_minimax_text_01.py
def
get_activation_fn
(
activation
):
if
debug
:
logger
.
info
(
f
"activation:
{
activation
}
"
)
if
activation
==
"gelu"
:
return
F
.
gelu
elif
activation
==
"relu"
:
return
F
.
relu
elif
activation
==
"elu"
:
return
F
.
elu
elif
activation
==
"sigmoid"
:
return
F
.
sigmoid
elif
activation
==
"exp"
:
def
f
(
x
):
with
torch
.
no_grad
():
x_max
=
torch
.
max
(
x
,
dim
=-
1
,
keepdims
=
True
).
values
y
=
torch
.
exp
(
x
-
x_max
)
return
y
return
f
elif
activation
==
"leak"
:
return
F
.
leaky_relu
elif
activation
==
"1+elu"
:
def
f
(
x
):
return
1
+
F
.
elu
(
x
)
return
f
elif
activation
==
"2+elu"
:
def
f
(
x
):
return
2
+
F
.
elu
(
x
)
return
f
elif
activation
==
"silu"
or
activation
==
"swish"
:
return
F
.
silu
elif
activation
==
"sine"
:
return
torch
.
sin
else
:
logger
.
info
(
f
"activation: does not support
{
activation
}
, use Identity!!!"
)
return
lambda
x
:
x
# Copied from https://huggingface.co/MiniMaxAI/MiniMax-Text-01/blob/main/modeling_minimax_text_01.py
class
MiniMaxText01LightningAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
=
None
,
layer_idx
:
Optional
[
int
]
=
None
,
**
kwargs
):
super
().
__init__
()
if
config
is
None
:
config
=
type
(
"Config"
,
(),
kwargs
)
bias
=
False
self
.
hidden_size
=
config
.
hidden_size
self
.
num_heads
=
config
.
num_attention_heads
self
.
head_dim
=
getattr
(
config
,
"head_dim"
,
self
.
hidden_size
//
self
.
num_heads
)
self
.
out_proj
=
nn
.
Linear
(
self
.
head_dim
*
self
.
num_heads
,
self
.
hidden_size
,
bias
=
bias
)
self
.
act
=
get_activation_fn
(
config
.
hidden_act
)
self
.
norm
=
MiniMaxText01RMSNorm
(
self
.
head_dim
*
self
.
num_heads
)
self
.
qkv_proj
=
nn
.
Linear
(
self
.
hidden_size
,
3
*
self
.
head_dim
*
self
.
num_heads
,
bias
=
bias
)
self
.
output_gate
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
head_dim
*
self
.
num_heads
,
bias
=
bias
)
# for inference only
self
.
offset
=
0
self
.
layer_idx
=
layer_idx
def
forward
(
self
,
hidden_states
,
attn_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
# (b, h, n, m)
output_attentions
:
bool
=
False
,
past_key_value
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
use_cache
:
bool
=
False
,
slope_rate
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
,
):
if
(
not
self
.
training
)
and
(
not
do_eval
):
return
self
.
inference
(
hidden_states
,
attn_mask
,
output_attentions
,
past_key_value
,
use_cache
,
slope_rate
,
)
def
inference
(
self
,
x
,
attn_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
# (b, n)
output_attentions
:
bool
=
False
,
past_key_value
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
use_cache
:
bool
=
False
,
slope_rate
:
Optional
[
torch
.
Tensor
]
=
None
,
# (h, 1, 1)
):
# x: b n d
b
,
n
,
d
=
x
.
shape
# linear map
qkv
=
self
.
act
(
self
.
qkv_proj
(
x
))
new_shape
=
qkv
.
size
()[:
-
1
]
+
(
self
.
num_heads
,
-
1
)
qkv
=
qkv
.
view
(
*
new_shape
)
q
,
k
,
v
=
torch
.
split
(
qkv
,
[
self
.
head_dim
]
*
3
,
dim
=
3
)
q
=
q
.
transpose
(
1
,
2
)
k
=
k
.
transpose
(
1
,
2
)
v
=
v
.
transpose
(
1
,
2
)
if
past_key_value
is
None
:
self
.
offset
=
q
.
shape
[
-
2
]
else
:
self
.
offset
+=
1
# for align with metaseq
ratio
=
torch
.
exp
(
-
slope_rate
)
# only use for the first time
if
past_key_value
is
None
:
slope_rate
=
slope_rate
.
to
(
torch
.
float32
)
if
attn_mask
is
not
None
:
v
=
v
.
masked_fill
(
(
1
-
attn_mask
).
unsqueeze
(
1
).
unsqueeze
(
-
1
).
to
(
torch
.
bool
),
0
)
NUM_BLOCK
=
(
n
+
BLOCK
-
1
)
//
BLOCK
b
,
h
,
n
,
d
=
q
.
shape
e
=
v
.
shape
[
-
1
]
# other
array
=
torch
.
arange
(
BLOCK
).
to
(
q
)
+
1
q_decay
=
torch
.
exp
(
-
slope_rate
*
array
.
reshape
(
-
1
,
1
))
k_decay
=
torch
.
exp
(
-
slope_rate
*
(
BLOCK
-
array
.
reshape
(
-
1
,
1
)))
index
=
array
[:,
None
]
-
array
[
None
,
:]
s_index
=
(
slope_rate
*
index
[
None
,
None
,
]
)
s_index
=
torch
.
where
(
index
>=
0
,
-
s_index
,
float
(
"-inf"
))
diag_decay
=
torch
.
exp
(
s_index
)
kv
=
torch
.
zeros
(
b
,
h
,
d
,
e
).
to
(
torch
.
float32
).
to
(
q
.
device
)
output
=
torch
.
empty
((
b
,
h
,
n
,
e
),
dtype
=
q
.
dtype
,
device
=
q
.
device
)
for
i
in
range
(
NUM_BLOCK
):
si
=
i
*
BLOCK
ei
=
min
(
si
+
BLOCK
,
n
)
m
=
ei
-
si
qi
=
q
[:,
:,
si
:
ei
].
contiguous
()
ki
=
k
[:,
:,
si
:
ei
].
contiguous
()
vi
=
v
[:,
:,
si
:
ei
].
contiguous
()
qkv_none_diag
=
torch
.
matmul
(
qi
*
q_decay
[:,
:
m
],
kv
).
to
(
torch
.
float32
)
# diag
qk
=
(
torch
.
matmul
(
qi
,
ki
.
transpose
(
-
1
,
-
2
)).
to
(
torch
.
float32
)
*
diag_decay
[:,
:,
:
m
,
:
m
]
)
qkv_diag
=
torch
.
matmul
(
qk
,
vi
.
to
(
torch
.
float32
))
block_decay
=
torch
.
exp
(
-
slope_rate
*
m
)
output
[:,
:,
si
:
ei
]
=
qkv_none_diag
+
qkv_diag
kv
=
block_decay
*
kv
+
torch
.
matmul
(
(
ki
*
k_decay
[:,
-
m
:]).
transpose
(
-
1
,
-
2
).
to
(
vi
.
dtype
),
vi
)
else
:
kv
=
past_key_value
output
=
[]
for
i
in
range
(
n
):
kv
=
ratio
*
kv
+
torch
.
einsum
(
"... n d, ... n e -> ... d e"
,
k
[:,
:,
i
:
i
+
1
],
v
[:,
:,
i
:
i
+
1
],
)
qkv
=
torch
.
einsum
(
"... n e, ... e d -> ... n d"
,
q
[:,
:,
i
:
i
+
1
],
kv
.
to
(
q
.
dtype
)
)
output
.
append
(
qkv
)
output
=
torch
.
concat
(
output
,
dim
=-
2
)
# reshape
output
=
rearrange
(
output
,
"b h n d -> b n (h d)"
)
# normalize
output
=
self
.
norm
(
output
)
# gate
output
=
F
.
sigmoid
(
self
.
output_gate
(
x
))
*
output
# outproj
output
=
self
.
out_proj
(
output
)
attn_weights
=
None
return
output
,
attn_weights
,
kv
def
_build_slope_tensor
(
n_attention_heads
:
int
):
def
get_slopes
(
n
):
def
get_slopes_power_of_2
(
n
):
start
=
2
**
(
-
(
2
**
-
(
math
.
log2
(
n
)
-
3
)))
ratio
=
start
return
[
start
*
ratio
**
i
for
i
in
range
(
n
)]
if
math
.
log2
(
n
).
is_integer
():
return
get_slopes_power_of_2
(
n
)
# In the paper, we only train models that have 2^a heads for some a. This function has
else
:
# some good properties that only occur when the input is a power of 2. To maintain that even
closest_power_of_2
=
2
**
math
.
floor
(
math
.
log2
(
n
)
)
# when the number of heads is not a power of 2, we use this workaround.
return
(
get_slopes_power_of_2
(
closest_power_of_2
)
+
get_slopes
(
2
*
closest_power_of_2
)[
0
::
2
][:
n
-
closest_power_of_2
]
)
# h, 1, 1
slopes
=
torch
.
tensor
(
get_slopes
(
n_attention_heads
)).
reshape
(
n_attention_heads
,
1
,
1
)
return
slopes
def
test_lightning_attention_implementations
(
model_params
):
torch
.
manual_seed
(
42
)
batch_size
=
2
seq_len
=
1024
dtype
=
torch
.
bfloat16
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
hidden_states
=
torch
.
randn
(
batch_size
,
seq_len
,
model_params
[
"hidden_size"
],
dtype
=
dtype
,
device
=
device
)
attention_mask
=
torch
.
ones
(
batch_size
,
seq_len
,
dtype
=
dtype
,
device
=
device
)
slope_rate
=
_build_slope_tensor
(
model_params
[
"num_attention_heads"
]).
to
(
device
)
model_attn
=
MiniMaxText01LightningAttention
(
**
model_params
).
to
(
dtype
).
to
(
device
)
model_attn
.
eval
()
with
torch
.
no_grad
():
model_output
,
_
,
_
=
model_attn
.
inference
(
hidden_states
,
attn_mask
=
attention_mask
,
slope_rate
=
slope_rate
)
qkv
=
model_attn
.
act
(
model_attn
.
qkv_proj
(
hidden_states
))
new_shape
=
qkv
.
size
()[:
-
1
]
+
(
model_attn
.
num_heads
,
-
1
)
qkv
=
qkv
.
view
(
*
new_shape
)
q
,
k
,
v
=
torch
.
split
(
qkv
,
[
model_attn
.
head_dim
]
*
3
,
dim
=-
1
)
q
=
q
.
transpose
(
1
,
2
)
k
=
k
.
transpose
(
1
,
2
)
v
=
v
.
transpose
(
1
,
2
)
lib_output
=
lightning_attn_func
(
q
,
k
,
v
,
slope_rate
)
lib_output
=
lib_output
.
transpose
(
1
,
2
).
contiguous
()
lib_output
=
lib_output
.
view
(
batch_size
,
seq_len
,
-
1
)
lib_output
=
model_attn
.
norm
(
lib_output
)
lib_output
=
torch
.
sigmoid
(
model_attn
.
output_gate
(
hidden_states
))
*
lib_output
lib_output
=
model_attn
.
out_proj
(
lib_output
)
torch
.
testing
.
assert_close
(
model_output
,
lib_output
,
rtol
=
1e-3
,
atol
=
1e-2
,
msg
=
"Lightning attention implementations produce different results"
,
)
def
get_benchmark
():
batch_size_range
=
[
2
**
i
for
i
in
range
(
0
,
7
)]
# max 64
seq_length_range
=
[
256
,
512
,
1024
,
2048
,
4096
]
# max 4096
configs
=
list
(
itertools
.
product
(
batch_size_range
,
seq_length_range
))
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
,
"seq_len"
],
x_vals
=
[
list
(
_
)
for
_
in
configs
],
line_arg
=
"provider"
,
line_vals
=
[
"MiniMax-Text-01"
,
"OpenNLPLab"
],
line_names
=
[
"MiniMax-Text-01 Model Implementation"
,
"OpenNLPLab Library Implementation"
,
],
styles
=
[(
"blue"
,
"-"
),
(
"green"
,
"-"
)],
ylabel
=
"us"
,
plot_name
=
"lightning-attention-prefill-performance"
,
args
=
{},
)
)
def
benchmark
(
batch_size
,
seq_len
,
provider
):
dtype
=
torch
.
bfloat16
device
=
torch
.
device
(
"cuda"
)
params
=
{
"hidden_size"
:
6144
,
"num_attention_heads"
:
64
,
"head_dim"
:
96
,
"hidden_act"
:
"gelu"
,
}
hidden_states
=
torch
.
randn
(
batch_size
,
seq_len
,
params
[
"hidden_size"
],
dtype
=
dtype
,
device
=
device
)
attention_mask
=
torch
.
ones
(
batch_size
,
seq_len
,
dtype
=
dtype
,
device
=
device
)
slope_rate
=
_build_slope_tensor
(
params
[
"num_attention_heads"
]).
to
(
device
)
model_attn
=
MiniMaxText01LightningAttention
(
**
params
).
to
(
dtype
).
to
(
device
)
model_attn
.
eval
()
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"MiniMax-Text-01"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
model_attn
.
inference
(
hidden_states
,
attn_mask
=
attention_mask
,
slope_rate
=
slope_rate
),
quantiles
=
quantiles
,
)
else
:
def
run_lib
():
qkv
=
model_attn
.
act
(
model_attn
.
qkv_proj
(
hidden_states
))
new_shape
=
qkv
.
size
()[:
-
1
]
+
(
model_attn
.
num_heads
,
-
1
)
qkv
=
qkv
.
view
(
*
new_shape
)
q
,
k
,
v
=
torch
.
split
(
qkv
,
[
model_attn
.
head_dim
]
*
3
,
dim
=-
1
)
q
=
q
.
transpose
(
1
,
2
)
k
=
k
.
transpose
(
1
,
2
)
v
=
v
.
transpose
(
1
,
2
)
lib_output
=
lightning_attn_func
(
q
,
k
,
v
,
slope_rate
)
lib_output
=
lib_output
.
transpose
(
1
,
2
).
contiguous
()
lib_output
=
lib_output
.
view
(
batch_size
,
seq_len
,
-
1
)
lib_output
=
model_attn
.
norm
(
lib_output
)
lib_output
=
(
torch
.
sigmoid
(
model_attn
.
output_gate
(
hidden_states
))
*
lib_output
)
return
model_attn
.
out_proj
(
lib_output
)
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
run_lib
,
quantiles
=
quantiles
,
)
return
1000
*
ms
,
1000
*
max_ms
,
1000
*
min_ms
return
benchmark
if
__name__
==
"__main__"
:
import
argparse
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--save_path"
,
type
=
str
,
default
=
"./configs/benchmark_ops/lightning_attention_prefill/"
,
help
=
"Path to save lightning attention prefill benchmark results"
,
)
args
=
parser
.
parse_args
()
# Run correctness test first
# Adapted from https://huggingface.co/MiniMaxAI/MiniMax-Text-01/blob/main/config.json
params
=
{
"hidden_size"
:
6144
,
"num_attention_heads"
:
64
,
"head_dim"
:
96
,
"hidden_act"
:
"silu"
,
}
test_lightning_attention_implementations
(
params
)
# Run performance benchmark
benchmark
=
get_benchmark
()
benchmark
.
run
(
print_data
=
True
,
save_path
=
args
.
save_path
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment