Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
04ff1e43
Unverified
Commit
04ff1e43
authored
Aug 27, 2025
by
Woosuk Kwon
Committed by
GitHub
Aug 27, 2025
Browse files
[Misc] Move CpuGpuBuffer to vllm/v1/utils.py (#23728)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
6578e873
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
33 additions
and
33 deletions
+33
-33
vllm/v1/utils.py
vllm/v1/utils.py
+29
-0
vllm/v1/worker/cpu_model_runner.py
vllm/v1/worker/cpu_model_runner.py
+1
-1
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+3
-3
vllm/v1/worker/utils.py
vllm/v1/worker/utils.py
+0
-29
No files found.
vllm/v1/utils.py
View file @
04ff1e43
...
...
@@ -96,6 +96,35 @@ class ConstantList(Generic[T], Sequence):
return
f
"ConstantList(
{
self
.
_x
}
)"
class
CpuGpuBuffer
:
def
__init__
(
self
,
*
args
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
pin_memory
:
bool
,
):
self
.
cpu
=
torch
.
zeros
(
*
args
,
dtype
=
dtype
,
device
=
"cpu"
,
pin_memory
=
pin_memory
)
self
.
np
=
self
.
cpu
.
numpy
()
self
.
gpu
=
self
.
cpu
.
to
(
device
)
def
copy_to_gpu
(
self
,
n
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
if
n
is
None
:
return
self
.
gpu
.
copy_
(
self
.
cpu
,
non_blocking
=
True
)
return
self
.
gpu
[:
n
].
copy_
(
self
.
cpu
[:
n
],
non_blocking
=
True
)
def
copy_to_cpu
(
self
,
n
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
"""NOTE: Because this method is non-blocking, explicit synchronization
is needed to ensure the data is copied to CPU."""
if
n
is
None
:
return
self
.
cpu
.
copy_
(
self
.
gpu
,
non_blocking
=
True
)
return
self
.
cpu
[:
n
].
copy_
(
self
.
gpu
[:
n
],
non_blocking
=
True
)
def
get_engine_client_zmq_addr
(
local_only
:
bool
,
host
:
str
,
port
:
int
=
0
)
->
str
:
...
...
vllm/v1/worker/cpu_model_runner.py
View file @
04ff1e43
...
...
@@ -10,8 +10,8 @@ from vllm.config import VllmConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader
import
get_model
from
vllm.v1.attention.backends.cpu_attn
import
TorchSDPAMetadataBuilderV1
from
vllm.v1.utils
import
CpuGpuBuffer
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
from
vllm.v1.worker.utils
import
CpuGpuBuffer
if
TYPE_CHECKING
:
from
vllm.v1.core.sched.output
import
SchedulerOutput
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
04ff1e43
...
...
@@ -78,14 +78,14 @@ from vllm.v1.spec_decode.eagle import EagleProposer
from
vllm.v1.spec_decode.medusa
import
MedusaProposer
from
vllm.v1.spec_decode.metadata
import
SpecDecodeMetadata
from
vllm.v1.spec_decode.ngram_proposer
import
NgramProposer
from
vllm.v1.utils
import
CpuGpuBuffer
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
from
vllm.v1.worker.kv_connector_model_runner_mixin
import
(
KVConnectorModelRunnerMixin
,
KVConnectorOutput
)
from
vllm.v1.worker.lora_model_runner_mixin
import
LoRAModelRunnerMixin
from
.utils
import
(
AttentionGroup
,
CpuGpuBuffer
,
MultiModalBudget
,
bind_kv_cache
,
gather_mm_placeholders
,
initialize_kv_cache_for_kv_sharing
,
from
.utils
import
(
AttentionGroup
,
MultiModalBudget
,
bind_kv_cache
,
gather_mm_placeholders
,
initialize_kv_cache_for_kv_sharing
,
sanity_check_mm_encoder_outputs
,
scatter_mm_placeholders
)
if
TYPE_CHECKING
:
...
...
vllm/v1/worker/utils.py
View file @
04ff1e43
...
...
@@ -303,32 +303,3 @@ def bind_kv_cache(
for
layer_name
,
kv_cache
in
kv_caches
.
items
():
# NOTE: Use list because of v0 PP virtual engine.
forward_context
[
layer_name
].
kv_cache
=
[
kv_cache
]
class
CpuGpuBuffer
:
def
__init__
(
self
,
*
args
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
pin_memory
:
bool
,
):
self
.
cpu
=
torch
.
zeros
(
*
args
,
dtype
=
dtype
,
device
=
"cpu"
,
pin_memory
=
pin_memory
)
self
.
np
=
self
.
cpu
.
numpy
()
self
.
gpu
=
self
.
cpu
.
to
(
device
)
def
copy_to_gpu
(
self
,
n
:
Optional
[
int
]
=
None
)
->
None
:
if
n
is
None
:
return
self
.
gpu
.
copy_
(
self
.
cpu
,
non_blocking
=
True
)
return
self
.
gpu
[:
n
].
copy_
(
self
.
cpu
[:
n
],
non_blocking
=
True
)
def
copy_to_cpu
(
self
,
n
:
Optional
[
int
]
=
None
)
->
None
:
"""NOTE: Because this method is non-blocking, explicit synchronization
is needed to ensure the data is copied to CPU."""
if
n
is
None
:
return
self
.
cpu
.
copy_
(
self
.
gpu
,
non_blocking
=
True
)
return
self
.
cpu
[:
n
].
copy_
(
self
.
gpu
[:
n
],
non_blocking
=
True
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment