Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3a6764a4
Commit
3a6764a4
authored
Aug 20, 2024
by
zhuwenwen
Browse files
fix fa and triton tests
parent
2dbefd03
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
20 additions
and
4 deletions
+20
-4
tests/kernels/test_attention.py
tests/kernels/test_attention.py
+2
-0
vllm/model_executor/layers/ops/rand.py
vllm/model_executor/layers/ops/rand.py
+9
-2
vllm/model_executor/layers/ops/sample.py
vllm/model_executor/layers/ops/sample.py
+9
-2
No files found.
tests/kernels/test_attention.py
View file @
3a6764a4
...
...
@@ -359,6 +359,8 @@ def test_multi_query_kv_attention(
attn_bias
=
attn_bias
,
p
=
0.0
,
scale
=
scale
,
op
=
xops
.
fmha
.
MemoryEfficientAttentionFlashAttentionOp
[
0
]
if
(
is_hip
())
else
None
,
)
output
=
output
.
squeeze
(
0
)
...
...
vllm/model_executor/layers/ops/rand.py
View file @
3a6764a4
...
...
@@ -3,6 +3,7 @@ from typing import Optional, Union
import
torch
import
triton
import
triton.language
as
tl
from
vllm.utils
import
is_hip
def
seeded_uniform
(
...
...
@@ -69,8 +70,14 @@ def seeded_uniform(
# Manual tuning. This seems to give best performance on A100 for
# simple kernels like this.
if
philox_block_size
>=
8192
:
if
is_hip
():
num_warps
=
16
else
:
num_warps
=
32
elif
philox_block_size
>=
4096
:
if
is_hip
():
num_warps
=
8
else
:
num_warps
=
16
elif
philox_block_size
>=
2048
:
num_warps
=
8
...
...
vllm/model_executor/layers/ops/sample.py
View file @
3a6764a4
...
...
@@ -6,6 +6,7 @@ import triton
import
triton.language
as
tl
from
vllm.model_executor.layers.ops.rand
import
seeded_uniform
from
vllm.utils
import
is_hip
_EPS
=
1e-6
...
...
@@ -278,8 +279,14 @@ def _sample(probs: torch.Tensor,
# Manual tuning. This seems to give best performance on A100 for
# simple kernels like this.
if
block_size
>=
8192
:
if
is_hip
():
num_warps
=
16
else
:
num_warps
=
32
elif
block_size
>=
4096
:
if
is_hip
():
num_warps
=
8
else
:
num_warps
=
16
elif
block_size
>=
2048
:
num_warps
=
8
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment