serial_utils.py 23.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import dataclasses
5
import importlib
6
import pickle
7
from abc import ABC, abstractmethod
8
from collections.abc import Callable, Sequence
9
from functools import partial
10
from inspect import isclass
11
from types import FunctionType
12
from typing import Any, ClassVar, TypeAlias, cast, get_type_hints
13

14
import cloudpickle
15
import msgspec
16
import numpy as np
17
import torch
18
import zmq
19
from msgspec import msgpack
20
21
from pydantic import GetCoreSchemaHandler
from pydantic_core import core_schema
22

23
from vllm import envs
24
from vllm.logger import init_logger
25
26
27
28
29
30
31
32
33
34
35
from vllm.multimodal.inputs import (
    BaseMultiModalField,
    MultiModalBatchedField,
    MultiModalFieldConfig,
    MultiModalFieldElem,
    MultiModalFlatField,
    MultiModalKwargsItem,
    MultiModalKwargsItems,
    MultiModalSharedField,
    NestedTensors,
)
36
from vllm.utils.platform_utils import is_pin_memory_available
37
from vllm.v1.utils import tensor_data
38

39
40
logger = init_logger(__name__)

41
42
CUSTOM_TYPE_PICKLE = 1
CUSTOM_TYPE_CLOUDPICKLE = 2
43
CUSTOM_TYPE_RAW_VIEW = 3
44

45
46
47
48
49
50
51
52
# MultiModalField class serialization type map.
# These need to list all possible field types and match them
# to factory methods in `MultiModalFieldConfig`.
MMF_CLASS_TO_FACTORY: dict[type[BaseMultiModalField], str] = {
    MultiModalFlatField: "flat",
    MultiModalSharedField: "shared",
    MultiModalBatchedField: "batched",
}
53

54
bytestr: TypeAlias = bytes | bytearray | memoryview | zmq.Frame
55
56


57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
class OOBTensorConsumer(ABC):
    @abstractmethod
    def __call__(self, tensor: torch.Tensor) -> dict | None:
        """
        Called with tensors for the current message.
        Returns None to reject the tensor (falls back to regular serialization),
        otherwise a dict with arbitrary placeholder data to be included
        in the serialized message.
        """
        return None

    @abstractmethod
    def new_message(self) -> None:
        """Called at the start of each new encoded message."""
        pass


# dtype, shape, metadata -> tensor
OOBTensorProvider = Callable[[str, tuple[int, ...], dict], torch.Tensor]


78
def _log_insecure_serialization_warning():
79
80
81
82
    logger.warning_once(
        "Allowing insecure serialization using pickle due to "
        "VLLM_ALLOW_INSECURE_SERIALIZATION=1"
    )
83
84


85
def _typestr(val: Any) -> tuple[str, str] | None:
86
87
88
    if val is None:
        return None
    t = type(val)
89
90
91
    return t.__module__, t.__qualname__


92
93
94
95
96
97
98
99
100
101
102
103
104
def _encode_type_info_recursive(obj: Any) -> Any:
    """Recursively encode type information for nested structures of
    lists/dicts."""
    if obj is None:
        return None
    if type(obj) is list:
        return [_encode_type_info_recursive(item) for item in obj]
    if type(obj) is dict:
        return {k: _encode_type_info_recursive(v) for k, v in obj.items()}
    return _typestr(obj)


def _decode_type_info_recursive(
105
106
    type_info: Any, data: Any, convert_fn: Callable[[Sequence[str], Any], Any]
) -> Any:
107
108
109
110
111
112
113
114
115
116
117
    """Recursively decode type information for nested structures of
    lists/dicts."""
    if type_info is None:
        return data
    if isinstance(type_info, dict):
        assert isinstance(data, dict)
        return {
            k: _decode_type_info_recursive(type_info[k], data[k], convert_fn)
            for k in type_info
        }
    if isinstance(type_info, list) and (
118
119
120
        # Exclude serialized tensors/numpy arrays.
        len(type_info) != 2 or not isinstance(type_info[0], str)
    ):
121
122
123
124
125
126
127
128
        assert isinstance(data, list)
        return [
            _decode_type_info_recursive(ti, d, convert_fn)
            for ti, d in zip(type_info, data)
        ]
    return convert_fn(type_info, data)


129
130
131
132
133
134
135
class UtilityResult:
    """Wrapper for special handling when serializing/deserializing."""

    def __init__(self, r: Any = None):
        self.result = r


136
137
class MsgpackEncoder:
    """Encoder with custom torch tensor and numpy array serialization.
138

139
140
    Note that unlike vanilla `msgspec` Encoders, this interface is generally
    not thread-safe when encoding tensors / numpy arrays.
141

142
    By default, arrays below 256B are serialized inline Larger will get sent
143
    via dedicated messages. Note that this is a per-tensor limit.
144
145
146

    When a ``oob_tensor_consumer`` is provided, tensors (CUDA and CPU) will be
    offered to it for out-of-band handling.
147
148
    """

149
150
151
152
153
    def __init__(
        self,
        size_threshold: int | None = None,
        oob_tensor_consumer: OOBTensorConsumer | None = None,
    ):
154
155
        if size_threshold is None:
            size_threshold = envs.VLLM_MSGPACK_ZERO_COPY_THRESHOLD
156
157
158
159
        self.encoder = msgpack.Encoder(enc_hook=self.enc_hook)
        # This is used as a local stash of buffers that we can then access from
        # our custom `msgspec` hook, `enc_hook`. We don't have a way to
        # pass custom data to the hook otherwise.
160
        self.aux_buffers: list[bytestr] | None = None
161
        self.size_threshold = size_threshold
162
        self.oob_tensor_consumer = oob_tensor_consumer
163
        if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
164
            _log_insecure_serialization_warning()
165
166
167

    def encode(self, obj: Any) -> Sequence[bytestr]:
        try:
168
169
            if self.oob_tensor_consumer is not None:
                self.oob_tensor_consumer.new_message()
170
            self.aux_buffers = bufs = [b""]
171
172
173
174
175
176
177
178
179
180
181
            bufs[0] = self.encoder.encode(obj)
            # This `bufs` list allows us to collect direct pointers to backing
            # buffers of tensors and np arrays, and return them along with the
            # top-level encoded buffer instead of copying their data into the
            # new buffer.
            return bufs
        finally:
            self.aux_buffers = None

    def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]:
        try:
182
183
            if self.oob_tensor_consumer is not None:
                self.oob_tensor_consumer.new_message()
184
185
186
187
188
189
190
191
192
            self.aux_buffers = [buf]
            bufs = self.aux_buffers
            self.encoder.encode_into(obj, buf)
            return bufs
        finally:
            self.aux_buffers = None

    def enc_hook(self, obj: Any) -> Any:
        if isinstance(obj, torch.Tensor):
193
            return self._encode_tensor(obj)
194
195

        # Fall back to pickle for object or void kind ndarrays.
196
        if isinstance(obj, np.ndarray) and obj.dtype.kind not in ("O", "V"):
197
198
            return self._encode_ndarray(obj)

199
200
201
202
        if isinstance(obj, slice):
            # We are assuming only int-based values will be used here.
            return tuple(
                int(v) if v is not None else None
203
204
                for v in (obj.start, obj.stop, obj.step)
            )
205

206
207
208
        if isinstance(obj, MultiModalKwargsItem):
            return self._encode_mm_item(obj)

209
210
211
        if isinstance(obj, MultiModalKwargsItems):
            return self._encode_mm_items(obj)

212
213
        if isinstance(obj, UtilityResult):
            result = obj.result
214
            if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
215
                return None, result
216
217
218
219
            # Since utility results are not strongly typed, we recursively
            # encode type information for nested structures of lists/dicts
            # to help with correct msgspec deserialization.
            return _encode_type_info_recursive(result), result
220

221
        if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
222
223
224
225
226
            raise TypeError(
                f"Object of type {type(obj)} is not serializable"
                "Set VLLM_ALLOW_INSECURE_SERIALIZATION=1 to allow "
                "fallback to pickle-based serialization."
            )
227

228
229
230
231
232
        if isinstance(obj, FunctionType):
            # `pickle` is generally faster than cloudpickle, but can have
            # problems serializing methods.
            return msgpack.Ext(CUSTOM_TYPE_CLOUDPICKLE, cloudpickle.dumps(obj))

233
234
235
        return msgpack.Ext(
            CUSTOM_TYPE_PICKLE, pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
        )
236
237
238

    def _encode_ndarray(
        self, obj: np.ndarray
239
    ) -> tuple[str, tuple[int, ...], int | memoryview]:
240
        assert self.aux_buffers is not None
241
        # If the array is non-contiguous, we need to copy it first
242
        arr_data = obj.data if obj.flags.c_contiguous else obj.tobytes()
243
        if not obj.shape or obj.nbytes < self.size_threshold:
244
245
246
            # Encode small arrays and scalars inline. Using this extension type
            # ensures we can avoid copying when decoding.
            data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr_data)
247
        else:
248
            # Otherwise encode index of backing buffer to avoid copy.
249
            data = len(self.aux_buffers)
250
251
            self.aux_buffers.append(arr_data)

252
253
254
255
        # We serialize the ndarray as a tuple of native types.
        # The data is either inlined if small, or an index into a list of
        # backing buffers that we've stashed in `aux_buffers`.
        return obj.dtype.str, obj.shape, data
256

257
258
    def _encode_tensor(
        self, obj: torch.Tensor
259
260
    ) -> tuple[str, tuple[int, ...], int | dict | memoryview]:
        oob_consumer = self.oob_tensor_consumer
261
        # view the tensor as a contiguous 1D array of bytes
262
        if obj.nbytes < self.size_threshold and obj.is_cpu:
263
            # Smaller tensors are encoded inline, just like ndarrays.
264
265
266
            data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, tensor_data(obj))
        elif oob_consumer is not None and (data := oob_consumer(obj)) is not None:
            assert isinstance(data, dict)
267
268
        else:
            # Otherwise encode index of backing buffer to avoid copy.
269
            assert self.aux_buffers is not None
270
            data = len(self.aux_buffers)
271
            self.aux_buffers.append(tensor_data(obj))
272
        dtype = str(obj.dtype).removeprefix("torch.")
273
274
        return dtype, obj.shape, data

275
276
277
278
279
280
    def _encode_mm_items(self, items: MultiModalKwargsItems) -> dict[str, Any]:
        return {
            modality: [self._encode_mm_item(item) for item in itemlist]
            for modality, itemlist in items.items()
        }

281
282
    def _encode_mm_item(self, item: MultiModalKwargsItem) -> dict[str, Any]:
        return {key: self._encode_mm_field_elem(elem) for key, elem in item.items()}
283

284
    def _encode_mm_field_elem(self, elem: MultiModalFieldElem) -> dict[str, Any]:
285
        return {
286
287
288
289
            "data": (
                None if elem.data is None else self._encode_nested_tensors(elem.data)
            ),
            "field": self._encode_mm_field(elem.field),
290
291
        }

292
293
    def _encode_nested_tensors(self, nt: NestedTensors) -> Any:
        if isinstance(nt, torch.Tensor):
294
            return self._encode_tensor(nt)
295
296
297
298
299
300
301
302
303
304
305
        if isinstance(nt, (int, float)):
            # Although it violates NestedTensors type, MultiModalKwargs
            # values are sometimes floats.
            return nt
        return [self._encode_nested_tensors(x) for x in nt]

    def _encode_mm_field(self, field: BaseMultiModalField):
        # Figure out the factory name for the field type.
        name = MMF_CLASS_TO_FACTORY.get(field.__class__)
        if not name:
            raise TypeError(f"Unsupported field type: {field.__class__}")
306

307
308
        # We just need to copy all of the field values in order
        # which will be then used to reconstruct the field.
309
310
        factory_kw = {f.name: getattr(field, f.name) for f in dataclasses.fields(field)}
        return name, factory_kw
311

312
313

class MsgpackDecoder:
314
315
316
317
    """Decoder with custom torch tensor and numpy array serialization.

    Note that unlike vanilla `msgspec` Decoders, this interface is generally
    not thread-safe when encoding tensors / numpy arrays.
318
319
320

    ``oob_tensor_provider`` must be used when an OOBTensorConsumer is used on the
    encoder side.
321
    """
322

323
324
325
326
327
328
    def __init__(
        self,
        t: Any | None = None,
        share_mem: bool = True,
        oob_tensor_provider: OOBTensorProvider | None = None,
    ):
329
330
        self.share_mem = share_mem
        self.pin_tensors = is_pin_memory_available()
331
332
333
334
        args = () if t is None else (t,)
        self.decoder = msgpack.Decoder(
            *args, ext_hook=self.ext_hook, dec_hook=self.dec_hook
        )
335
        self.aux_buffers: Sequence[bytestr] = ()
336
        self.oob_tensor_provider = oob_tensor_provider
337
        if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
338
            _log_insecure_serialization_warning()
339

340
    def decode(self, bufs: bytestr | Sequence[bytestr]) -> Any:
341
        if isinstance(bufs, bytestr):  # type: ignore
342
343
344
345
346
347
348
349
350
351
352
353
354
355
            return self.decoder.decode(bufs)

        self.aux_buffers = bufs
        try:
            return self.decoder.decode(bufs[0])
        finally:
            self.aux_buffers = ()

    def dec_hook(self, t: type, obj: Any) -> Any:
        # Given native types in `obj`, convert to type `t`.
        if isclass(t):
            if issubclass(t, np.ndarray):
                return self._decode_ndarray(obj)
            if issubclass(t, torch.Tensor):
356
                return self._decode_tensor(obj)
357
358
            if t is slice:
                return slice(*obj)
359
360
            if issubclass(t, MultiModalKwargsItem):
                return self._decode_mm_item(obj)
361
362
            if issubclass(t, MultiModalKwargsItems):
                return self._decode_mm_items(obj)
363
364
            if t is UtilityResult:
                return self._decode_utility_result(obj)
365
366
        return obj

367
368
369
370
    def _decode_utility_result(self, obj: Any) -> UtilityResult:
        result_type, result = obj
        if result_type is not None:
            if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
371
372
373
374
                raise TypeError(
                    "VLLM_ALLOW_INSECURE_SERIALIZATION must "
                    "be set to use custom utility result types"
                )
375
            # Use recursive decoding to handle nested structures
376
377
378
            result = _decode_type_info_recursive(
                result_type, result, self._convert_result
            )
379
380
        return UtilityResult(result)

381
382
383
    def _convert_result(self, result_type: Sequence[str], result: Any) -> Any:
        if result_type is None:
            return result
384
385
386
387
388
        mod_name, name = result_type
        mod = importlib.import_module(mod_name)
        result_type = getattr(mod, name)
        return msgspec.convert(result, result_type, dec_hook=self.dec_hook)

389
390
    def _decode_ndarray(self, arr: Any) -> np.ndarray:
        dtype, shape, data = arr
391
392
393
        # zero-copy decode. We assume the ndarray will not be kept around,
        # as it now locks the whole received message buffer in memory.
        buffer = self.aux_buffers[data] if isinstance(data, int) else data
394
395
396
397
        arr = np.frombuffer(buffer, dtype=dtype)
        if not self.share_mem:
            arr = arr.copy()
        return arr.reshape(shape)
398
399
400

    def _decode_tensor(self, arr: Any) -> torch.Tensor:
        dtype, shape, data = arr
401
402
403
404
405
406
        if isinstance(data, dict):
            assert self.oob_tensor_provider, (
                "Received OOB tensor but tensor provider is not set"
            )
            return self.oob_tensor_provider(dtype, shape, data)

407
408
409
        is_aux = isinstance(data, int)
        buffer = self.aux_buffers[data] if is_aux else data
        buffer = buffer if isinstance(buffer, memoryview) else memoryview(buffer)
410
411
        torch_dtype = getattr(torch, dtype)
        assert isinstance(torch_dtype, torch.dtype)
412
        if not buffer.nbytes:  # torch.frombuffer doesn't like empty buffers
413
414
415
416
            assert 0 in shape
            return torch.empty(shape, dtype=torch_dtype)
        # Create uint8 array
        arr = torch.frombuffer(buffer, dtype=torch.uint8)
417
418
419
420
421
422
423
        # Clone ensures tensor is backed by pytorch-owned memory for safe
        # future async CPU->GPU transfer.
        # Pin larger tensors for more efficient CPU->GPU transfer.
        if not is_aux:
            arr = arr.clone()
        elif not self.share_mem:
            arr = arr.pin_memory() if self.pin_tensors else arr.clone()
424
        # Convert back to proper shape & type
425
        return arr.view(torch_dtype).view(shape)
426

427
    def _decode_mm_items(self, obj: dict[str, Any]) -> MultiModalKwargsItems:
428
429
430
431
432
433
        return MultiModalKwargsItems(
            {
                modality: [self._decode_mm_item(item) for item in itemlist]
                for modality, itemlist in obj.items()
            }
        )
434

435
436
437
    def _decode_mm_item(self, obj: dict[str, Any]) -> MultiModalKwargsItem:
        return MultiModalKwargsItem(
            {key: self._decode_mm_field_elem(elem) for key, elem in obj.items()}
438
        )
439

440
    def _decode_mm_field_elem(self, obj: dict[str, Any]) -> MultiModalFieldElem:
441
442
443
        if obj["data"] is not None:
            obj["data"] = self._decode_nested_tensors(obj["data"])

444
        # Reconstruct the field processor using MultiModalFieldConfig
445
        factory_meth_name, factory_kw = obj["field"]
446
447
448
449
450
        factory_meth = getattr(MultiModalFieldConfig, factory_meth_name)

        # Special case: decode the union "slices" field of
        # MultiModalFlatField
        if factory_meth_name == "flat":
451
            factory_kw["slices"] = self._decode_nested_slices(factory_kw["slices"])
452

453
        obj["field"] = factory_meth("", **factory_kw).field
454
        return MultiModalFieldElem(**obj)
455
456
457
458
459
460
461
462
463

    def _decode_nested_tensors(self, obj: Any) -> NestedTensors:
        if isinstance(obj, (int, float)):
            # Although it violates NestedTensors type, MultiModalKwargs
            # values are sometimes floats.
            return obj
        if not isinstance(obj, list):
            raise TypeError(f"Unexpected NestedTensors contents: {type(obj)}")
        if obj and isinstance(obj[0], str):
464
            return self._decode_tensor(obj)
465
466
        return [self._decode_nested_tensors(x) for x in obj]

467
468
469
470
471
472
    def _decode_nested_slices(self, obj: Any) -> Any:
        assert isinstance(obj, (list, tuple))
        if obj and not isinstance(obj[0], (list, tuple)):
            return slice(*obj)
        return [self._decode_nested_slices(x) for x in obj]

473
    def ext_hook(self, code: int, data: memoryview) -> Any:
474
475
        if code == CUSTOM_TYPE_RAW_VIEW:
            return data
476

477
        if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
478
479
480
481
            if code == CUSTOM_TYPE_PICKLE:
                return pickle.loads(data)
            if code == CUSTOM_TYPE_CLOUDPICKLE:
                return cloudpickle.loads(data)
482

483
        raise NotImplementedError(f"Extension type code {code} is not supported")
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510


def run_method(
    obj: Any,
    method: str | bytes | Callable,
    args: tuple[Any, ...],
    kwargs: dict[str, Any],
) -> Any:
    """
    Run a method of an object with the given arguments and keyword arguments.
    If the method is string, it will be converted to a method using getattr.
    If the method is serialized bytes and will be deserialized using
    cloudpickle.
    If the method is a callable, it will be called directly.
    """
    if isinstance(method, bytes):
        func = partial(cloudpickle.loads(method), obj)
    elif isinstance(method, str):
        try:
            func = getattr(obj, method)
        except AttributeError:
            raise NotImplementedError(
                f"Method {method!r} is not implemented."
            ) from None
    else:
        func = partial(method, obj)  # type: ignore
    return func(*args, **kwargs)
511
512
513


class PydanticMsgspecMixin:
514
515
516
517
518
519
520
521
522
523
524
525
526
    """Make a ``msgspec.Struct`` compatible with Pydantic for both
    **validation** (JSON/dict -> Struct) and **serialization**
    (Struct -> JSON-safe dict).

    Subclasses may set ``__pydantic_msgspec_exclude__`` (a ``set[str]``)
    to list non-underscore field names that should also be stripped from
    serialized output.  Fields whose names start with ``_`` are always
    excluded automatically.
    """

    # Subclasses can override to exclude additional public-but-internal keys.
    __pydantic_msgspec_exclude__: ClassVar[set[str]] = set()

527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
    @classmethod
    def __get_pydantic_core_schema__(
        cls, source_type: Any, handler: GetCoreSchemaHandler
    ) -> core_schema.CoreSchema:
        """
        Make msgspec.Struct compatible with Pydantic, respecting defaults.
        Handle JSON=>msgspec.Struct. Used when exposing msgspec.Struct to the
        API as input or in `/docs`. Note this is cached by Pydantic and not
        called on every validation.
        """
        msgspec_fields = {f.name: f for f in msgspec.structs.fields(source_type)}
        type_hints = get_type_hints(source_type)

        # Build the Pydantic typed_dict_field for each msgspec field
        fields = {}
        for name, hint in type_hints.items():
543
544
545
546
547
548
549
            if name not in msgspec_fields:
                # Skip ClassVar and other non-struct annotations.
                continue
            # Skip private fields — they are excluded from serialization
            # and should not appear in the generated JSON/OpenAPI schema.
            if name.startswith("_"):
                continue
550
551
552
553
554
555
            msgspec_field = msgspec_fields[name]

            # typed_dict_field using the handler to get the schema
            field_schema = handler(hint)

            # Add default value to the schema.
556
557
558
            # Mark fields with defaults as not required so the generated
            # JSON Schema stays consistent with ``omit_defaults=True``
            # serialization (fields at their default value may be absent).
559
560
561
562
563
            if msgspec_field.default_factory is not msgspec.NODEFAULT:
                wrapped_schema = core_schema.with_default_schema(
                    schema=field_schema,
                    default_factory=msgspec_field.default_factory,
                )
564
565
566
                fields[name] = core_schema.typed_dict_field(
                    wrapped_schema, required=False
                )
567
568
569
570
571
            elif msgspec_field.default is not msgspec.NODEFAULT:
                wrapped_schema = core_schema.with_default_schema(
                    schema=field_schema,
                    default=msgspec_field.default,
                )
572
573
574
                fields[name] = core_schema.typed_dict_field(
                    wrapped_schema, required=False
                )
575
576
577
            else:
                # No default, so Pydantic will treat it as required
                fields[name] = core_schema.typed_dict_field(field_schema)
578
        typed_dict_then_convert = core_schema.no_info_after_validator_function(
579
580
581
582
            cls._validate_msgspec,
            core_schema.typed_dict_schema(fields),
        )

583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
        # Build a serializer that strips private / excluded fields.
        serializer = core_schema.plain_serializer_function_ser_schema(
            cls._serialize_msgspec,
            info_arg=False,
        )

        # Accept either an already-constructed msgspec.Struct instance or a
        # JSON/dict-like payload.
        return core_schema.union_schema(
            [
                core_schema.is_instance_schema(source_type),
                typed_dict_then_convert,
            ],
            serialization=serializer,
        )

599
600
601
602
603
604
605
606
    @classmethod
    def _validate_msgspec(cls, value: Any) -> Any:
        """Validate and convert input to msgspec.Struct instance."""
        if isinstance(value, cls):
            return value
        if isinstance(value, dict):
            return cls(**value)
        return msgspec.convert(value, type=cls)
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628

    @staticmethod
    def _serialize_msgspec(value: Any) -> Any:
        """Serialize a msgspec.Struct to a JSON-compatible dict, stripping
        private (``_``-prefixed) and explicitly excluded fields.

        Uses ``msgspec.to_builtins`` which respects ``omit_defaults=True``,
        so only fields that differ from their declared defaults are included.
        """
        raw = msgspec.to_builtins(value)
        if not isinstance(raw, dict):
            return raw

        exclude: set[str] = cast(
            set[str],
            getattr(type(value), "__pydantic_msgspec_exclude__", set()),
        )
        for key in list(raw):
            if key.startswith("_") or key in exclude:
                del raw[key]

        return raw