Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
622b7ab9
Unverified
Commit
622b7ab9
authored
Oct 29, 2024
by
wangshuai09
Committed by
GitHub
Oct 29, 2024
Browse files
[Hardware] using current_platform.seed_everything (#9785)
Signed-off-by:
wangshuai09
<
391746016@qq.com
>
parent
09500f7d
Changes
27
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
34 additions
and
37 deletions
+34
-37
tests/kernels/test_prefix_prefill.py
tests/kernels/test_prefix_prefill.py
+4
-3
tests/lora/test_layers.py
tests/lora/test_layers.py
+2
-2
tests/lora/test_punica_sizes.py
tests/lora/test_punica_sizes.py
+5
-5
tests/lora/test_punica_variation.py
tests/lora/test_punica_variation.py
+6
-6
vllm/model_executor/utils.py
vllm/model_executor/utils.py
+1
-2
vllm/platforms/interface.py
vllm/platforms/interface.py
+14
-0
vllm/utils.py
vllm/utils.py
+2
-19
No files found.
tests/kernels/test_prefix_prefill.py
View file @
622b7ab9
...
@@ -9,7 +9,8 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
...
@@ -9,7 +9,8 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
from
vllm.attention.backends.xformers
import
_make_alibi_bias
from
vllm.attention.backends.xformers
import
_make_alibi_bias
from
vllm.attention.ops.prefix_prefill
import
context_attention_fwd
from
vllm.attention.ops.prefix_prefill
import
context_attention_fwd
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
seed_everything
from
vllm.platforms
import
current_platform
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
NUM_HEADS
=
[
64
]
NUM_HEADS
=
[
64
]
NUM_QUERIES_PER_KV
=
[
1
,
8
,
64
]
NUM_QUERIES_PER_KV
=
[
1
,
8
,
64
]
...
@@ -39,7 +40,7 @@ def test_contexted_kv_attention(
...
@@ -39,7 +40,7 @@ def test_contexted_kv_attention(
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
device
:
str
,
device
:
str
,
)
->
None
:
)
->
None
:
seed_everything
(
0
)
current_platform
.
seed_everything
(
0
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
# Need this, otherwise when we capture the graph the process
# Need this, otherwise when we capture the graph the process
...
@@ -234,7 +235,7 @@ def test_contexted_kv_attention_alibi(
...
@@ -234,7 +235,7 @@ def test_contexted_kv_attention_alibi(
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
device
:
str
,
device
:
str
,
)
->
None
:
)
->
None
:
seed_everything
(
0
)
current_platform
.
seed_everything
(
0
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
# Need this, otherwise when we capture the graph the process
# Need this, otherwise when we capture the graph the process
...
...
tests/lora/test_layers.py
View file @
622b7ab9
...
@@ -39,7 +39,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
...
@@ -39,7 +39,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
get_masked_input_and_mask
)
ParallelLMHead
,
VocabParallelEmbedding
,
get_masked_input_and_mask
)
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.
util
s
import
seed_everything
from
vllm.
platform
s
import
current_platform
from
.utils
import
DummyLoRAManager
from
.utils
import
DummyLoRAManager
...
@@ -923,7 +923,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
...
@@ -923,7 +923,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
seq_len
)
->
None
:
seq_len
)
->
None
:
dtype
=
torch
.
float16
dtype
=
torch
.
float16
seed
=
0
seed
=
0
seed_everything
(
seed
)
current_platform
.
seed_everything
(
seed
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
punica_wrapper
=
PunicaWrapper
(
8192
,
256
,
device
)
punica_wrapper
=
PunicaWrapper
(
8192
,
256
,
device
)
max_loras
=
8
max_loras
=
8
...
...
tests/lora/test_punica_sizes.py
View file @
622b7ab9
...
@@ -15,8 +15,8 @@ from vllm.lora.ops.bgmv_shrink import bgmv_shrink
...
@@ -15,8 +15,8 @@ from vllm.lora.ops.bgmv_shrink import bgmv_shrink
from
vllm.lora.ops.sgmv_expand
import
sgmv_expand
from
vllm.lora.ops.sgmv_expand
import
sgmv_expand
from
vllm.lora.ops.sgmv_expand_slice
import
sgmv_expand_slice
from
vllm.lora.ops.sgmv_expand_slice
import
sgmv_expand_slice
from
vllm.lora.ops.sgmv_shrink
import
sgmv_shrink
from
vllm.lora.ops.sgmv_shrink
import
sgmv_shrink
from
vllm.platforms
import
current_platform
from
vllm.triton_utils.libentry
import
LibEntry
from
vllm.triton_utils.libentry
import
LibEntry
from
vllm.utils
import
seed_everything
from
.utils
import
(
generate_data
,
generate_data_for_expand_nslices
,
from
.utils
import
(
generate_data
,
generate_data_for_expand_nslices
,
ref_torch_groupgemm
)
ref_torch_groupgemm
)
...
@@ -146,7 +146,7 @@ def test_punica_sgmv(
...
@@ -146,7 +146,7 @@ def test_punica_sgmv(
device
:
str
,
device
:
str
,
):
):
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
seed_everything
(
seed
)
current_platform
.
seed_everything
(
seed
)
seq_length
=
128
seq_length
=
128
(
(
...
@@ -239,7 +239,7 @@ def test_punica_bgmv(
...
@@ -239,7 +239,7 @@ def test_punica_bgmv(
from
vllm.lora.ops.bgmv_shrink
import
_bgmv_shrink_kernel
from
vllm.lora.ops.bgmv_shrink
import
_bgmv_shrink_kernel
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
seed_everything
(
seed
)
current_platform
.
seed_everything
(
seed
)
seq_length
=
1
seq_length
=
1
(
(
...
@@ -327,7 +327,7 @@ def test_punica_expand_nslices(
...
@@ -327,7 +327,7 @@ def test_punica_expand_nslices(
from
vllm.lora.ops.bgmv_expand_slice
import
_bgmv_expand_slice_kernel
from
vllm.lora.ops.bgmv_expand_slice
import
_bgmv_expand_slice_kernel
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
seed_everything
(
seed
)
current_platform
.
seed_everything
(
seed
)
seq_length
=
128
if
op_type
==
"sgmv"
else
1
seq_length
=
128
if
op_type
==
"sgmv"
else
1
(
(
...
...
tests/lora/test_punica_variation.py
View file @
622b7ab9
...
@@ -14,8 +14,8 @@ from vllm.lora.ops.bgmv_shrink import bgmv_shrink
...
@@ -14,8 +14,8 @@ from vllm.lora.ops.bgmv_shrink import bgmv_shrink
from
vllm.lora.ops.sgmv_expand
import
sgmv_expand
from
vllm.lora.ops.sgmv_expand
import
sgmv_expand
from
vllm.lora.ops.sgmv_expand_slice
import
sgmv_expand_slice
from
vllm.lora.ops.sgmv_expand_slice
import
sgmv_expand_slice
from
vllm.lora.ops.sgmv_shrink
import
sgmv_shrink
from
vllm.lora.ops.sgmv_shrink
import
sgmv_shrink
from
vllm.platforms
import
current_platform
from
vllm.triton_utils.libentry
import
LibEntry
from
vllm.triton_utils.libentry
import
LibEntry
from
vllm.utils
import
seed_everything
from
.utils
import
(
generate_data
,
generate_data_for_expand_nslices
,
from
.utils
import
(
generate_data
,
generate_data_for_expand_nslices
,
ref_torch_groupgemm
)
ref_torch_groupgemm
)
...
@@ -61,7 +61,7 @@ def test_punica_sgmv(
...
@@ -61,7 +61,7 @@ def test_punica_sgmv(
device
:
str
,
device
:
str
,
):
):
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
seed_everything
(
seed
)
current_platform
.
seed_everything
(
seed
)
seq_length
=
128
seq_length
=
128
(
(
...
@@ -154,7 +154,7 @@ def test_punica_bgmv(
...
@@ -154,7 +154,7 @@ def test_punica_bgmv(
from
vllm.lora.ops.bgmv_shrink
import
_bgmv_shrink_kernel
from
vllm.lora.ops.bgmv_shrink
import
_bgmv_shrink_kernel
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
seed_everything
(
seed
)
current_platform
.
seed_everything
(
seed
)
seq_length
=
1
seq_length
=
1
(
(
...
@@ -242,7 +242,7 @@ def test_punica_expand_nslices(
...
@@ -242,7 +242,7 @@ def test_punica_expand_nslices(
from
vllm.lora.ops.bgmv_expand_slice
import
_bgmv_expand_slice_kernel
from
vllm.lora.ops.bgmv_expand_slice
import
_bgmv_expand_slice_kernel
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
seed_everything
(
seed
)
current_platform
.
seed_everything
(
seed
)
seq_length
=
128
if
op_type
==
"sgmv"
else
1
seq_length
=
128
if
op_type
==
"sgmv"
else
1
(
(
...
...
vllm/model_executor/utils.py
View file @
622b7ab9
...
@@ -4,11 +4,10 @@ from typing import Any, Dict, Optional
...
@@ -4,11 +4,10 @@ from typing import Any, Dict, Optional
import
torch
import
torch
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
seed_everything
def
set_random_seed
(
seed
:
int
)
->
None
:
def
set_random_seed
(
seed
:
int
)
->
None
:
seed_everything
(
seed
)
current_platform
.
seed_everything
(
seed
)
def
set_weight_attrs
(
def
set_weight_attrs
(
...
...
vllm/platforms/interface.py
View file @
622b7ab9
import
enum
import
enum
import
random
from
typing
import
NamedTuple
,
Optional
,
Tuple
,
Union
from
typing
import
NamedTuple
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
torch
import
torch
...
@@ -111,6 +113,18 @@ class Platform:
...
@@ -111,6 +113,18 @@ class Platform:
"""
"""
return
torch
.
inference_mode
(
mode
=
True
)
return
torch
.
inference_mode
(
mode
=
True
)
@
classmethod
def
seed_everything
(
cls
,
seed
:
int
)
->
None
:
"""
Set the seed of each random module.
`torch.manual_seed` will set seed on all devices.
Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
"""
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
class
UnspecifiedPlatform
(
Platform
):
class
UnspecifiedPlatform
(
Platform
):
_enum
=
PlatformEnum
.
UNSPECIFIED
_enum
=
PlatformEnum
.
UNSPECIFIED
vllm/utils.py
View file @
622b7ab9
...
@@ -7,7 +7,6 @@ import gc
...
@@ -7,7 +7,6 @@ import gc
import
inspect
import
inspect
import
ipaddress
import
ipaddress
import
os
import
os
import
random
import
socket
import
socket
import
subprocess
import
subprocess
import
sys
import
sys
...
@@ -331,22 +330,6 @@ def get_cpu_memory() -> int:
...
@@ -331,22 +330,6 @@ def get_cpu_memory() -> int:
return
psutil
.
virtual_memory
().
total
return
psutil
.
virtual_memory
().
total
def
seed_everything
(
seed
:
int
)
->
None
:
"""
Set the seed of each random module.
Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
"""
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
if
current_platform
.
is_cuda_alike
():
torch
.
cuda
.
manual_seed_all
(
seed
)
if
current_platform
.
is_xpu
():
torch
.
xpu
.
manual_seed_all
(
seed
)
def
random_uuid
()
->
str
:
def
random_uuid
()
->
str
:
return
str
(
uuid
.
uuid4
().
hex
)
return
str
(
uuid
.
uuid4
().
hex
)
...
@@ -643,7 +626,7 @@ def create_kv_caches_with_random_flash(
...
@@ -643,7 +626,7 @@ def create_kv_caches_with_random_flash(
seed
:
int
=
0
,
seed
:
int
=
0
,
device
:
Optional
[
str
]
=
"cuda"
,
device
:
Optional
[
str
]
=
"cuda"
,
)
->
Tuple
[
List
[
torch
.
Tensor
],
List
[
torch
.
Tensor
]]:
)
->
Tuple
[
List
[
torch
.
Tensor
],
List
[
torch
.
Tensor
]]:
seed_everything
(
seed
)
current_platform
.
seed_everything
(
seed
)
torch_dtype
=
get_kv_cache_torch_dtype
(
cache_dtype
,
model_dtype
)
torch_dtype
=
get_kv_cache_torch_dtype
(
cache_dtype
,
model_dtype
)
key_value_cache_shape
=
(
num_blocks
,
2
,
block_size
,
num_heads
,
head_size
)
key_value_cache_shape
=
(
num_blocks
,
2
,
block_size
,
num_heads
,
head_size
)
...
@@ -685,7 +668,7 @@ def create_kv_caches_with_random(
...
@@ -685,7 +668,7 @@ def create_kv_caches_with_random(
f
"Does not support key cache of type fp8 with head_size
{
head_size
}
"
f
"Does not support key cache of type fp8 with head_size
{
head_size
}
"
)
)
seed_everything
(
seed
)
current_platform
.
seed_everything
(
seed
)
torch_dtype
=
get_kv_cache_torch_dtype
(
cache_dtype
,
model_dtype
)
torch_dtype
=
get_kv_cache_torch_dtype
(
cache_dtype
,
model_dtype
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment