Unverified Commit ca21483b authored by Kunshang Ji's avatar Kunshang Ji Committed by GitHub
Browse files

[MISC] fix pin_memory=torch.cuda.is_available(), use is_pin_memory_available (#37415)


Signed-off-by: default avatarKunshang Ji <kunshang.ji@intel.com>
parent da70c87e
...@@ -11,6 +11,7 @@ from transformers import PreTrainedTokenizerBase ...@@ -11,6 +11,7 @@ from transformers import PreTrainedTokenizerBase
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils.import_utils import LazyLoader from vllm.utils.import_utils import LazyLoader
from vllm.utils.platform_utils import is_pin_memory_available
from vllm.v1.structured_output.backend_types import ( from vllm.v1.structured_output.backend_types import (
StructuredOutputBackend, StructuredOutputBackend,
StructuredOutputGrammar, StructuredOutputGrammar,
...@@ -138,7 +139,7 @@ class LMFormatEnforcerBackend(StructuredOutputBackend): ...@@ -138,7 +139,7 @@ class LMFormatEnforcerBackend(StructuredOutputBackend):
(max_num_seqs, (self.vocab_size + 31) // 32), (max_num_seqs, (self.vocab_size + 31) // 32),
-1, -1,
dtype=torch.int32, dtype=torch.int32,
pin_memory=torch.cuda.is_available(), pin_memory=is_pin_memory_available(),
) )
def destroy(self): def destroy(self):
......
...@@ -15,6 +15,7 @@ from regex import escape as regex_escape ...@@ -15,6 +15,7 @@ from regex import escape as regex_escape
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils.import_utils import LazyLoader from vllm.utils.import_utils import LazyLoader
from vllm.utils.platform_utils import is_pin_memory_available
from vllm.v1.structured_output.backend_types import ( from vllm.v1.structured_output.backend_types import (
StructuredOutputBackend, StructuredOutputBackend,
StructuredOutputGrammar, StructuredOutputGrammar,
...@@ -96,7 +97,7 @@ class OutlinesBackend(StructuredOutputBackend): ...@@ -96,7 +97,7 @@ class OutlinesBackend(StructuredOutputBackend):
(max_num_seqs, (self.vocab_size + 31) // 32), (max_num_seqs, (self.vocab_size + 31) // 32),
-1, -1,
dtype=torch.int32, dtype=torch.int32,
pin_memory=torch.cuda.is_available(), pin_memory=is_pin_memory_available(),
) )
def destroy(self): def destroy(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment