deepseek_vl2.py 23.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
# adapted from https://github.com/deepseek-ai/DeepSeek-VL2/blob/faf18023f24b962b32d9f0a2d89e402a8d383a78/deepseek_vl2/models/modeling_deepseek_vl_v2.py
"""Inference-only Deepseek-VL2 model compatible with HuggingFace weights."""
import math
7
from collections.abc import Iterable, Mapping, Sequence
8
from typing import Annotated, Literal, Optional, Union
9
10
11
12
13

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
14
from transformers import BatchFeature
15
16

from vllm.config import VllmConfig
17
from vllm.distributed import get_tensor_model_parallel_world_size
18
19
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
20
from vllm.model_executor.models.transformers import replace_linear_class
21
from vllm.multimodal import MULTIMODAL_REGISTRY
22
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
23
                                    MultiModalKwargsItems, MultiModalUUIDDict)
24
25
26
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
                                   ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
27
28
                                        BaseProcessingInfo,
                                        MultiModalProcessingInfo,
29
                                        PromptReplacement, PromptUpdate)
30
from vllm.multimodal.profiling import BaseDummyInputsBuilder
31
32
33
34
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.deepseek_vl2 import (DeepseekVLV2Config,
                                                          MlpProjectorConfig,
                                                          VisionEncoderConfig)
35
36
from vllm.transformers_utils.processors.deepseek_vl2 import (
    DeepseekVLV2Processor)
37
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
38
from vllm.utils import is_list_of
39
from vllm.utils.tensor_schema import TensorSchema, TensorShape
40

41
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
42
from .utils import (AutoWeightsLoader, WeightsMapper,
43
                    init_vllm_registered_model, maybe_prefix)
44
45
46
47
48

# The image token id may be various
_IMAGE_TOKEN = "<image>"


49
class DeepseekVL2ImagePixelInputs(TensorSchema):
50
    """
51
    Dimensions:
52
        - bnp: Batch size * number of images * number of patches
53
        - p: Number of patches
54
55
56
        - c: Number of channels (3)
        - h: Height of each image
        - w: Width of each image
57
    """
58
    type: Literal["pixel_values"]
59
60
    data: Annotated[torch.Tensor,
                    TensorShape("bnp", 3, "h", "w", dynamic_dims={"bnp"})]
61
    images_spatial_crop: Annotated[torch.Tensor, TensorShape("bn", 2)]
62
63


64
class DeepseekVL2VImageEmbeddingInputs(TensorSchema):
65
    """
66
67
68
69
70
71
72
73
    Dimensions:
        - bn: Batch size * number of images
        - f: Image feature size
        - h: Hidden size (must match language model backbone)
    """
    type: Literal["image_embeds"]
    data: Annotated[Union[torch.Tensor, list[torch.Tensor]],
                    TensorShape("bn", "f", "h")]
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139


DeepseekVL2ImageInputs = Union[DeepseekVL2ImagePixelInputs,
                               DeepseekVL2VImageEmbeddingInputs]


class MlpProjector(nn.Module):

    def __init__(self, cfg: MlpProjectorConfig):

        super().__init__()

        self.cfg = cfg
        assert not cfg.token_pooling, (
            "Token pooling is not supported currently.")

        if cfg.projector_type == "downsample_mlp_gelu":
            mlp_depth = cfg.depth
            mlp_ratio = cfg.mlp_ratio
            modules = [
                nn.Linear(
                    cfg.input_dim * cfg.downsample_ratio *
                    cfg.downsample_ratio, cfg.n_embed * mlp_ratio)
            ]
            for _ in range(1, mlp_depth - 1):
                modules.append(nn.GELU())
                modules.append(
                    nn.Linear(cfg.n_embed * mlp_ratio,
                              cfg.n_embed * mlp_ratio))
            modules.append(nn.GELU())
            modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed))
            modules = nn.Sequential(*modules)

        else:
            raise NotImplementedError(
                f"Unsupported projector type: {cfg.projector_type}")

        self.layers = modules

    def forward(self, x):
        bs, hw, input_dim = x.shape
        h = w = int((hw)**0.5)
        """compute padding"""
        if h % self.cfg.downsample_ratio:
            pad = self.cfg.downsample_ratio - h % self.cfg.downsample_ratio
        else:
            pad = 0
        x = x.reshape(bs, h, w, input_dim)
        if pad > 0:
            x = F.pad(x, (0, 0, 0, pad, 0, pad), "constant", 0)
        """4 to 1 concat"""
        x = x.permute(0, 3, 1, 2)  # B, C, H, W
        x = F.unfold(x,
                     kernel_size=self.cfg.downsample_ratio,
                     stride=self.cfg.downsample_ratio,
                     padding=0)  # B, C*4, HW // 4
        x = x.permute(0, 2, 1)

        return self.layers(x)


class DeepseekVL2ProcessingInfo(BaseProcessingInfo):

    def get_hf_config(self):
        return self.ctx.get_hf_config(DeepseekVLV2Config)

140
141
    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(DeepseekVLV2Processor, **kwargs)
142
143
144
145

    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None}

146
147
148
149
150
    def get_num_image_tokens(self,
                             *,
                             image_width: int,
                             image_height: int,
                             cropping: bool = True) -> int:
151
152
153
154
155
        hf_processor = self.get_hf_processor()
        image_size = hf_processor.image_size
        patch_size = hf_processor.patch_size
        downsample_ratio = hf_processor.downsample_ratio

156
157
158
159
160
161
162
        if cropping:
            best_width, best_height = hf_processor.select_best_resolution(
                (image_width, image_height))
            num_width_tiles, num_height_tiles = (best_width // image_size,
                                                 best_height // image_size)
        else:
            num_width_tiles = num_height_tiles = 1
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181

        h = w = math.ceil((image_size // patch_size) / downsample_ratio)

        global_views_tokens = h * (w + 1)
        local_views_tokens = (num_height_tiles * h) * (num_width_tiles * w + 1)
        return global_views_tokens + local_views_tokens + 1

    def get_image_size_with_most_features(self) -> ImageSize:
        hf_config = self.get_hf_config()
        candidate_resolutions = hf_config.candidate_resolutions
        height, width = max(candidate_resolutions,
                            key=lambda x: self.get_num_image_tokens(
                                image_width=x[1], image_height=x[0]))
        return ImageSize(width=width, height=height)


class DeepseekVL2DummyInputsBuilder(
        BaseDummyInputsBuilder[DeepseekVL2ProcessingInfo]):

182
183
184
185
186
187
188
189
190
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        processor = self.info.get_hf_processor()
        image_token = processor.image_token

        return image_token * num_images

    def get_dummy_mm_data(
191
192
193
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
194
    ) -> MultiModalDataDict:
195
196
197
198
        num_images = mm_counts.get("image", 0)

        max_image_size = self.info.get_image_size_with_most_features()

199
        return {
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
            "image":
            self._get_dummy_images(width=max_image_size.width,
                                   height=max_image_size.height,
                                   num_images=num_images)
        }


class DeepseekVL2MultiModalProcessor(
        BaseMultiModalProcessor[DeepseekVL2ProcessingInfo]):

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
215
        tok_kwargs: Mapping[str, object],
216
    ) -> BatchFeature:
217
        if not mm_data:
218
            tokenizer = self.info.get_tokenizer()
219
220
221
222
223
224
225
226
227
228
229
            return tokenizer(prompt,
                             add_special_tokens=True,
                             return_tensors="pt")

        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
            tok_kwargs=tok_kwargs,
        )

230
231
        processed_outputs["num_patches"] = (
            processed_outputs["images_spatial_crop"].prod(-1) + 1)
232
233
234
235
236
237
238
239

        return processed_outputs

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
240
241
        num_patches = hf_inputs.get("num_patches", torch.empty(0))

242
        return dict(
243
244
            pixel_values=MultiModalFieldConfig.flat_from_sizes(
                "image", num_patches),
245
246
247
248
            images_spatial_crop=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )

249
    def _get_prompt_updates(
250
251
252
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
253
        out_mm_kwargs: MultiModalKwargsItems,
254
    ) -> Sequence[PromptUpdate]:
255
256
257
258
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)

        image_token_id = hf_processor.image_token_id
        assert isinstance(image_token_id, int)
259
260
261
262
263
264
265
266
267
268
269
270
271

        def get_replacement_deepseek_vl2(item_idx: int):
            images = mm_items.get_items(
                "image", (ImageEmbeddingItems, ImageProcessorItems))

            if isinstance(images, ImageEmbeddingItems):
                num_image_tokens = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)

                num_image_tokens = self.info.get_num_image_tokens(
                    image_width=image_size.width,
                    image_height=image_size.height,
272
                    cropping=len(images) <= 2,
273
274
275
276
277
278
279
280
281
282
283
                )
            return [image_token_id] * num_image_tokens

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
                replacement=get_replacement_deepseek_vl2,
            )
        ]

284
285
286
287
288
    def _cached_apply_hf_processor(
        self,
        prompt: Union[str, list[int]],
        mm_data_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
289
        tokenization_kwargs: Mapping[str, object],
290
        mm_uuids: Optional[MultiModalUUIDDict] = None,
291
    ) -> tuple[list[int], MultiModalProcessingInfo, bool]:
292
293
294
295
296
        # The processor logic is different for len(images) <= 2 vs > 2
        # Since the processing cache assumes that the processor output is
        # invariant of how many images are passed per prompt, we only
        # perform caching for the most common case
        if mm_data_items.get_count("image", strict=False) > 2:
297
            return self._apply_hf_processor(
298
                prompt=prompt,
299
                mm_data_items=mm_data_items,
300
                hf_processor_mm_kwargs=hf_processor_mm_kwargs,
301
                tokenization_kwargs=tokenization_kwargs,
302
                mm_uuids=mm_uuids,
303
304
305
306
307
308
            )

        return super()._cached_apply_hf_processor(
            prompt=prompt,
            mm_data_items=mm_data_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
309
            tokenization_kwargs=tokenization_kwargs,
310
            mm_uuids=mm_uuids,
311
312
        )

313
314
315
316
317
318

@MULTIMODAL_REGISTRY.register_processor(
    DeepseekVL2MultiModalProcessor,
    info=DeepseekVL2ProcessingInfo,
    dummy_inputs=DeepseekVL2DummyInputsBuilder)
class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
319
    merge_by_field_config = True
320
321
322
323
324

    hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={
        "language.": "language_model.",
    })

325
326
327
328
329
330
331
    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("image"):
            return "<image>"

        raise ValueError("Only image modality is supported")

332
333
334
335
336
337
338
339
340
341
342
343
344
345
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config: DeepseekVLV2Config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

        self.config = config
        self.multimodal_config = multimodal_config

        self.vision_config = config.vision_config
        self.projector_config = config.projector_config
        self.text_config = config.text_config

        model_config = vllm_config.model_config
346
        tokenizer = cached_tokenizer_from_config(model_config)
347
        self.image_token_id: int = tokenizer.vocab[_IMAGE_TOKEN]
348
349
350
351
352
353
354
355
356
357
358
359
360

        self.vision = self._init_vision_module(self.vision_config,
                                               quant_config,
                                               maybe_prefix(prefix, "vision"))

        self.projector = MlpProjector(self.projector_config)
        self.tile_tag = config.tile_tag
        self.global_view_pos = config.global_view_pos

        # special token for image token sequence format
        embed_std = 1 / torch.sqrt(
            torch.tensor(self.projector_config.n_embed, dtype=torch.float32))
        if self.tile_tag == "2D":
361
            # <|view_seperator|>, <|\n|>
362
363
364
            self.image_newline = nn.Parameter(
                torch.randn(self.projector_config.n_embed) * embed_std)
            # This is a typo in original implementation
365
            self.view_seperator = nn.Parameter(
366
367
368
369
370
371
                torch.randn(self.projector_config.n_embed) * embed_std)
        else:
            raise ValueError(
                f"Only 2D tile_tag is supported currently, got: {self.tile_tag}"
            )

372
373
374
375
376
377
378
        if self.text_config.topk_method == "noaux_tc":
            architectures = ["DeepseekV3ForCausalLM"]
        elif not self.text_config.use_mla:
            architectures = ["DeepseekForCausalLM"]
        else:
            architectures = ["DeepseekV2ForCausalLM"]

379
380
381
382
        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
            hf_config=self.text_config,
            prefix=maybe_prefix(prefix, "language"),
383
            architectures=architectures,
384
385
386
387
388
        )

        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
    def _get_parent_and_attr(self, root: torch.nn.Module, dotted_name: str):
        """Return (parent_module, final_attr_name) for a dotted module path."""
        names = dotted_name.split('.')
        parent = root
        for n in names[:-1]:
            parent = getattr(parent, n)
        return parent, names[-1]

    #patch for timm ViT instance to support tensor parallel
    def patch_vit_for_tp(self, vit: torch.nn.Module,
                         quant_config: QuantizationConfig):
        try:
            import timm
        except ImportError as e:
            raise ImportError("Please install timm") from e

        for name, module in vit.named_modules():
            if isinstance(module, nn.Linear):
                parent, attr_name = self._get_parent_and_attr(vit, name)
                if isinstance(parent, timm.layers.Mlp) and attr_name == "fc1":
409
410
411
412
                    new_linear = replace_linear_class(module,
                                                      "colwise",
                                                      quant_config,
                                                      prefix=name)
413
414
415
                    setattr(parent, attr_name, new_linear)
                elif isinstance(parent,
                                timm.layers.Mlp) and attr_name == "fc2":
416
417
418
419
                    new_linear = replace_linear_class(module,
                                                      "rowwise",
                                                      quant_config,
                                                      prefix=name)
420
421
422
423
                    setattr(parent, attr_name, new_linear)

        return vit

424
425
426
427
428
429
430
431
432
    def _init_vision_module(
        self,
        vision_config: VisionEncoderConfig,
        quant_config: Optional[QuantizationConfig],
        prefix: str = "",
    ) -> nn.Module:
        # TODO: refactor vision model through timm wrapper from transformers
        try:
            import timm
433
434
        except ImportError as e:
            raise ImportError("Please install timm") from e
435
436
437
438
439
440
441
442
443
444

        with set_default_torch_dtype(torch.float16):
            model = timm.create_model(
                "vit_so400m_patch14_siglip_384.webli",
                pretrained=False,
                num_classes=0,
                dynamic_img_size=True,
                dynamic_img_pad=True,
            )

445
446
447
        if get_tensor_model_parallel_world_size() > 1:
            model = self.patch_vit_for_tp(model, quant_config)

448
449
450
451
452
453
454
455
456
457
458
459
460
        model = model.to(dtype=torch.get_default_dtype())
        return model

    def _parse_and_validate_image_input(
            self, **kwargs: object) -> Optional[DeepseekVL2ImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
        images_spatial_crop = kwargs.pop("images_spatial_crop", None)
        image_embeds = kwargs.pop("image_embeds", None)

        if pixel_values is None and image_embeds is None:
            return None

        if pixel_values is not None:
461
            expected_h = expected_w = self.vision_config.image_size
462
463
464
465
466
467
468
469
            return DeepseekVL2ImagePixelInputs(
                type="pixel_values",
                data=pixel_values,
                images_spatial_crop=images_spatial_crop,
                resolve_bindings={
                    "h": expected_h,
                    "w": expected_w,
                })
470
471
472
473

        if image_embeds is not None:
            return DeepseekVL2VImageEmbeddingInputs(
                type="image_embeds",
474
                data=image_embeds,
475
476
477
478
479
480
            )

        raise AssertionError("This line should be unreachable.")

    def _pixel_values_to_embedding(
        self,
481
        pixel_values: torch.Tensor,
482
        images_spatial_crop: torch.Tensor,
483
    ) -> list[torch.Tensor]:
484
        # [batch_all_tiles, vit_seq_len, c]
485
        images_feature = self.vision.forward_features(pixel_values)
486
487
488
489
490
491
492

        # [batch_all_tiles, hw, D]
        images_embeds = self.projector(images_feature)

        _, hw, n_dim = images_embeds.shape
        h = w = int(hw**0.5)

493
        # fill image token based on self.tile_tag & self.global_view_pos
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
        tile_index = 0
        vision_embeddings = []
        for jdx in range(images_spatial_crop.size(0)):
            # extra global & local features
            num_width_tiles, num_height_tiles = images_spatial_crop[jdx]
            if num_width_tiles == 0 or num_height_tiles == 0:
                break
            num_tiles_in_image = num_width_tiles * num_height_tiles

            # [hw, D]
            global_features = images_embeds[tile_index]

            # [num_height_tiles * num_width_tiles, hw, D]
            local_features = images_embeds[tile_index + 1:tile_index + 1 +
                                           num_tiles_in_image]
            tile_index += num_tiles_in_image + 1

            # format global and local features
            # ----------------- global view add newline -----------------
            # [hw, D] -> [h, w, D]
            global_features = global_features.view(h, w, n_dim)

            # [D]     -> [h, 1, D]
            new_lines_in_global = repeat(self.image_newline, "d -> h 1 d", h=h)

            # cat([h, w, D], [h, 1, D], dim=1) -> [h, w + 1, D]
            global_features = torch.cat([global_features, new_lines_in_global],
                                        dim=1)

            # [h, w + 1, D] -> [h * (w + 1), D]
            global_features = global_features.view(-1, n_dim)

            # ----------------- local view add newline -----------------
            # [num_height_tiles * num_width_tiles, h * w, D] ->
            # [num_height_tiles * h, num_width_tiles * w, D]
            local_features = rearrange(local_features,
                                       "(th tw) (h w) d -> (th h) (tw w) d",
                                       th=num_height_tiles,
                                       tw=num_width_tiles,
                                       h=h,
                                       w=w)

            # [D] -> [num_height_tiles * h, 1, D]
            new_lines_in_local = repeat(self.image_newline,
                                        "d -> (th h) 1 d",
                                        th=num_height_tiles,
                                        h=h)

            # [num_height_tiles * h, num_width_tiles * w + 1, D]
            local_features = torch.cat([local_features, new_lines_in_local],
                                       dim=1)

            # [num_height_tiles * h, num_width_tiles * w + 1, D]
            #   --> [(num_height_tiles * h) * (num_width_tiles * w + 1), D]
            local_features = local_features.view(-1, n_dim)

            # merge global and local tiles
            if self.global_view_pos == "head":
                global_local_features = torch.cat([
                    global_features,
554
                    self.view_seperator[None, :],
555
556
557
558
559
                    local_features,
                ])
            else:
                global_local_features = torch.cat([
                    local_features,
560
                    self.view_seperator[None, :],
561
562
563
564
565
566
567
                    global_features,
                ])

            vision_embeddings.append(global_local_features)
        return vision_embeddings

    def _process_image_input(
568
            self, image_input: DeepseekVL2ImageInputs) -> list[torch.Tensor]:
569
570
571
572
573
574
575
576
577
        if image_input["type"] == "image_embeds":
            image_data = image_input["data"]
            if is_list_of(image_data, torch.Tensor):
                # it's already a list of tensors
                return image_data
            if len(image_data.shape) == 3:
                # 3D tensor
                return list(torch.unbind(image_data, dim=0))
            raise ValueError(
578
                "We expect batched 2D tensors; "
579
580
581
582
583
584
585
586
587
                "this can be either a list of 2D tensors or a single 3D tensor."
            )

        pixel_values = image_input["data"]
        images_spatial_crop = image_input["images_spatial_crop"]

        return self._pixel_values_to_embedding(
            pixel_values=pixel_values, images_spatial_crop=images_spatial_crop)

588
589
590
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

591
592
    def get_multimodal_embeddings(self,
                                  **kwargs: object) -> MultiModalEmbeddings:
593
594
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
595
            return []
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
        vision_embeddings = self._process_image_input(image_input)
        return vision_embeddings

    def forward(self,
                input_ids: torch.Tensor,
                positions: torch.Tensor,
                intermediate_tensors: Optional[IntermediateTensors] = None,
                inputs_embeds: Optional[torch.Tensor] = None,
                **kwargs: object):

        if intermediate_tensors is not None:
            inputs_embeds = None

        hidden_states = self.language_model(input_ids,
                                            positions,
                                            intermediate_tensors,
                                            inputs_embeds=inputs_embeds)

        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
620
        return self.language_model.compute_logits(hidden_states)
621

622
623
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
624
625
626
627
628

        loader = AutoWeightsLoader(self)
        autoloaded_weights = loader.load_weights(weights,
                                                 mapper=self.hf_to_vllm_mapper)
        return autoloaded_weights