Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
50f1b6d6
"docs/git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "a0939977a3b3c34c925c565c3fd3dcbe5d09e23c"
Unverified
Commit
50f1b6d6
authored
Jun 23, 2025
by
Ke Bao
Committed by
GitHub
Jun 22, 2025
Browse files
Remove copy after bmm (#7441)
parent
5962e70d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
4 deletions
+18
-4
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+18
-4
No files found.
python/sglang/srt/models/deepseek_v2.py
View file @
50f1b6d6
...
@@ -1084,13 +1084,16 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -1084,13 +1084,16 @@ class DeepseekV2AttentionMLA(nn.Module):
masked_m
,
masked_m
,
expected_m
,
expected_m
,
)
)
attn_bmm_output
=
attn_bmm_output
[:,
:
expected_m
,
:]
attn_bmm_output
=
(
attn_bmm_output
[:,
:
expected_m
,
:].
transpose
(
0
,
1
).
flatten
(
1
,
2
)
)
elif
_is_hip
:
elif
_is_hip
:
# TODO(haishaw): add bmm_fp8 to ROCm
# TODO(haishaw): add bmm_fp8 to ROCm
attn_bmm_output
=
torch
.
bmm
(
attn_bmm_output
=
torch
.
bmm
(
attn_output
.
to
(
torch
.
bfloat16
).
transpose
(
0
,
1
),
attn_output
.
to
(
torch
.
bfloat16
).
transpose
(
0
,
1
),
self
.
w_vc
.
to
(
torch
.
bfloat16
)
*
self
.
w_scale
,
self
.
w_vc
.
to
(
torch
.
bfloat16
)
*
self
.
w_scale
,
)
)
attn_bmm_output
=
attn_bmm_output
.
transpose
(
0
,
1
).
flatten
(
1
,
2
)
elif
self
.
w_vc
.
dtype
==
torch
.
float8_e4m3fn
:
elif
self
.
w_vc
.
dtype
==
torch
.
float8_e4m3fn
:
attn_output_val
,
attn_output_scale
=
per_tensor_quant_mla_fp8
(
attn_output_val
,
attn_output_scale
=
per_tensor_quant_mla_fp8
(
attn_output
.
transpose
(
0
,
1
),
attn_output
.
transpose
(
0
,
1
),
...
@@ -1103,10 +1106,21 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -1103,10 +1106,21 @@ class DeepseekV2AttentionMLA(nn.Module):
self
.
w_scale
,
self
.
w_scale
,
torch
.
bfloat16
,
torch
.
bfloat16
,
)
)
attn_bmm_output
=
attn_bmm_output
.
transpose
(
0
,
1
).
flatten
(
1
,
2
)
else
:
else
:
attn_bmm_output
=
torch
.
bmm
(
attn_output
.
transpose
(
0
,
1
),
self
.
w_vc
)
attn_bmm_output
=
torch
.
empty
(
attn_output
=
attn_bmm_output
.
transpose
(
0
,
1
).
flatten
(
1
,
2
)
(
attn_output
.
shape
[
0
],
self
.
num_local_heads
*
self
.
v_head_dim
),
output
,
_
=
self
.
o_proj
(
attn_output
)
dtype
=
attn_output
.
dtype
,
device
=
attn_output
.
device
,
)
torch
.
bmm
(
attn_output
.
transpose
(
0
,
1
),
self
.
w_vc
,
out
=
attn_bmm_output
.
view
(
-
1
,
self
.
num_local_heads
,
self
.
v_head_dim
).
transpose
(
0
,
1
),
)
output
,
_
=
self
.
o_proj
(
attn_bmm_output
)
return
output
return
output
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment