Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
622b7ab9
Unverified
Commit
622b7ab9
authored
Oct 29, 2024
by
wangshuai09
Committed by
GitHub
Oct 29, 2024
Browse files
[Hardware] using current_platform.seed_everything (#9785)
Signed-off-by:
wangshuai09
<
391746016@qq.com
>
parent
09500f7d
Changes
27
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
34 additions
and
37 deletions
+34
-37
tests/kernels/test_prefix_prefill.py
tests/kernels/test_prefix_prefill.py
+4
-3
tests/lora/test_layers.py
tests/lora/test_layers.py
+2
-2
tests/lora/test_punica_sizes.py
tests/lora/test_punica_sizes.py
+5
-5
tests/lora/test_punica_variation.py
tests/lora/test_punica_variation.py
+6
-6
vllm/model_executor/utils.py
vllm/model_executor/utils.py
+1
-2
vllm/platforms/interface.py
vllm/platforms/interface.py
+14
-0
vllm/utils.py
vllm/utils.py
+2
-19
No files found.
tests/kernels/test_prefix_prefill.py
View file @
622b7ab9
...
@@ -9,7 +9,8 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
...
@@ -9,7 +9,8 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
from
vllm.attention.backends.xformers
import
_make_alibi_bias
from
vllm.attention.backends.xformers
import
_make_alibi_bias
from
vllm.attention.ops.prefix_prefill
import
context_attention_fwd
from
vllm.attention.ops.prefix_prefill
import
context_attention_fwd
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
seed_everything
from
vllm.platforms
import
current_platform
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
NUM_HEADS
=
[
64
]
NUM_HEADS
=
[
64
]
NUM_QUERIES_PER_KV
=
[
1
,
8
,
64
]
NUM_QUERIES_PER_KV
=
[
1
,
8
,
64
]
...
@@ -39,7 +40,7 @@ def test_contexted_kv_attention(
...
@@ -39,7 +40,7 @@ def test_contexted_kv_attention(
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
device
:
str
,
device
:
str
,
)
->
None
:
)
->
None
:
seed_everything
(
0
)
current_platform
.
seed_everything
(
0
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
# Need this, otherwise when we capture the graph the process
# Need this, otherwise when we capture the graph the process
...
@@ -234,7 +235,7 @@ def test_contexted_kv_attention_alibi(
...
@@ -234,7 +235,7 @@ def test_contexted_kv_attention_alibi(
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
device
:
str
,
device
:
str
,
)
->
None
:
)
->
None
:
seed_everything
(
0
)
current_platform
.
seed_everything
(
0
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
# Need this, otherwise when we capture the graph the process
# Need this, otherwise when we capture the graph the process
...
...
tests/lora/test_layers.py
View file @
622b7ab9
...
@@ -39,7 +39,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
...
@@ -39,7 +39,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
get_masked_input_and_mask
)
ParallelLMHead
,
VocabParallelEmbedding
,
get_masked_input_and_mask
)
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.
util
s
import
seed_everything
from
vllm.
platform
s
import
current_platform
from
.utils
import
DummyLoRAManager
from
.utils
import
DummyLoRAManager
...
@@ -923,7 +923,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
...
@@ -923,7 +923,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
seq_len
)
->
None
:
seq_len
)
->
None
:
dtype
=
torch
.
float16
dtype
=
torch
.
float16
seed
=
0
seed
=
0
seed_everything
(
seed
)
current_platform
.
seed_everything
(
seed
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
punica_wrapper
=
PunicaWrapper
(
8192
,
256
,
device
)
punica_wrapper
=
PunicaWrapper
(
8192
,
256
,
device
)
max_loras
=
8
max_loras
=
8
...
...
tests/lora/test_punica_sizes.py
View file @
622b7ab9
"""
"""
This script is mainly used to tests various hidden_sizes. We have collected the
This script is mainly used to tests various hidden_sizes. We have collected the
hidden_sizes included in the LoRA models currently supported by vLLM. It tests
hidden_sizes included in the LoRA models currently supported by vLLM. It tests
whether the corresponding Triton kernel can run normally when tensor parallelism
whether the corresponding Triton kernel can run normally when tensor parallelism
is set to [1, 2, 4, 8, 16, 32, 64].
is set to [1, 2, 4, 8, 16, 32, 64].
...
@@ -15,8 +15,8 @@ from vllm.lora.ops.bgmv_shrink import bgmv_shrink
...
@@ -15,8 +15,8 @@ from vllm.lora.ops.bgmv_shrink import bgmv_shrink
from
vllm.lora.ops.sgmv_expand
import
sgmv_expand
from
vllm.lora.ops.sgmv_expand
import
sgmv_expand
from
vllm.lora.ops.sgmv_expand_slice
import
sgmv_expand_slice
from
vllm.lora.ops.sgmv_expand_slice
import
sgmv_expand_slice
from
vllm.lora.ops.sgmv_shrink
import
sgmv_shrink
from
vllm.lora.ops.sgmv_shrink
import
sgmv_shrink
from
vllm.platforms
import
current_platform
from
vllm.triton_utils.libentry
import
LibEntry
from
vllm.triton_utils.libentry
import
LibEntry
from
vllm.utils
import
seed_everything
from
.utils
import
(
generate_data
,
generate_data_for_expand_nslices
,
from
.utils
import
(
generate_data
,
generate_data_for_expand_nslices
,
ref_torch_groupgemm
)
ref_torch_groupgemm
)
...
@@ -146,7 +146,7 @@ def test_punica_sgmv(
...
@@ -146,7 +146,7 @@ def test_punica_sgmv(
device
:
str
,
device
:
str
,
):
):
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
seed_everything
(
seed
)
current_platform
.
seed_everything
(
seed
)
seq_length
=
128
seq_length
=
128
(
(
...
@@ -239,7 +239,7 @@ def test_punica_bgmv(
...
@@ -239,7 +239,7 @@ def test_punica_bgmv(
from
vllm.lora.ops.bgmv_shrink
import
_bgmv_shrink_kernel
from
vllm.lora.ops.bgmv_shrink
import
_bgmv_shrink_kernel
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
seed_everything
(
seed
)
current_platform
.
seed_everything
(
seed
)
seq_length
=
1
seq_length
=
1
(
(
...
@@ -327,7 +327,7 @@ def test_punica_expand_nslices(
...
@@ -327,7 +327,7 @@ def test_punica_expand_nslices(
from
vllm.lora.ops.bgmv_expand_slice
import
_bgmv_expand_slice_kernel
from
vllm.lora.ops.bgmv_expand_slice
import
_bgmv_expand_slice_kernel
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
seed_everything
(
seed
)
current_platform
.
seed_everything
(
seed
)
seq_length
=
128
if
op_type
==
"sgmv"
else
1
seq_length
=
128
if
op_type
==
"sgmv"
else
1
(
(
...
...
tests/lora/test_punica_variation.py
View file @
622b7ab9
"""
"""
This script is mainly used to test whether trtion kernels can run normally
This script is mainly used to test whether trtion kernels can run normally
under different conditions, including various batches, numbers of LoRA , and
under different conditions, including various batches, numbers of LoRA , and
maximum ranks.
maximum ranks.
"""
"""
from
unittest.mock
import
patch
from
unittest.mock
import
patch
...
@@ -14,8 +14,8 @@ from vllm.lora.ops.bgmv_shrink import bgmv_shrink
...
@@ -14,8 +14,8 @@ from vllm.lora.ops.bgmv_shrink import bgmv_shrink
from
vllm.lora.ops.sgmv_expand
import
sgmv_expand
from
vllm.lora.ops.sgmv_expand
import
sgmv_expand
from
vllm.lora.ops.sgmv_expand_slice
import
sgmv_expand_slice
from
vllm.lora.ops.sgmv_expand_slice
import
sgmv_expand_slice
from
vllm.lora.ops.sgmv_shrink
import
sgmv_shrink
from
vllm.lora.ops.sgmv_shrink
import
sgmv_shrink
from
vllm.platforms
import
current_platform
from
vllm.triton_utils.libentry
import
LibEntry
from
vllm.triton_utils.libentry
import
LibEntry
from
vllm.utils
import
seed_everything
from
.utils
import
(
generate_data
,
generate_data_for_expand_nslices
,
from
.utils
import
(
generate_data
,
generate_data_for_expand_nslices
,
ref_torch_groupgemm
)
ref_torch_groupgemm
)
...
@@ -61,7 +61,7 @@ def test_punica_sgmv(
...
@@ -61,7 +61,7 @@ def test_punica_sgmv(
device
:
str
,
device
:
str
,
):
):
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
seed_everything
(
seed
)
current_platform
.
seed_everything
(
seed
)
seq_length
=
128
seq_length
=
128
(
(
...
@@ -154,7 +154,7 @@ def test_punica_bgmv(
...
@@ -154,7 +154,7 @@ def test_punica_bgmv(
from
vllm.lora.ops.bgmv_shrink
import
_bgmv_shrink_kernel
from
vllm.lora.ops.bgmv_shrink
import
_bgmv_shrink_kernel
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
seed_everything
(
seed
)
current_platform
.
seed_everything
(
seed
)
seq_length
=
1
seq_length
=
1
(
(
...
@@ -242,7 +242,7 @@ def test_punica_expand_nslices(
...
@@ -242,7 +242,7 @@ def test_punica_expand_nslices(
from
vllm.lora.ops.bgmv_expand_slice
import
_bgmv_expand_slice_kernel
from
vllm.lora.ops.bgmv_expand_slice
import
_bgmv_expand_slice_kernel
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
seed_everything
(
seed
)
current_platform
.
seed_everything
(
seed
)
seq_length
=
128
if
op_type
==
"sgmv"
else
1
seq_length
=
128
if
op_type
==
"sgmv"
else
1
(
(
...
...
vllm/model_executor/utils.py
View file @
622b7ab9
...
@@ -4,11 +4,10 @@ from typing import Any, Dict, Optional
...
@@ -4,11 +4,10 @@ from typing import Any, Dict, Optional
import
torch
import
torch
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
seed_everything
def
set_random_seed
(
seed
:
int
)
->
None
:
def
set_random_seed
(
seed
:
int
)
->
None
:
seed_everything
(
seed
)
current_platform
.
seed_everything
(
seed
)
def
set_weight_attrs
(
def
set_weight_attrs
(
...
...
vllm/platforms/interface.py
View file @
622b7ab9
import
enum
import
enum
import
random
from
typing
import
NamedTuple
,
Optional
,
Tuple
,
Union
from
typing
import
NamedTuple
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
torch
import
torch
...
@@ -111,6 +113,18 @@ class Platform:
...
@@ -111,6 +113,18 @@ class Platform:
"""
"""
return
torch
.
inference_mode
(
mode
=
True
)
return
torch
.
inference_mode
(
mode
=
True
)
@
classmethod
def
seed_everything
(
cls
,
seed
:
int
)
->
None
:
"""
Set the seed of each random module.
`torch.manual_seed` will set seed on all devices.
Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
"""
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
class
UnspecifiedPlatform
(
Platform
):
class
UnspecifiedPlatform
(
Platform
):
_enum
=
PlatformEnum
.
UNSPECIFIED
_enum
=
PlatformEnum
.
UNSPECIFIED
vllm/utils.py
View file @
622b7ab9
...
@@ -7,7 +7,6 @@ import gc
...
@@ -7,7 +7,6 @@ import gc
import
inspect
import
inspect
import
ipaddress
import
ipaddress
import
os
import
os
import
random
import
socket
import
socket
import
subprocess
import
subprocess
import
sys
import
sys
...
@@ -331,22 +330,6 @@ def get_cpu_memory() -> int:
...
@@ -331,22 +330,6 @@ def get_cpu_memory() -> int:
return
psutil
.
virtual_memory
().
total
return
psutil
.
virtual_memory
().
total
def
seed_everything
(
seed
:
int
)
->
None
:
"""
Set the seed of each random module.
Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
"""
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
if
current_platform
.
is_cuda_alike
():
torch
.
cuda
.
manual_seed_all
(
seed
)
if
current_platform
.
is_xpu
():
torch
.
xpu
.
manual_seed_all
(
seed
)
def
random_uuid
()
->
str
:
def
random_uuid
()
->
str
:
return
str
(
uuid
.
uuid4
().
hex
)
return
str
(
uuid
.
uuid4
().
hex
)
...
@@ -643,7 +626,7 @@ def create_kv_caches_with_random_flash(
...
@@ -643,7 +626,7 @@ def create_kv_caches_with_random_flash(
seed
:
int
=
0
,
seed
:
int
=
0
,
device
:
Optional
[
str
]
=
"cuda"
,
device
:
Optional
[
str
]
=
"cuda"
,
)
->
Tuple
[
List
[
torch
.
Tensor
],
List
[
torch
.
Tensor
]]:
)
->
Tuple
[
List
[
torch
.
Tensor
],
List
[
torch
.
Tensor
]]:
seed_everything
(
seed
)
current_platform
.
seed_everything
(
seed
)
torch_dtype
=
get_kv_cache_torch_dtype
(
cache_dtype
,
model_dtype
)
torch_dtype
=
get_kv_cache_torch_dtype
(
cache_dtype
,
model_dtype
)
key_value_cache_shape
=
(
num_blocks
,
2
,
block_size
,
num_heads
,
head_size
)
key_value_cache_shape
=
(
num_blocks
,
2
,
block_size
,
num_heads
,
head_size
)
...
@@ -685,7 +668,7 @@ def create_kv_caches_with_random(
...
@@ -685,7 +668,7 @@ def create_kv_caches_with_random(
f
"Does not support key cache of type fp8 with head_size
{
head_size
}
"
f
"Does not support key cache of type fp8 with head_size
{
head_size
}
"
)
)
seed_everything
(
seed
)
current_platform
.
seed_everything
(
seed
)
torch_dtype
=
get_kv_cache_torch_dtype
(
cache_dtype
,
model_dtype
)
torch_dtype
=
get_kv_cache_torch_dtype
(
cache_dtype
,
model_dtype
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment