Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
cd7e32e2
"vscode:/vscode.git/clone" did not exist on "b7e9ad13080c1fa93f3747baf2de3929068593c0"
Unverified
Commit
cd7e32e2
authored
Apr 11, 2025
by
fzyzcjy
Committed by
GitHub
Apr 11, 2025
Browse files
Optimize attention in llama4 (#5127)
parent
88799448
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
12 deletions
+18
-12
python/sglang/srt/models/llama4.py
python/sglang/srt/models/llama4.py
+18
-12
No files found.
python/sglang/srt/models/llama4.py
View file @
cd7e32e2
...
...
@@ -240,9 +240,13 @@ class Llama4Attention(nn.Module):
def
_get_attn_scale
(
self
,
positions
:
torch
.
Tensor
)
->
torch
.
Tensor
:
floor
=
torch
.
floor
((
positions
+
1.0
)
/
self
.
floor_scale
)
attn_scale
=
torch
.
log
(
floor
+
1.0
)
*
self
.
attn_scale
+
1.0
return
attn_scale
.
unsqueeze
(
-
1
)
@
torch
.
compile
(
dynamic
=
True
,
backend
=
get_compiler_backend
())
def
_mul_attn_scale
(
self
,
positions
,
q
):
attn_scale
=
self
.
_get_attn_scale
(
positions
)
return
(
q
*
attn_scale
).
to
(
q
.
dtype
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
...
...
@@ -250,27 +254,29 @@ class Llama4Attention(nn.Module):
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
qk
,
v
=
qkv
.
split
([
self
.
q_size
+
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
if
self
.
rotary_emb
is
not
None
:
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q_view
,
k_view
=
qk
.
split
([
self
.
q_size
,
self
.
kv_size
],
dim
=-
1
)
q_out_unused
,
k_out_unused
=
self
.
rotary_emb
(
positions
,
q_view
,
k_view
)
assert
(
q_out_unused
is
q_view
)
and
(
k_out_unused
is
k_view
)
del
q_view
,
k_view
,
q_out_unused
,
k_out_unused
if
self
.
qk_norm
is
not
None
:
# TODO: support float
q
=
q
.
reshape
(
-
1
,
self
.
head_dim
).
contiguous
().
bfloat16
()
k
=
k
.
reshape
(
-
1
,
self
.
head_dim
).
contiguous
().
bfloat16
()
q
=
self
.
qk_norm
(
q
).
to
(
q
.
dtype
)
k
=
self
.
qk_norm
(
k
).
to
(
k
.
dtype
)
q
=
q
.
reshape
(
-
1
,
self
.
q_size
)
k
=
k
.
reshape
(
-
1
,
self
.
kv_size
)
# TODO there are still 2 redundant direct_copy_kernel_cuda for this `reshape` and (in attn backend) q.contiguous(), maybe we can fuse them later
qk
=
qk
.
reshape
(
-
1
,
self
.
head_dim
).
contiguous
().
bfloat16
()
qk
=
self
.
qk_norm
(
qk
).
to
(
torch
.
bfloat16
)
qk
=
qk
.
reshape
(
-
1
,
self
.
q_size
+
self
.
kv_size
)
q
,
k
=
qk
.
split
([
self
.
q_size
,
self
.
kv_size
],
dim
=-
1
)
# We are applying temperature tuning (https://arxiv.org/abs/2501.19399) to NoPE layers, where
# the inference-time temperature tuning function is customized to not affect short context
# while working at very long context
# https://arxiv.org/abs/2501.19399
if
self
.
attn_temperature_tuning
and
not
self
.
use_rope
:
attn_scale
=
self
.
_get_attn_scale
(
positions
)
q
=
(
q
*
attn_scale
).
to
(
q
.
dtype
)
q
=
self
.
_mul_attn_scale
(
positions
=
positions
,
q
=
q
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
forward_batch
)
output
,
_
=
self
.
o_proj
(
attn_output
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment