Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1c18cce0
Commit
1c18cce0
authored
Aug 21, 2024
by
zhuwenwen
Browse files
fix tests and update the usage of fa
parent
b40f2ffc
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
29 additions
and
5 deletions
+29
-5
tests/kernels/test_attention.py
tests/kernels/test_attention.py
+2
-0
vllm/model_executor/layers/ops/rand.py
vllm/model_executor/layers/ops/rand.py
+9
-2
vllm/model_executor/layers/ops/sample.py
vllm/model_executor/layers/ops/sample.py
+9
-2
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+5
-0
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+4
-1
No files found.
tests/kernels/test_attention.py
View file @
1c18cce0
...
@@ -364,6 +364,8 @@ def test_multi_query_kv_attention(
...
@@ -364,6 +364,8 @@ def test_multi_query_kv_attention(
attn_bias
=
attn_bias
,
attn_bias
=
attn_bias
,
p
=
0.0
,
p
=
0.0
,
scale
=
scale
,
scale
=
scale
,
op
=
xops
.
fmha
.
MemoryEfficientAttentionFlashAttentionOp
[
0
]
if
(
is_hip
())
else
None
,
)
)
output
=
output
.
squeeze
(
0
)
output
=
output
.
squeeze
(
0
)
...
...
vllm/model_executor/layers/ops/rand.py
View file @
1c18cce0
...
@@ -3,6 +3,7 @@ from typing import Optional, Union
...
@@ -3,6 +3,7 @@ from typing import Optional, Union
import
torch
import
torch
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
from
vllm.utils
import
is_hip
def
seeded_uniform
(
def
seeded_uniform
(
...
@@ -69,8 +70,14 @@ def seeded_uniform(
...
@@ -69,8 +70,14 @@ def seeded_uniform(
# Manual tuning. This seems to give best performance on A100 for
# Manual tuning. This seems to give best performance on A100 for
# simple kernels like this.
# simple kernels like this.
if
philox_block_size
>=
8192
:
if
philox_block_size
>=
8192
:
if
is_hip
():
num_warps
=
16
else
:
num_warps
=
32
num_warps
=
32
elif
philox_block_size
>=
4096
:
elif
philox_block_size
>=
4096
:
if
is_hip
():
num_warps
=
8
else
:
num_warps
=
16
num_warps
=
16
elif
philox_block_size
>=
2048
:
elif
philox_block_size
>=
2048
:
num_warps
=
8
num_warps
=
8
...
...
vllm/model_executor/layers/ops/sample.py
View file @
1c18cce0
...
@@ -6,6 +6,7 @@ import triton.language as tl
...
@@ -6,6 +6,7 @@ import triton.language as tl
from
vllm.model_executor.layers.ops.rand
import
seeded_uniform
from
vllm.model_executor.layers.ops.rand
import
seeded_uniform
from
vllm.triton_utils.sample
import
get_num_triton_sampler_splits
from
vllm.triton_utils.sample
import
get_num_triton_sampler_splits
from
vllm.utils
import
is_hip
_EPS
:
tl
.
constexpr
=
1e-6
_EPS
:
tl
.
constexpr
=
1e-6
...
@@ -266,8 +267,14 @@ def _sample(probs: torch.Tensor,
...
@@ -266,8 +267,14 @@ def _sample(probs: torch.Tensor,
# Manual tuning. This seems to give best performance on A100 for
# Manual tuning. This seems to give best performance on A100 for
# simple kernels like this.
# simple kernels like this.
if
block_size
>=
8192
:
if
block_size
>=
8192
:
if
is_hip
():
num_warps
=
16
else
:
num_warps
=
32
num_warps
=
32
elif
block_size
>=
4096
:
elif
block_size
>=
4096
:
if
is_hip
():
num_warps
=
8
else
:
num_warps
=
16
num_warps
=
16
elif
block_size
>=
2048
:
elif
block_size
>=
2048
:
num_warps
=
8
num_warps
=
8
...
...
vllm/model_executor/model_loader/utils.py
View file @
1c18cce0
...
@@ -23,6 +23,7 @@ def get_model_architecture(
...
@@ -23,6 +23,7 @@ def get_model_architecture(
model_config
:
ModelConfig
)
->
Tuple
[
Type
[
nn
.
Module
],
str
]:
model_config
:
ModelConfig
)
->
Tuple
[
Type
[
nn
.
Module
],
str
]:
architectures
=
getattr
(
model_config
.
hf_config
,
"architectures"
,
[])
architectures
=
getattr
(
model_config
.
hf_config
,
"architectures"
,
[])
support_nn_architectures
=
[
'LlamaForCausalLM'
,
'QWenLMHeadModel'
,
'Qwen2ForCausalLM'
,
'ChatGLMModel'
,
'BaichuanForCausalLM'
]
support_nn_architectures
=
[
'LlamaForCausalLM'
,
'QWenLMHeadModel'
,
'Qwen2ForCausalLM'
,
'ChatGLMModel'
,
'BaichuanForCausalLM'
]
use_triton_fa_architectures
=
[
'DeepseekV2ForCausalLM'
]
if
any
(
arch
in
architectures
for
arch
in
support_nn_architectures
):
if
any
(
arch
in
architectures
for
arch
in
support_nn_architectures
):
if
os
.
getenv
(
'LLAMA_NN'
)
!=
'0'
:
if
os
.
getenv
(
'LLAMA_NN'
)
!=
'0'
:
os
.
environ
[
'LLAMA_NN'
]
=
'1'
os
.
environ
[
'LLAMA_NN'
]
=
'1'
...
@@ -35,6 +36,10 @@ def get_model_architecture(
...
@@ -35,6 +36,10 @@ def get_model_architecture(
os
.
environ
[
'GEMM_PAD'
]
=
'0'
os
.
environ
[
'GEMM_PAD'
]
=
'0'
os
.
environ
[
'FA_PAD'
]
=
'0'
os
.
environ
[
'FA_PAD'
]
=
'0'
if
any
(
arch
in
architectures
for
arch
in
use_triton_fa_architectures
):
os
.
environ
[
'VLLM_USE_TRITON_FLASH_ATTN'
]
=
'1'
os
.
environ
[
'VLLM_USE_FLASH_ATTN_AUTO'
]
=
'0'
# Special handling for quantized Mixtral.
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
# FIXME(woosuk): This is a temporary hack.
if
(
model_config
.
quantization
is
not
None
if
(
model_config
.
quantization
is
not
None
...
...
vllm/worker/model_runner.py
View file @
1c18cce0
...
@@ -903,7 +903,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -903,7 +903,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
import
vllm.envs
as
envs
import
vllm.envs
as
envs
if
envs
.
VLLM_USE_FLASH_ATTN_AUTO
:
if
envs
.
VLLM_USE_FLASH_ATTN_AUTO
:
for
group_id
in
range
(
1
):
for
group_id
in
range
(
1
):
if
max_num_batched_tokens
>=
8000
:
seq_len
=
8000
seq_len
=
8000
else
:
seq_len
=
max_num_batched_tokens
batch_size
+=
seq_len
batch_size
+=
seq_len
seq_data
,
dummy_multi_modal_data
=
INPUT_REGISTRY
\
seq_data
,
dummy_multi_modal_data
=
INPUT_REGISTRY
\
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment