Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
67c4637c
Unverified
Commit
67c4637c
authored
Feb 09, 2025
by
Nick Hill
Committed by
GitHub
Feb 10, 2025
Browse files
[V1] Use msgpack for core request serialization (#12918)
Signed-off-by:
Nick Hill
<
nhill@redhat.com
>
parent
aa0ca5eb
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
62 additions
and
95 deletions
+62
-95
vllm/v1/engine/__init__.py
vllm/v1/engine/__init__.py
+14
-28
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+26
-35
vllm/v1/engine/core_client.py
vllm/v1/engine/core_client.py
+11
-16
vllm/v1/serial_utils.py
vllm/v1/serial_utils.py
+11
-16
No files found.
vllm/v1/engine/__init__.py
View file @
67c4637c
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
enum
import
enum
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Union
import
msgspec
import
msgspec
from
vllm.lora.request
import
LoRARequest
from
vllm.multimodal
import
MultiModalKwargs
from
vllm.multimodal.inputs
import
PlaceholderRange
from
vllm.sampling_params
import
SamplingParams
from
vllm.v1.metrics.stats
import
SchedulerStats
from
vllm.v1.metrics.stats
import
SchedulerStats
from
vllm.v1.outputs
import
LogprobsLists
,
LogprobsTensors
from
vllm.v1.outputs
import
LogprobsLists
,
LogprobsTensors
if
TYPE_CHECKING
:
from
vllm.lora.request
import
LoRARequest
from
vllm.multimodal
import
MultiModalKwargs
from
vllm.multimodal.inputs
import
PlaceholderRange
from
vllm.sampling_params
import
SamplingParams
# These are possible values of RequestOutput.finish_reason,
# These are possible values of RequestOutput.finish_reason,
# so form part of the external API.
# so form part of the external API.
FINISH_REASON_STRINGS
=
(
"stop"
,
"length"
,
"abort"
)
FINISH_REASON_STRINGS
=
(
"stop"
,
"length"
,
"abort"
)
...
@@ -39,8 +36,11 @@ class FinishReason(enum.IntEnum):
...
@@ -39,8 +36,11 @@ class FinishReason(enum.IntEnum):
return
FINISH_REASON_STRINGS
[
self
.
value
]
return
FINISH_REASON_STRINGS
[
self
.
value
]
@
dataclass
class
EngineCoreRequest
(
class
EngineCoreRequest
:
msgspec
.
Struct
,
array_like
=
True
,
# type: ignore[call-arg]
omit_defaults
=
True
,
# type: ignore[call-arg]
gc
=
False
):
# type: ignore[call-arg]
# NOTE: prompt and prompt_token_ids should be DecoderOnlyInput,
# NOTE: prompt and prompt_token_ids should be DecoderOnlyInput,
# but this object is currently not playing well with msgspec
# but this object is currently not playing well with msgspec
...
@@ -51,13 +51,13 @@ class EngineCoreRequest:
...
@@ -51,13 +51,13 @@ class EngineCoreRequest:
# Detokenizer, but set to None when it is added to EngineCoreClient.
# Detokenizer, but set to None when it is added to EngineCoreClient.
prompt
:
Optional
[
str
]
prompt
:
Optional
[
str
]
prompt_token_ids
:
List
[
int
]
prompt_token_ids
:
List
[
int
]
mm_inputs
:
Optional
[
List
[
Optional
[
"
MultiModalKwargs
"
]]]
mm_inputs
:
Optional
[
List
[
Optional
[
MultiModalKwargs
]]]
mm_hashes
:
Optional
[
List
[
str
]]
mm_hashes
:
Optional
[
List
[
str
]]
mm_placeholders
:
Optional
[
List
[
"
PlaceholderRange
"
]]
mm_placeholders
:
Optional
[
List
[
PlaceholderRange
]]
sampling_params
:
"
SamplingParams
"
sampling_params
:
SamplingParams
eos_token_id
:
Optional
[
int
]
eos_token_id
:
Optional
[
int
]
arrival_time
:
float
arrival_time
:
float
lora_request
:
Optional
[
"
LoRARequest
"
]
lora_request
:
Optional
[
LoRARequest
]
class
EngineCoreOutput
(
class
EngineCoreOutput
(
...
@@ -94,16 +94,6 @@ class EngineCoreOutputs(
...
@@ -94,16 +94,6 @@ class EngineCoreOutputs(
scheduler_stats
:
SchedulerStats
scheduler_stats
:
SchedulerStats
@
dataclass
class
EngineCoreProfile
:
is_start
:
bool
@
dataclass
class
EngineCoreResetPrefixCache
:
pass
class
EngineCoreRequestType
(
enum
.
Enum
):
class
EngineCoreRequestType
(
enum
.
Enum
):
"""
"""
Request types defined as hex byte strings, so it can be sent over sockets
Request types defined as hex byte strings, so it can be sent over sockets
...
@@ -113,7 +103,3 @@ class EngineCoreRequestType(enum.Enum):
...
@@ -113,7 +103,3 @@ class EngineCoreRequestType(enum.Enum):
ABORT
=
b
'
\x01
'
ABORT
=
b
'
\x01
'
PROFILE
=
b
'
\x02
'
PROFILE
=
b
'
\x02
'
RESET_PREFIX_CACHE
=
b
'
\x03
'
RESET_PREFIX_CACHE
=
b
'
\x03
'
EngineCoreRequestUnion
=
Union
[
EngineCoreRequest
,
EngineCoreProfile
,
EngineCoreResetPrefixCache
,
List
[
str
]]
vllm/v1/engine/core.py
View file @
67c4637c
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
pickle
import
queue
import
queue
import
signal
import
signal
import
threading
import
threading
import
time
import
time
from
multiprocessing.connection
import
Connection
from
multiprocessing.connection
import
Connection
from
typing
import
List
,
Tuple
,
Type
from
typing
import
Any
,
List
,
Tuple
,
Type
import
psutil
import
psutil
import
zmq
import
zmq
...
@@ -19,13 +18,12 @@ from vllm.transformers_utils.config import (
...
@@ -19,13 +18,12 @@ from vllm.transformers_utils.config import (
from
vllm.utils
import
get_exception_traceback
,
zmq_socket_ctx
from
vllm.utils
import
get_exception_traceback
,
zmq_socket_ctx
from
vllm.v1.core.kv_cache_utils
import
get_kv_cache_config
from
vllm.v1.core.kv_cache_utils
import
get_kv_cache_config
from
vllm.v1.core.scheduler
import
Scheduler
from
vllm.v1.core.scheduler
import
Scheduler
from
vllm.v1.engine
import
(
EngineCoreOutputs
,
EngineCoreProfile
,
from
vllm.v1.engine
import
(
EngineCoreOutputs
,
EngineCoreRequest
,
EngineCoreRequest
,
EngineCoreRequestType
,
EngineCoreRequestType
)
EngineCoreRequestUnion
,
EngineCoreResetPrefixCache
)
from
vllm.v1.engine.mm_input_mapper
import
MMInputMapperServer
from
vllm.v1.engine.mm_input_mapper
import
MMInputMapperServer
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.request
import
Request
,
RequestStatus
from
vllm.v1.request
import
Request
,
RequestStatus
from
vllm.v1.serial_utils
import
Msgpack
En
coder
,
Pickle
Encoder
from
vllm.v1.serial_utils
import
Msgpack
De
coder
,
Msgpack
Encoder
from
vllm.version
import
__version__
as
VLLM_VERSION
from
vllm.version
import
__version__
as
VLLM_VERSION
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -161,7 +159,8 @@ class EngineCoreProc(EngineCore):
...
@@ -161,7 +159,8 @@ class EngineCoreProc(EngineCore):
# and to overlap some serialization/deserialization with the
# and to overlap some serialization/deserialization with the
# model forward pass.
# model forward pass.
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
self
.
input_queue
:
queue
.
Queue
[
EngineCoreRequestUnion
]
=
queue
.
Queue
()
self
.
input_queue
:
queue
.
Queue
[
Tuple
[
EngineCoreRequestType
,
Any
]]
=
queue
.
Queue
()
self
.
output_queue
:
queue
.
Queue
[
EngineCoreOutputs
]
=
queue
.
Queue
()
self
.
output_queue
:
queue
.
Queue
[
EngineCoreOutputs
]
=
queue
.
Queue
()
threading
.
Thread
(
target
=
self
.
process_input_socket
,
threading
.
Thread
(
target
=
self
.
process_input_socket
,
args
=
(
input_path
,
),
args
=
(
input_path
,
),
...
@@ -223,7 +222,7 @@ class EngineCoreProc(EngineCore):
...
@@ -223,7 +222,7 @@ class EngineCoreProc(EngineCore):
while
True
:
while
True
:
try
:
try
:
req
=
self
.
input_queue
.
get
(
timeout
=
POLLING_TIMEOUT_S
)
req
=
self
.
input_queue
.
get
(
timeout
=
POLLING_TIMEOUT_S
)
self
.
_handle_client_request
(
req
)
self
.
_handle_client_request
(
*
req
)
break
break
except
queue
.
Empty
:
except
queue
.
Empty
:
logger
.
debug
(
"EngineCore busy loop waiting."
)
logger
.
debug
(
"EngineCore busy loop waiting."
)
...
@@ -233,10 +232,10 @@ class EngineCoreProc(EngineCore):
...
@@ -233,10 +232,10 @@ class EngineCoreProc(EngineCore):
except
BaseException
:
except
BaseException
:
raise
raise
# 2) Handle any new client requests
(Abort or Add)
.
# 2) Handle any new client requests.
while
not
self
.
input_queue
.
empty
():
while
not
self
.
input_queue
.
empty
():
req
=
self
.
input_queue
.
get_nowait
()
req
=
self
.
input_queue
.
get_nowait
()
self
.
_handle_client_request
(
req
)
self
.
_handle_client_request
(
*
req
)
# 3) Step the engine core.
# 3) Step the engine core.
outputs
=
self
.
step
()
outputs
=
self
.
step
()
...
@@ -244,48 +243,40 @@ class EngineCoreProc(EngineCore):
...
@@ -244,48 +243,40 @@ class EngineCoreProc(EngineCore):
# 5) Put EngineCoreOutputs into the output queue.
# 5) Put EngineCoreOutputs into the output queue.
self
.
output_queue
.
put_nowait
(
outputs
)
self
.
output_queue
.
put_nowait
(
outputs
)
def
_handle_client_request
(
self
,
request
:
EngineCoreRequestUnion
)
->
None
:
def
_handle_client_request
(
self
,
request_type
:
EngineCoreRequestType
,
"""Handle EngineCoreRequest or EngineCoreABORT from Client."""
request
:
Any
)
->
None
:
"""Dispatch request from client."""
if
isinstance
(
request
,
EngineCoreRequest
)
:
if
request_type
==
EngineCoreRequest
Type
.
ADD
:
self
.
add_request
(
request
)
self
.
add_request
(
request
)
elif
isinstance
(
request
,
EngineCoreProfile
):
elif
request_type
==
EngineCoreRequestType
.
ABORT
:
self
.
model_executor
.
profile
(
request
.
is_start
)
elif
isinstance
(
request
,
EngineCoreResetPrefixCache
):
self
.
reset_prefix_cache
()
else
:
# TODO: make an EngineCoreAbort wrapper
assert
isinstance
(
request
,
list
)
self
.
abort_requests
(
request
)
self
.
abort_requests
(
request
)
elif
request_type
==
EngineCoreRequestType
.
RESET_PREFIX_CACHE
:
self
.
reset_prefix_cache
()
elif
request_type
==
EngineCoreRequestType
.
PROFILE
:
self
.
model_executor
.
profile
(
request
)
def
process_input_socket
(
self
,
input_path
:
str
):
def
process_input_socket
(
self
,
input_path
:
str
):
"""Input socket IO thread."""
"""Input socket IO thread."""
# Msgpack serialization decoding.
# Msgpack serialization decoding.
decoder_add_req
=
PickleEncoder
(
)
add_request_decoder
=
MsgpackDecoder
(
EngineCoreRequest
)
decoder_abort_req
=
PickleEn
coder
()
generic_decoder
=
MsgpackDe
coder
()
with
zmq_socket_ctx
(
input_path
,
zmq
.
constants
.
PULL
)
as
socket
:
with
zmq_socket_ctx
(
input_path
,
zmq
.
constants
.
PULL
)
as
socket
:
while
True
:
while
True
:
# (RequestType, RequestData)
# (RequestType, RequestData)
type_frame
,
data_frame
=
socket
.
recv_multipart
(
copy
=
False
)
type_frame
,
data_frame
=
socket
.
recv_multipart
(
copy
=
False
)
request_type
=
type_frame
.
buffer
request_type
=
EngineCoreRequestType
(
bytes
(
type_frame
.
buffer
))
request_data
=
data_frame
.
buffer
# Deserialize the request data.
# Deserialize the request data.
if
request_type
==
EngineCoreRequestType
.
ADD
.
value
:
decoder
=
add_request_decoder
if
(
request
=
decoder_add_req
.
decode
(
request_data
)
request_type
elif
request_type
==
EngineCoreRequestType
.
ABORT
.
value
:
==
EngineCoreRequestType
.
ADD
)
else
generic_decoder
request
=
decoder_abort_req
.
decode
(
request_data
)
request
=
decoder
.
decode
(
data_frame
.
buffer
)
elif
request_type
in
(
EngineCoreRequestType
.
PROFILE
.
value
,
EngineCoreRequestType
.
RESET_PREFIX_CACHE
.
value
):
request
=
pickle
.
loads
(
request_data
)
else
:
raise
ValueError
(
f
"Unknown RequestType:
{
request_type
}
"
)
# Push to input queue for core busy loop.
# Push to input queue for core busy loop.
self
.
input_queue
.
put_nowait
(
request
)
self
.
input_queue
.
put_nowait
(
(
request
_type
,
request
)
)
def
process_output_socket
(
self
,
output_path
:
str
):
def
process_output_socket
(
self
,
output_path
:
str
):
"""Output socket IO thread."""
"""Output socket IO thread."""
...
...
vllm/v1/engine/core_client.py
View file @
67c4637c
...
@@ -5,7 +5,7 @@ import os
...
@@ -5,7 +5,7 @@ import os
import
signal
import
signal
import
weakref
import
weakref
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
typing
import
List
,
Optional
,
Type
from
typing
import
Any
,
List
,
Optional
,
Type
import
zmq
import
zmq
import
zmq.asyncio
import
zmq.asyncio
...
@@ -14,12 +14,11 @@ from vllm.config import VllmConfig
...
@@ -14,12 +14,11 @@ from vllm.config import VllmConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
(
get_open_zmq_ipc_path
,
kill_process_tree
,
from
vllm.utils
import
(
get_open_zmq_ipc_path
,
kill_process_tree
,
make_zmq_socket
)
make_zmq_socket
)
from
vllm.v1.engine
import
(
EngineCoreOutputs
,
EngineCoreProfile
,
from
vllm.v1.engine
import
(
EngineCoreOutputs
,
EngineCoreRequest
,
EngineCoreRequest
,
EngineCoreRequestType
,
EngineCoreRequestType
)
EngineCoreRequestUnion
,
EngineCoreResetPrefixCache
)
from
vllm.v1.engine.core
import
EngineCore
,
EngineCoreProc
from
vllm.v1.engine.core
import
EngineCore
,
EngineCoreProc
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.serial_utils
import
MsgpackDecoder
,
Pickle
Encoder
from
vllm.v1.serial_utils
import
MsgpackDecoder
,
Msgpack
Encoder
from
vllm.v1.utils
import
BackgroundProcHandle
from
vllm.v1.utils
import
BackgroundProcHandle
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -161,7 +160,7 @@ class MPClient(EngineCoreClient):
...
@@ -161,7 +160,7 @@ class MPClient(EngineCoreClient):
signal
.
signal
(
signal
.
SIGUSR1
,
sigusr1_handler
)
signal
.
signal
(
signal
.
SIGUSR1
,
sigusr1_handler
)
# Serialization setup.
# Serialization setup.
self
.
encoder
=
Pickle
Encoder
()
self
.
encoder
=
Msgpack
Encoder
()
self
.
decoder
=
MsgpackDecoder
(
EngineCoreOutputs
)
self
.
decoder
=
MsgpackDecoder
(
EngineCoreOutputs
)
# ZMQ setup.
# ZMQ setup.
...
@@ -220,7 +219,7 @@ class SyncMPClient(MPClient):
...
@@ -220,7 +219,7 @@ class SyncMPClient(MPClient):
return
self
.
decoder
.
decode
(
frame
.
buffer
)
return
self
.
decoder
.
decode
(
frame
.
buffer
)
def
_send_input
(
self
,
request_type
:
EngineCoreRequestType
,
def
_send_input
(
self
,
request_type
:
EngineCoreRequestType
,
request
:
EngineCoreRequestUnion
)
->
None
:
request
:
Any
)
->
None
:
# (RequestType, SerializedRequest)
# (RequestType, SerializedRequest)
msg
=
(
request_type
.
value
,
self
.
encoder
.
encode
(
request
))
msg
=
(
request_type
.
value
,
self
.
encoder
.
encode
(
request
))
...
@@ -237,12 +236,10 @@ class SyncMPClient(MPClient):
...
@@ -237,12 +236,10 @@ class SyncMPClient(MPClient):
self
.
_send_input
(
EngineCoreRequestType
.
ABORT
,
request_ids
)
self
.
_send_input
(
EngineCoreRequestType
.
ABORT
,
request_ids
)
def
profile
(
self
,
is_start
:
bool
=
True
)
->
None
:
def
profile
(
self
,
is_start
:
bool
=
True
)
->
None
:
self
.
_send_input
(
EngineCoreRequestType
.
PROFILE
,
self
.
_send_input
(
EngineCoreRequestType
.
PROFILE
,
is_start
)
EngineCoreProfile
(
is_start
))
def
reset_prefix_cache
(
self
)
->
None
:
def
reset_prefix_cache
(
self
)
->
None
:
self
.
_send_input
(
EngineCoreRequestType
.
RESET_PREFIX_CACHE
,
self
.
_send_input
(
EngineCoreRequestType
.
RESET_PREFIX_CACHE
,
None
)
EngineCoreResetPrefixCache
())
class
AsyncMPClient
(
MPClient
):
class
AsyncMPClient
(
MPClient
):
...
@@ -277,7 +274,7 @@ class AsyncMPClient(MPClient):
...
@@ -277,7 +274,7 @@ class AsyncMPClient(MPClient):
return
self
.
decoder
.
decode
(
await
self
.
outputs_queue
.
get
())
return
self
.
decoder
.
decode
(
await
self
.
outputs_queue
.
get
())
async
def
_send_input
(
self
,
request_type
:
EngineCoreRequestType
,
async
def
_send_input
(
self
,
request_type
:
EngineCoreRequestType
,
request
:
EngineCoreRequestUnion
)
->
None
:
request
:
Any
)
->
None
:
msg
=
(
request_type
.
value
,
self
.
encoder
.
encode
(
request
))
msg
=
(
request_type
.
value
,
self
.
encoder
.
encode
(
request
))
await
self
.
input_socket
.
send_multipart
(
msg
,
copy
=
False
)
await
self
.
input_socket
.
send_multipart
(
msg
,
copy
=
False
)
...
@@ -293,9 +290,7 @@ class AsyncMPClient(MPClient):
...
@@ -293,9 +290,7 @@ class AsyncMPClient(MPClient):
await
self
.
_send_input
(
EngineCoreRequestType
.
ABORT
,
request_ids
)
await
self
.
_send_input
(
EngineCoreRequestType
.
ABORT
,
request_ids
)
async
def
profile_async
(
self
,
is_start
:
bool
=
True
)
->
None
:
async
def
profile_async
(
self
,
is_start
:
bool
=
True
)
->
None
:
await
self
.
_send_input
(
EngineCoreRequestType
.
PROFILE
,
await
self
.
_send_input
(
EngineCoreRequestType
.
PROFILE
,
is_start
)
EngineCoreProfile
(
is_start
))
async
def
reset_prefix_cache_async
(
self
)
->
None
:
async
def
reset_prefix_cache_async
(
self
)
->
None
:
await
self
.
_send_input
(
EngineCoreRequestType
.
RESET_PREFIX_CACHE
,
await
self
.
_send_input
(
EngineCoreRequestType
.
RESET_PREFIX_CACHE
,
None
)
EngineCoreResetPrefixCache
())
vllm/v1/serial_utils.py
View file @
67c4637c
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
pickle
import
pickle
from
typing
import
Any
from
typing
import
Any
,
Optional
import
torch
import
torch
from
msgspec
import
msgpack
from
msgspec
import
msgpack
CUSTOM_TYPE_CODE_PICKLE
=
1
CUSTOM_TYPE_TENSOR
=
1
CUSTOM_TYPE_PICKLE
=
2
class
PickleEncoder
:
def
encode
(
self
,
obj
:
Any
):
return
pickle
.
dumps
(
obj
)
def
decode
(
self
,
data
:
Any
):
return
pickle
.
loads
(
data
)
class
MsgpackEncoder
:
class
MsgpackEncoder
:
...
@@ -34,8 +26,9 @@ class MsgpackEncoder:
...
@@ -34,8 +26,9 @@ class MsgpackEncoder:
class
MsgpackDecoder
:
class
MsgpackDecoder
:
"""Decoder with custom torch tensor serialization."""
"""Decoder with custom torch tensor serialization."""
def
__init__
(
self
,
t
:
Any
):
def
__init__
(
self
,
t
:
Optional
[
Any
]
=
None
):
self
.
decoder
=
msgpack
.
Decoder
(
t
,
ext_hook
=
custom_ext_hook
)
args
=
()
if
t
is
None
else
(
t
,
)
self
.
decoder
=
msgpack
.
Decoder
(
*
args
,
ext_hook
=
custom_ext_hook
)
def
decode
(
self
,
obj
:
Any
):
def
decode
(
self
,
obj
:
Any
):
return
self
.
decoder
.
decode
(
obj
)
return
self
.
decoder
.
decode
(
obj
)
...
@@ -46,13 +39,15 @@ def custom_enc_hook(obj: Any) -> Any:
...
@@ -46,13 +39,15 @@ def custom_enc_hook(obj: Any) -> Any:
# NOTE(rob): it is fastest to use numpy + pickle
# NOTE(rob): it is fastest to use numpy + pickle
# when serializing torch tensors.
# when serializing torch tensors.
# https://gist.github.com/tlrmchlsmth/8067f1b24a82b6e2f90450e7764fa103 # noqa: E501
# https://gist.github.com/tlrmchlsmth/8067f1b24a82b6e2f90450e7764fa103 # noqa: E501
return
msgpack
.
Ext
(
CUSTOM_TYPE_
CODE_PICKLE
,
pickle
.
dumps
(
obj
.
numpy
()))
return
msgpack
.
Ext
(
CUSTOM_TYPE_
TENSOR
,
pickle
.
dumps
(
obj
.
numpy
()))
r
aise
NotImplementedError
(
f
"Objects of type
{
type
(
obj
)
}
are not supported"
)
r
eturn
msgpack
.
Ext
(
CUSTOM_TYPE_PICKLE
,
pickle
.
dumps
(
obj
)
)
def
custom_ext_hook
(
code
:
int
,
data
:
memoryview
)
->
Any
:
def
custom_ext_hook
(
code
:
int
,
data
:
memoryview
)
->
Any
:
if
code
==
CUSTOM_TYPE_
CODE_PICKLE
:
if
code
==
CUSTOM_TYPE_
TENSOR
:
return
torch
.
from_numpy
(
pickle
.
loads
(
data
))
return
torch
.
from_numpy
(
pickle
.
loads
(
data
))
if
code
==
CUSTOM_TYPE_PICKLE
:
return
pickle
.
loads
(
data
)
raise
NotImplementedError
(
f
"Extension type code
{
code
}
is not supported"
)
raise
NotImplementedError
(
f
"Extension type code
{
code
}
is not supported"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment