Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
614aa512
Unverified
Commit
614aa512
authored
Jun 30, 2024
by
youkaichao
Committed by
GitHub
Jun 30, 2024
Browse files
[misc][cuda] use nvml to avoid accidentally cuda initialization (#6007)
parent
af9ad46f
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
86 additions
and
68 deletions
+86
-68
tests/kernels/test_cutlass.py
tests/kernels/test_cutlass.py
+2
-1
tests/quantization/utils.py
tests/quantization/utils.py
+2
-1
vllm/attention/ops/blocksparse_attention/interface.py
vllm/attention/ops/blocksparse_attention/interface.py
+3
-3
vllm/attention/ops/prefix_prefill.py
vllm/attention/ops/prefix_prefill.py
+3
-1
vllm/distributed/device_communicators/custom_all_reduce.py
vllm/distributed/device_communicators/custom_all_reduce.py
+5
-53
vllm/lora/punica.py
vllm/lora/punica.py
+2
-1
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+2
-1
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+2
-2
vllm/model_executor/layers/quantization/gptq_marlin.py
vllm/model_executor/layers/quantization/gptq_marlin.py
+2
-1
vllm/model_executor/layers/quantization/utils/marlin_utils.py
.../model_executor/layers/quantization/utils/marlin_utils.py
+2
-1
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+2
-2
vllm/utils.py
vllm/utils.py
+57
-0
vllm/worker/worker.py
vllm/worker/worker.py
+2
-1
No files found.
tests/kernels/test_cutlass.py
View file @
614aa512
...
...
@@ -8,12 +8,13 @@ import pytest
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.utils
import
get_device_capability_stateless
CUDA_DEVICES
=
[
f
"cuda:
{
i
}
"
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
2
)
]
capability
=
torch
.
cuda
.
get_device_capability
()
capability
=
get_device_capability
_stateless
()
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
...
...
tests/quantization/utils.py
View file @
614aa512
import
torch
from
vllm.model_executor.layers.quantization
import
QUANTIZATION_METHODS
from
vllm.utils
import
get_device_capability_stateless
def
is_quant_method_supported
(
quant_method
:
str
)
->
bool
:
...
...
@@ -8,7 +9,7 @@ def is_quant_method_supported(quant_method: str) -> bool:
if
not
torch
.
cuda
.
is_available
():
return
False
capability
=
torch
.
cuda
.
get_device_capability
()
capability
=
get_device_capability
_stateless
()
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
return
(
capability
>=
QUANTIZATION_METHODS
[
quant_method
].
get_min_capability
())
vllm/attention/ops/blocksparse_attention/interface.py
View file @
614aa512
...
...
@@ -2,13 +2,13 @@ import math
import
torch
from
vllm.utils
import
is_cpu
,
is_hip
from
vllm.utils
import
get_device_capability_stateless
,
is_cpu
,
is_hip
from
.utils
import
(
dense_to_crow_col
,
get_head_sliding_step
,
get_sparse_attn_mask
)
IS_COMPUTE_8_OR_ABOVE
=
(
torch
.
cuda
.
is_available
()
and
torch
.
cuda
.
get_device_capability
()[
0
]
>=
8
)
and
get_device_capability
_stateless
()[
0
]
>=
8
)
if
IS_COMPUTE_8_OR_ABOVE
:
from
.blocksparse_attention_kernel
import
blocksparse_flash_attn_varlen_fwd
...
...
@@ -235,4 +235,4 @@ class LocalStridedBlockSparseAttn(torch.nn.Module):
v
,
cu_seqlens_k
,
cu_seqlens_q
=
cu_seqlens_q
,
sm_scale
=
sm_scale
)
\ No newline at end of file
sm_scale
=
sm_scale
)
vllm/attention/ops/prefix_prefill.py
View file @
614aa512
...
...
@@ -5,6 +5,8 @@ import torch
import
triton
import
triton.language
as
tl
from
vllm.utils
import
get_device_capability_stateless
if
triton
.
__version__
>=
"2.1.0"
:
@
triton
.
jit
...
...
@@ -683,7 +685,7 @@ if triton.__version__ >= "2.1.0":
alibi_slopes
=
None
,
sliding_window
=
None
):
cap
=
torch
.
cuda
.
get_device_capability
()
cap
=
get_device_capability
_stateless
()
BLOCK
=
128
if
cap
[
0
]
>=
8
else
64
# shape constraints
Lq
,
Lk
,
Lv
=
q
.
shape
[
-
1
],
k
.
shape
[
-
1
],
v
.
shape
[
-
1
]
...
...
vllm/distributed/device_communicators/custom_all_reduce.py
View file @
614aa512
...
...
@@ -11,66 +11,18 @@ from vllm.distributed.device_communicators.custom_all_reduce_utils import (
gpu_p2p_access_check
)
from
vllm.distributed.parallel_state
import
is_in_the_same_node
from
vllm.logger
import
init_logger
from
vllm.utils
import
cuda_device_count_stateless
from
vllm.utils
import
cuda_device_count_stateless
,
is_full_nvlink
try
:
import
pynvml
# Simulate ImportError if custom_ar ops are not supported.
if
not
ops
.
is_custom_op_supported
(
"_C_custom_ar::meta_size"
):
raise
ImportError
(
"custom_ar"
,
__file__
)
assert
ops
.
is_custom_op_supported
(
"_C_custom_ar::meta_size"
)
custom_ar
=
True
@
contextmanager
def
_nvml
():
try
:
pynvml
.
nvmlInit
()
yield
finally
:
pynvml
.
nvmlShutdown
()
except
ImportError
:
# For AMD GPUs
except
Exception
:
# For AMD GPUs and CPUs
custom_ar
=
False
pynvml
=
None
@
contextmanager
def
_nvml
():
try
:
yield
finally
:
pass
logger
=
init_logger
(
__name__
)
@
_nvml
()
def
_is_full_nvlink
(
device_ids
:
List
[
int
])
->
bool
:
"""
query if the set of gpus are fully connected by nvlink (1 hop)
Note that `pynvml` is not affected by `CUDA_VISIBLE_DEVICES`,
so it works on real physical device ids.
"""
handles
=
[
pynvml
.
nvmlDeviceGetHandleByIndex
(
i
)
for
i
in
device_ids
]
for
i
,
handle
in
enumerate
(
handles
):
for
j
,
peer_handle
in
enumerate
(
handles
):
if
i
<
j
:
try
:
p2p_status
=
pynvml
.
nvmlDeviceGetP2PStatus
(
handle
,
peer_handle
,
pynvml
.
NVML_P2P_CAPS_INDEX_NVLINK
)
if
p2p_status
!=
pynvml
.
NVML_P2P_STATUS_OK
:
return
False
except
pynvml
.
NVMLError
as
error
:
logger
.
error
(
"NVLink detection failed. This is normal if your"
" machine has no NVLink equipped."
,
exc_info
=
error
)
return
False
return
True
def
_can_p2p
(
rank
:
int
,
world_size
:
int
)
->
bool
:
for
i
in
range
(
world_size
):
if
i
==
rank
:
...
...
@@ -161,7 +113,7 @@ class CustomAllreduce:
# test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported
# this checks hardware and driver support for NVLink
full_nvlink
=
_
is_full_nvlink
(
physical_device_ids
)
full_nvlink
=
is_full_nvlink
(
physical_device_ids
)
if
world_size
>
2
and
not
full_nvlink
:
logger
.
warning
(
"Custom allreduce is disabled because it's not supported on"
...
...
vllm/lora/punica.py
View file @
614aa512
...
...
@@ -5,13 +5,14 @@ from typing import Optional
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.utils
import
get_device_capability_stateless
def
_check_punica_support
():
if
ops
.
is_custom_op_supported
(
"_punica_C::dispatch_bgmv"
):
return
if
torch
.
cuda
.
get_device_capability
()
<
(
8
,
0
):
if
get_device_capability
_stateless
()
<
(
8
,
0
):
raise
ImportError
(
"punica LoRA kernels require compute capability >= 8.0"
)
else
:
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
614aa512
...
...
@@ -14,6 +14,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
CompressionFormat
,
QuantizationArgs
,
QuantizationStrategy
,
find_first_name_or_class_match
)
from
vllm.utils
import
get_device_capability_stateless
class
CompressedTensorsConfig
(
QuantizationConfig
):
...
...
@@ -84,7 +85,7 @@ class CompressedTensorsConfig(QuantizationConfig):
return
[]
def
_check_gptq_and_marlin_can_run
(
self
):
capability
=
torch
.
cuda
.
get_device_capability
()
capability
=
get_device_capability
_stateless
()
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
if
capability
<
80
:
raise
RuntimeError
(
"The quantization config is not supported for "
,
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
614aa512
...
...
@@ -10,7 +10,7 @@ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.utils
import
print_warning_once
from
vllm.utils
import
get_device_capability_stateless
,
print_warning_once
ACTIVATION_SCHEMES
=
[
"static"
,
"dynamic"
]
...
...
@@ -18,7 +18,7 @@ logger = init_logger(__name__)
def
cutlass_fp8_supported
()
->
bool
:
capability
=
torch
.
cuda
.
get_device_capability
()
capability
=
get_device_capability
_stateless
()
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
return
ops
.
cutlass_scaled_mm_supports_fp8
(
capability
)
...
...
vllm/model_executor/layers/quantization/gptq_marlin.py
View file @
614aa512
...
...
@@ -11,6 +11,7 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
set_weight_attrs
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.utils
import
get_device_capability_stateless
logger
=
init_logger
(
__name__
)
...
...
@@ -165,7 +166,7 @@ class GPTQMarlinConfig(QuantizationConfig):
return
False
# If the capability of the device is too low, cannot convert.
major
,
minor
=
torch
.
cuda
.
get_device_capability
()
major
,
minor
=
get_device_capability
_stateless
()
device_capability
=
major
*
10
+
minor
if
device_capability
<
cls
.
get_min_capability
():
return
False
...
...
vllm/model_executor/layers/quantization/utils/marlin_utils.py
View file @
614aa512
...
...
@@ -12,8 +12,9 @@ from vllm.model_executor.layers.quantization.utils.marlin_perms import (
marlin_perm
,
marlin_scale_perm
,
marlin_scale_perm_single
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
get_pack_factor
,
quantize_weights
,
sort_weights
)
from
vllm.utils
import
get_device_capability_stateless
__cuda_arch
=
torch
.
cuda
.
get_device_capability
()
__cuda_arch
=
get_device_capability
_stateless
()
MARLIN_TILE
=
16
...
...
vllm/model_executor/model_loader/loader.py
View file @
614aa512
...
...
@@ -35,7 +35,7 @@ from vllm.model_executor.model_loader.weight_utils import (
from
vllm.model_executor.models.interfaces
import
(
supports_lora
,
supports_vision
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.utils
import
is_tpu
from
vllm.utils
import
get_device_capability_stateless
,
is_tpu
logger
=
init_logger
(
__name__
)
...
...
@@ -46,7 +46,7 @@ def _get_quantization_config(
"""Get the quantization config."""
if
model_config
.
quantization
is
not
None
:
quant_config
=
get_quant_config
(
model_config
,
load_config
)
capability
=
torch
.
cuda
.
get_device_capability
()
capability
=
get_device_capability
_stateless
()
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
if
capability
<
quant_config
.
get_min_capability
():
raise
ValueError
(
...
...
vllm/utils.py
View file @
614aa512
...
...
@@ -816,6 +816,63 @@ def cuda_device_count_stateless() -> int:
return
_cuda_device_count_stateless
(
envs
.
CUDA_VISIBLE_DEVICES
)
# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using NVML is that it will not initialize CUDA
try
:
import
pynvml
except
ImportError
:
# For non-NV devices
pynvml
=
None
def
with_nvml_context
(
fn
):
@
wraps
(
fn
)
def
wrapper
(
*
args
,
**
kwargs
):
if
pynvml
is
not
None
:
pynvml
.
nvmlInit
()
try
:
return
fn
(
*
args
,
**
kwargs
)
finally
:
if
pynvml
is
not
None
:
pynvml
.
nvmlShutdown
()
return
wrapper
@
with_nvml_context
def
is_full_nvlink
(
device_ids
:
List
[
int
])
->
bool
:
"""
query if the set of gpus are fully connected by nvlink (1 hop)
"""
handles
=
[
pynvml
.
nvmlDeviceGetHandleByIndex
(
i
)
for
i
in
device_ids
]
for
i
,
handle
in
enumerate
(
handles
):
for
j
,
peer_handle
in
enumerate
(
handles
):
if
i
<
j
:
try
:
p2p_status
=
pynvml
.
nvmlDeviceGetP2PStatus
(
handle
,
peer_handle
,
pynvml
.
NVML_P2P_CAPS_INDEX_NVLINK
)
if
p2p_status
!=
pynvml
.
NVML_P2P_STATUS_OK
:
return
False
except
pynvml
.
NVMLError
as
error
:
logger
.
error
(
"NVLink detection failed. This is normal if your"
" machine has no NVLink equipped."
,
exc_info
=
error
)
return
False
return
True
@
lru_cache
(
maxsize
=
8
)
@
with_nvml_context
def
get_device_capability_stateless
(
device_id
:
int
=
0
)
->
Tuple
[
int
,
int
]:
handle
=
pynvml
.
nvmlDeviceGetHandleByIndex
(
device_id
)
return
pynvml
.
nvmlDeviceGetCudaComputeCapability
(
handle
)
#From: https://stackoverflow.com/a/4104188/2749989
def
run_once
(
f
):
...
...
vllm/worker/worker.py
View file @
614aa512
...
...
@@ -16,6 +16,7 @@ from vllm.lora.request import LoRARequest
from
vllm.model_executor
import
set_random_seed
from
vllm.model_executor.model_loader.tensorizer
import
TensorizerConfig
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.utils
import
get_device_capability_stateless
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.worker.embedding_model_runner
import
EmbeddingModelRunner
from
vllm.worker.model_runner
import
GPUModelRunnerBase
,
ModelRunner
...
...
@@ -322,7 +323,7 @@ def init_worker_distributed_environment(
def
_check_if_gpu_supports_dtype
(
torch_dtype
:
torch
.
dtype
):
# Check if the GPU supports the dtype.
if
torch_dtype
==
torch
.
bfloat16
:
compute_capability
=
torch
.
cuda
.
get_device_capability
()
compute_capability
=
get_device_capability
_stateless
()
if
compute_capability
[
0
]
<
8
:
gpu_name
=
torch
.
cuda
.
get_device_name
()
raise
ValueError
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment