Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
cd7e32e2
Unverified
Commit
cd7e32e2
authored
Apr 11, 2025
by
fzyzcjy
Committed by
GitHub
Apr 11, 2025
Browse files
Optimize attention in llama4 (#5127)
parent
88799448
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
12 deletions
+18
-12
python/sglang/srt/models/llama4.py
python/sglang/srt/models/llama4.py
+18
-12
No files found.
python/sglang/srt/models/llama4.py
View file @
cd7e32e2
...
...
@@ -240,9 +240,13 @@ class Llama4Attention(nn.Module):
def
_get_attn_scale
(
self
,
positions
:
torch
.
Tensor
)
->
torch
.
Tensor
:
floor
=
torch
.
floor
((
positions
+
1.0
)
/
self
.
floor_scale
)
attn_scale
=
torch
.
log
(
floor
+
1.0
)
*
self
.
attn_scale
+
1.0
return
attn_scale
.
unsqueeze
(
-
1
)
@
torch
.
compile
(
dynamic
=
True
,
backend
=
get_compiler_backend
())
def
_mul_attn_scale
(
self
,
positions
,
q
):
attn_scale
=
self
.
_get_attn_scale
(
positions
)
return
(
q
*
attn_scale
).
to
(
q
.
dtype
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
...
...
@@ -250,27 +254,29 @@ class Llama4Attention(nn.Module):
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
qk
,
v
=
qkv
.
split
([
self
.
q_size
+
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
if
self
.
rotary_emb
is
not
None
:
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q_view
,
k_view
=
qk
.
split
([
self
.
q_size
,
self
.
kv_size
],
dim
=-
1
)
q_out_unused
,
k_out_unused
=
self
.
rotary_emb
(
positions
,
q_view
,
k_view
)
assert
(
q_out_unused
is
q_view
)
and
(
k_out_unused
is
k_view
)
del
q_view
,
k_view
,
q_out_unused
,
k_out_unused
if
self
.
qk_norm
is
not
None
:
# TODO: support float
q
=
q
.
reshape
(
-
1
,
self
.
head_dim
).
contiguous
().
bfloat16
()
k
=
k
.
reshape
(
-
1
,
self
.
head_dim
).
contiguous
().
bfloat16
()
q
=
self
.
qk_norm
(
q
).
to
(
q
.
dtype
)
k
=
self
.
qk_norm
(
k
).
to
(
k
.
dtype
)
q
=
q
.
reshape
(
-
1
,
self
.
q_size
)
k
=
k
.
reshape
(
-
1
,
self
.
kv_size
)
# TODO there are still 2 redundant direct_copy_kernel_cuda for this `reshape` and (in attn backend) q.contiguous(), maybe we can fuse them later
qk
=
qk
.
reshape
(
-
1
,
self
.
head_dim
).
contiguous
().
bfloat16
()
qk
=
self
.
qk_norm
(
qk
).
to
(
torch
.
bfloat16
)
qk
=
qk
.
reshape
(
-
1
,
self
.
q_size
+
self
.
kv_size
)
q
,
k
=
qk
.
split
([
self
.
q_size
,
self
.
kv_size
],
dim
=-
1
)
# We are applying temperature tuning (https://arxiv.org/abs/2501.19399) to NoPE layers, where
# the inference-time temperature tuning function is customized to not affect short context
# while working at very long context
# https://arxiv.org/abs/2501.19399
if
self
.
attn_temperature_tuning
and
not
self
.
use_rope
:
attn_scale
=
self
.
_get_attn_scale
(
positions
)
q
=
(
q
*
attn_scale
).
to
(
q
.
dtype
)
q
=
self
.
_mul_attn_scale
(
positions
=
positions
,
q
=
q
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
forward_batch
)
output
,
_
=
self
.
o_proj
(
attn_output
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment