humming.py 36.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import math
from collections.abc import Callable
from typing import TYPE_CHECKING, Any

import regex as re
import torch

from vllm import envs
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
    FusedMoEConfig,
    FusedMoEQuantConfig,
    FusedMoEQuantDesc,
)
from vllm.model_executor.layers.fused_moe.layer import (
    FusedMoE,
    FusedMoEMethodBase,
)
from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import (
    UnquantizedFusedMoEMethod,
)
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig,
    QuantizeMethodBase,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.parameter import (
    BasevLLMParameter,
    BlockQuantScaleParameter,
    ChannelQuantScaleParameter,
    GroupQuantScaleParameter,
    ModelWeightParameter,
    PackedvLLMParameter,
    PerTensorScaleParameter,
    RowvLLMParameter,
)
from vllm.model_executor.utils import set_weight_attrs

if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper


try:
    from humming.dtypes import DataType
    from humming.layer import HummingMethod
    from humming.schema import (
        BaseInputSchema,
        BaseWeightSchema,
        HummingInputSchema,
        HummingWeightSchema,
    )
    from humming.utils.weight import quantize_weight

    from vllm.model_executor.layers.fused_moe.fused_humming_moe import (
        BatchedHummingGroupedExperts,
        HummingGroupedExperts,
        HummingIndexedExperts,
        get_humming_moe_gemm_type,
    )
except ModuleNotFoundError:
    HummingMethod = None


def assert_humming_available():
    assert HummingMethod is not None, (
        "humming is not available, please run "
        "'pip install git+https://github.com/inclusionAI/humming' to install it."
    )


def prepare_padded_shape(shape, x):
    padded_shape = math.ceil(shape / x) * x
    return padded_shape, padded_shape - shape


def prepare_param(tensor, name, extra_attrs):
    extra_attrs = extra_attrs.copy()
    scale_type = extra_attrs.pop("scale_type", None)
    param_cls_name_map = {
        "block": BlockQuantScaleParameter,
        "tensor": PerTensorScaleParameter,
        "group": GroupQuantScaleParameter,
        "channel": ChannelQuantScaleParameter,
        "input_scale": PerTensorScaleParameter,
    }

    param_cls: type[BasevLLMParameter]
    if "packed_dim" in extra_attrs:
        param_cls = PackedvLLMParameter
    elif scale_type in param_cls_name_map:
        param_cls = param_cls_name_map[scale_type]
    elif "output_dim" in extra_attrs and "input_dim" in extra_attrs:
        param_cls = ModelWeightParameter
    elif "input_dim" in extra_attrs:
        param_cls = RowvLLMParameter
    elif "output_dim" in extra_attrs:
        param_cls = ChannelQuantScaleParameter
    else:
        param_cls = BasevLLMParameter

    kwargs_keys = [
        "input_dim",
        "output_dim",
        "packed_dim",
        "packed_factor",
        "weight_loader",
    ]
    cls_kwargs = {}
    for key in extra_attrs.copy():
        if key in kwargs_keys:
            cls_kwargs[key] = extra_attrs.pop(key)

    param = param_cls(data=tensor, **cls_kwargs)
    set_weight_attrs(param, extra_attrs)

    param.param_name = name
    param.ignore_warning = True
    if scale_type in ["tensor", "input_scale"]:
        param.needs_scalar_to_array = True

    return param


def prepare_moe_param(tensor, name, extra_attrs):
    param = torch.nn.Parameter(tensor, requires_grad=False)
    if "scale_type" in extra_attrs:
        extra_attrs["quant_method"] = extra_attrs["scale_type"]

    if "input_dim" in extra_attrs and "output_dim" in extra_attrs:
        input_dim = extra_attrs["input_dim"]
        output_dim = extra_attrs["output_dim"]
        extra_attrs["is_transposed"] = input_dim < output_dim

    set_weight_attrs(param, extra_attrs)
    param.param_name = name
    return param


def may_pad_loaded_weight(param, loaded_weight):
    pad_shape = getattr(param, "pad_shape", None)
    if pad_shape is None:
        return loaded_weight
    value = 1 if loaded_weight.dtype == torch.float8_e8m0fnu else 0
    padding = []
    for x in pad_shape[::-1][: loaded_weight.ndim]:
        padding += [0, x]
    loaded_weight = torch.nn.functional.pad(
        input=loaded_weight,
        pad=padding,
        value=value,
    )
    return loaded_weight


def compressed_tensors_get_config(config: dict[str, Any], key: str):
    assert key in ["weights", "input_activations"]
    target_group_config = None
    for group_config in config["config_groups"].values():
        if "Linear" in group_config["targets"]:
            if "weights" not in group_config:
                return None
            if key not in group_config or group_config[key] is None:
                return None
            target_group_config = group_config[key].copy()
            break

    if target_group_config is None:
        return None
    target_group_config["quant_method"] = config["quant_method"]
    if config["quant_method"] == "compressed-tensors":
        target_group_config["format"] = config["format"]
    elif config["quant_method"] == "modelopt":
        target_group_config["quant_algo"] = config["quant_algo"]
    return target_group_config


class HummingConfig(QuantizationConfig):
    packed_modules_mapping = {}

    def __init__(self, full_config: dict[str, Any] | None = None):
        assert_humming_available()
        self.full_config: dict[str, Any] = full_config or {}

    @classmethod
    def get_name(cls) -> QuantizationMethods:
        return "humming"

    @classmethod
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
        return 75

    @classmethod
    def get_config_filenames(cls) -> list[str]:
        return []

    @classmethod
    def from_config(cls, config: dict[str, Any]) -> "HummingConfig":
        return cls(full_config=config)

    @classmethod
    def override_quantization_method(
        cls, hf_quant_cfg, user_quant, hf_config=None
    ) -> QuantizationMethods | None:
        return "humming" if user_quant == "humming" else None

    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        self.hf_to_vllm_mapper = hf_to_vllm_mapper

    def is_layer_skipped(self, config: dict[str, Any], prefix: str):
        keys = ["ignored_layers", "ignore", "modules_to_not_convert"]
        ignored_layers = self.get_from_keys_or(config, keys, []) or []
        if hasattr(self, "hf_to_vllm_mapper"):
            ignored_layers = self.hf_to_vllm_mapper.apply_list(ignored_layers)

        if any(module_name in prefix for module_name in ignored_layers):
            return True
        if "lm_head" in prefix:
            return True

        for regex in config.get("dynamic", {}):
            if regex[:1] != "-":
                continue
            if re.match(regex[2:], prefix):
                return True

        return False

    def get_layer_weight_schema(self, config: dict[str, Any], prefix: str):
        if self.is_layer_skipped(config, prefix):
            return None

        if config["quant_method"] in ["compressed-tensors", "modelopt"]:
            group_config = compressed_tensors_get_config(config, "weights")
            if group_config is None:
                return None
            config = group_config

        layer_config = config
        layer_dynamic = config.get("dynamic", {})
        if not isinstance(layer_dynamic, dict):
            layer_dynamic = {}
        for regex, override_config in layer_dynamic.items():
            if regex[:1] != "+":
                continue
            if re.match(regex[2:], prefix):
                layer_config = config.copy()
                layer_config.update(override_config)
                break

        if "quant_method" in layer_config:
            return BaseWeightSchema.from_config(layer_config)
        return None

    def get_layer_input_schema(self, config: dict[str, Any], prefix: str):
        if self.is_layer_skipped(config, prefix):
            return None
        if config["quant_method"] in ["compressed-tensors", "modelopt"]:
            group_config = compressed_tensors_get_config(config, "input_activations")
            if group_config is None:
                return None
            config = group_config

        if config.get("quant_method", None) in BaseInputSchema.INPUT_SCHEMA_MAP:
            return BaseInputSchema.from_config(config)
        return None

    def get_quant_config_for_layer(
        self, prefix: str, layer_type: str
    ) -> "HummingLayerQuantizationConfig | None":
        weight_schema: BaseWeightSchema | None = None
        force_weight_schema: HummingWeightSchema | None = None

        if self.full_config:
            weight_schema = self.get_layer_weight_schema(self.full_config, prefix)

        is_online_quant = False
        online_quant_config = envs.VLLM_HUMMING_ONLINE_QUANT_CONFIG or {}
        if not self.full_config or online_quant_config.get("force_requant", False):
            online_quant_config["quant_method"] = "humming"
            schema = self.get_layer_weight_schema(online_quant_config, prefix)
            if not self.full_config:
                weight_schema = schema
                is_online_quant = True
            else:
                force_weight_schema = schema

        if weight_schema is not None:
            if weight_schema.quant_method == "gpt_oss_mxfp4" and layer_type != "moe":
                return None
            input_schema = None
            force_input_schema = None

            if self.full_config:
                input_schema = self.get_layer_input_schema(self.full_config, prefix)

            if envs.VLLM_HUMMING_INPUT_QUANT_CONFIG:
                quant_config = envs.VLLM_HUMMING_INPUT_QUANT_CONFIG.copy()
                quant_config["quant_method"] = "humming"
                force_input_schema = self.get_layer_input_schema(quant_config, prefix)
                if input_schema is None:
                    input_schema = force_input_schema

            if force_weight_schema is not None and force_input_schema is None:
                force_input_schema = HummingInputSchema()

            return HummingLayerQuantizationConfig(
                weight_schema=weight_schema,
                input_schema=input_schema,
                force_weight_schema=force_weight_schema,
                force_input_schema=force_input_schema,
                is_online_quant=is_online_quant,
            )
        return None

    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> "QuantizeMethodBase | None":
        layer_type = "other"
        if isinstance(layer, FusedMoE):
            layer_type = "moe"
        elif isinstance(layer, LinearBase):
            layer_type = "linear"

        # TODO: remove this after humming moe backend is ready
        quant_method = self.full_config.get("quant_method", None)
        moe_activation = getattr(layer, "activation", None)
        if quant_method == "mxfp4" and moe_activation == MoEActivation.SWIGLUOAI:
            self.full_config["quan_method"] = "gpt_oss_mxfp4"

        quant_config = self.get_quant_config_for_layer(prefix, layer_type)
        if quant_config is None:
            if isinstance(layer, FusedMoE):
                return UnquantizedFusedMoEMethod(layer.moe_config)
            elif isinstance(layer, LinearBase):
                return UnquantizedLinearMethod()
        elif isinstance(layer, LinearBase):
            return HummingLinearMethod(quant_config)
        elif isinstance(layer, FusedMoE):
            return HummingMoEMethod(quant_config, layer.moe_config)
        return None


class HummingLayerQuantizationConfig(HummingConfig):
    def __init__(
        self,
        weight_schema: "BaseWeightSchema",
        input_schema: "BaseInputSchema | None" = None,
        force_weight_schema: "HummingWeightSchema | None" = None,
        force_input_schema: "HummingInputSchema | None" = None,
        is_online_quant: bool = False,
    ):
        self.weight_schema = weight_schema
        if input_schema is None:
            input_schema = HummingInputSchema()
        self.input_schema = input_schema
        self.force_weight_schema = force_weight_schema
        self.force_input_schema = force_input_schema
        self.is_online_quant = is_online_quant

    @classmethod
    def from_config(cls, config):
        weight_schema = BaseWeightSchema.from_config(config)
        return cls(weight_schema)

    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> QuantizeMethodBase | None:
        raise NotImplementedError


class HummingLinearMethod(LinearMethodBase):
    def __init__(self, quant_config: HummingLayerQuantizationConfig):
        self.quant_config = quant_config
        self.weight_schema = quant_config.weight_schema
        self.input_schema = quant_config.input_schema
        self.force_weight_schema = quant_config.force_weight_schema
        self.force_input_schema = quant_config.force_input_schema
        self.is_online_quant = self.quant_config.is_online_quant

    def prepare_weight_loader(self, layer: torch.nn.Module, weight_loader: Callable):
        def new_weight_loader(
            param: torch.nn.Parameter,
            loaded_weight: torch.Tensor,
            shard_id: str | int | None = None,
        ):
            name = param.param_name
            float_dtypes = [torch.float16, torch.bfloat16, torch.float32]
            is_unquantized = name == "weight" and loaded_weight.dtype in float_dtypes
            if is_unquantized and self.is_online_quant:
                # online quant (fp16/bf16 -> quant_type)
                assert isinstance(self.weight_schema, HummingWeightSchema)
                f16_dtype = DataType.from_torch_dtype(layer.param_dtype)
                has_global_scale = "TENSOR" in str(self.weight_schema.weight_scale_type)
                tensor_list = quantize_weight(
                    weight=loaded_weight,
                    dtype=self.weight_schema.b_dtype,
                    scale_dtype=self.weight_schema.bs_dtype or f16_dtype,
                    group_size=self.weight_schema.weight_scale_group_size,
                    has_zero_point=self.weight_schema.has_zero_point,
                    has_global_scale=has_global_scale,
                    is_fp_zero_point=self.weight_schema.is_fp_zero_point,
                    pack=True,
                )

                key_list = ["weight", "weight_scale", "zero_point", "global_scale"]
                for key, tensor in zip(key_list, tensor_list):
                    if tensor is None or tensor.nelement() == 0:
                        continue
                    param = getattr(layer, key)
                    param.weight_loader(param, tensor, shard_id)

                return None
            elif is_unquantized and not self.is_online_quant:
                # fallback to unquantized linear
                # some model skip some layer when quantizing model, but
                # don't mark the layer as unquantized.
                if not layer.is_fallback:
                    layer.is_fallback = True
                    for name, _ in list(layer.named_parameters()):
                        if name != "bias":
                            delattr(layer, name)
                    delattr(layer, "locks")
                    self.__class__ = UnquantizedLinearMethod  # type: ignore
                    tensor = torch.empty(
                        (
                            layer.output_partition_sizes_sum,
                            layer.input_size_per_partition,
                        ),
                        dtype=layer.param_dtype,
                        device=param.device,
                    )
                    extra_weight_attrs = layer.extra_weight_attrs.copy()
                    orig_weight_loader = extra_weight_attrs.pop("weight_loader")
                    layer.weight = ModelWeightParameter(
                        data=tensor,
                        input_dim=1,
                        output_dim=0,
                        weight_loader=orig_weight_loader,
                    )
                    layer.weight.tp_size = layer.tp_size
                    layer.weight.tp_rank = layer.tp_rank
                    set_weight_attrs(layer.weight, extra_weight_attrs)

                param = layer.weight
                if shard_id is not None:
                    return layer.weight.weight_loader(param, loaded_weight, shard_id)
                return layer.weight.weight_loader(param, loaded_weight)

            # weight processing logic for specific quantization schema
            loaded_weight = self.weight_schema.process_loaded_weight(
                tensor=loaded_weight,
                name=name,
            )
            if shard_id is not None:
                return weight_loader(param, loaded_weight, shard_id)
            return weight_loader(param, loaded_weight)

        return new_weight_loader

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        layer.is_fallback = False
        layer.param_dtype = params_dtype
        layer.input_size = input_size
        layer.output_size = output_size
        layer.input_size_per_partition = input_size_per_partition
        layer.output_partition_sizes_sum = sum(output_partition_sizes)
        layer.output_partition_sizes = output_partition_sizes
        layer.extra_weight_attrs = extra_weight_attrs.copy()

        weight_loader = extra_weight_attrs.get("weight_loader", default_weight_loader)
        new_weight_loader = self.prepare_weight_loader(layer, weight_loader)
        extra_weight_attrs["weight_loader"] = new_weight_loader

        for key in ["weight_block_size", "block_structure"]:
            block_size = getattr(self.weight_schema, key, None)
            if block_size is not None:
                layer.weight_block_size = block_size

        weight_tensor_attrs = self.weight_schema.get_tensors_attrs(
            shape_n=layer.output_partition_sizes_sum,
            shape_k=layer.input_size_per_partition,
            param_dtype=params_dtype,
            stack_size=len(layer.output_partition_sizes),
        )

        input_tensor_attrs = self.input_schema.get_tensors_attrs(
            shape_k=layer.input_size_per_partition,
            param_dtype=params_dtype,
            stack_size=len(layer.output_partition_sizes),
        )

        tensors_attrs = weight_tensor_attrs | input_tensor_attrs

        for name, attrs in tensors_attrs.items():
            tensor = torch.empty(attrs["shape"], dtype=attrs["dtype"])
            extra_attrs = attrs.get("extra_attrs", {}).copy()
            extra_attrs.update(extra_weight_attrs)
            param = prepare_param(tensor, name, extra_attrs)
            setattr(layer, name, param)

        locks = torch.zeros(1024, dtype=torch.int32)
        layer.register_buffer("locks", locks)

        if self.force_input_schema is not None:
            self.input_schema = self.force_input_schema

        if not hasattr(layer, "weight"):
            param = prepare_param(torch.tensor(0), "weight", extra_weight_attrs)
            layer.weight = param

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        if layer.is_fallback:
            return None

        # convert from checkpoint format to humming format
        if not isinstance(self.weight_schema, HummingWeightSchema):
            self.weight_schema, tensors = self.weight_schema.convert_humming(
                tensors=layer.state_dict(),
                shape_n_stacks=layer.output_partition_sizes,
                shape_k_stacks=[layer.input_size_per_partition],
                param_dtype=layer.param_dtype,
            )

            self.input_schema, _ = self.input_schema.convert_humming(
                tensors=layer.state_dict(),
                shape_n_stacks=layer.output_partition_sizes,
                shape_k_stacks=[layer.input_size_per_partition],
                param_dtype=layer.param_dtype,
            )

            for name, _ in list(layer.named_parameters()):
                delattr(layer, name)

            for name, tensor in tensors.items():
                param = torch.nn.Parameter(tensor, requires_grad=False)
                setattr(layer, name, param)

            del tensors

        # force requant (origin quant setting -> fp16/bf16 -> new_quant setting)
        assert isinstance(self.weight_schema, HummingWeightSchema)
        force_requant = self.force_weight_schema is not None
        if force_requant and self.weight_schema != self.force_weight_schema:
            tensors = self.weight_schema.requant_tensors(
                tensors=layer.state_dict(),
                target_weight_schema=self.force_weight_schema,
                param_dtype=layer.param_dtype,
            )

            self.weight_schema = self.force_weight_schema

            for name, _ in list(layer.named_parameters()):
                if name != "bias":
                    delattr(layer, name)

            for name, tensor in tensors.items():
                param = torch.nn.Parameter(tensor, requires_grad=False)
                setattr(layer, name, param)

            del tensors

        # prepare layer config from humming kernel
        HummingMethod.prepare_layer_meta(
            layer=layer,
            shape_n=layer.output_partition_sizes_sum,
            shape_k=layer.input_size_per_partition,
            weight_schema=self.weight_schema,
            input_schema=self.input_schema,
            pad_n_to_multiple=256,
            pad_k_to_multiple=128,
            has_bias=layer.has_bias,
            torch_dtype=layer.param_dtype,
        )

        # preprocess weight for inference
        HummingMethod.transform_humming_layer(layer)

        # compute_config: kernel configs that do not directly affect weights
        # but significantly impact kernel behavior or computation precision.
        # see https://github.com/inclusionAI/humming/blob/main/docs/config.md
        compute_config = {
            "use_batch_invariant": envs.VLLM_BATCH_INVARIANT,
            "use_f16_accum": envs.VLLM_HUMMING_USE_F16_ACCUM,
            "gemm_type": "dense",
        }
        self.compute_config = json.dumps(compute_config)

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        flatten_inputs = x.view(-1, x.size(-1))
        output = HummingMethod.forward_layer(
            layer=layer,
            inputs=flatten_inputs,
            compute_config=self.compute_config,
        )
        output = output.view(*x.shape[:-1], output.size(-1))
        return output


class HummingMoEMethod(FusedMoEMethodBase):
    def __init__(
        self, quant_config: HummingLayerQuantizationConfig, moe: "FusedMoEConfig"
    ) -> None:
        super().__init__(moe)
        self.quant_config = quant_config
        self.moe = moe
        self.weight_schema = quant_config.weight_schema
        self.input_schema = quant_config.input_schema
        self.force_weight_schema = quant_config.force_weight_schema
        self.force_input_schema = quant_config.force_input_schema

    def prepare_weight_loader(self, layer, weight_loader):
        def new_weight_loader(
            param: torch.nn.Parameter,
            loaded_weight: torch.Tensor,
            weight_name: str,
            shard_id: str,
            expert_id: int | None = None,
            return_success: bool = False,
        ):
            name = param.param_name
            float_dtypes = [torch.float16, torch.bfloat16, torch.float32]
            is_unquantized = name == "weight" and loaded_weight.dtype in float_dtypes
            # online quant (fp16/bf16 -> quant_type)
            if is_unquantized:
                assert isinstance(self.weight_schema, HummingWeightSchema)
                f16_dtype = DataType.from_torch_dtype(layer.param_dtype)
                has_global_scale = "TENSOR" in str(self.weight_schema.weight_scale_type)
                tensor_list = quantize_weight(
                    weight=loaded_weight,
                    dtype=self.weight_schema.b_dtype,
                    scale_dtype=self.weight_schema.bs_dtype or f16_dtype,
                    group_size=self.weight_schema.weight_scale_group_size,
                    has_zero_point=self.weight_schema.has_zero_point,
                    has_global_scale=has_global_scale,
                    is_fp_zero_point=self.weight_schema.is_fp_zero_point,
                    pack=True,
                )

                key_list = ["weight", "weight_scale", "zero_point", "global_scale"]
                success = True
                for key, tensor in zip(key_list, tensor_list):
                    if tensor is None or tensor.nelement() == 0:
                        continue
                    sublayer_name = "w2" if shard_id == "w2" else "w13"

                    param = getattr(layer, sublayer_name + "_" + key)
                    part_subccess = param.weight_loader(
                        param=param,
                        loaded_weight=tensor.cpu(),
                        weight_name=shard_id + "_" + key,
                        shard_id=shard_id,
                        expert_id=expert_id,
                        return_success=return_success,
                    )
                    success = success and part_subccess

                return success if return_success else None

            # weight processing logic for specific quantization schema
            loaded_weight = self.weight_schema.process_loaded_weight(
                tensor=loaded_weight,
                name=name,
            )
            return weight_loader(
                param,
                loaded_weight,
                weight_name,
                shard_id=shard_id,
                expert_id=expert_id,
                return_success=return_success,
            )

        return new_weight_loader

    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        layer.num_experts = num_experts
        layer.param_dtype = params_dtype
        layer.intermediate_size = intermediate_size_per_partition
        weight_loader = extra_weight_attrs.get("weight_loader", default_weight_loader)
        weight_loader = self.prepare_weight_loader(layer, weight_loader)
        extra_weight_attrs["weight_loader"] = weight_loader

        # sublayer: a layer contains multiple sets of weights for quantized GEMM
        # (e.g., weight, weight_scale, etc.).
        # The weight names of sublayer start with the prefix "{sublayer_name}_"
        layer.sublayer_configs = {
            "w13": {
                "shape_n": intermediate_size_per_partition * 2,
                "shape_k": hidden_size,
                "tensors_attrs": self.weight_schema.get_padded_tensors_attrs(
                    shape_n=intermediate_size_per_partition * 2,
                    shape_k=hidden_size,
                    num_experts=num_experts,
                    param_dtype=params_dtype,
                    has_bias=self.moe.has_bias,
                ),
            },
            "w2": {
                "shape_n": hidden_size,
                "shape_k": intermediate_size_per_partition,
                "tensors_attrs": self.weight_schema.get_padded_tensors_attrs(
                    shape_n=hidden_size,
                    shape_k=intermediate_size_per_partition,
                    num_experts=num_experts,
                    param_dtype=params_dtype,
                    has_bias=self.moe.has_bias,
                ),
            },
        }

        for sublayer_name, configs in layer.sublayer_configs.items():
            for name, attrs in configs["tensors_attrs"].items():
                tensor = torch.empty(attrs["shape"], dtype=attrs["dtype"])
                param = torch.nn.Parameter(tensor, requires_grad=False)
                extra_attrs = attrs.get("extra_attrs", {}).copy()
                extra_attrs.update(extra_weight_attrs)
                param = prepare_moe_param(tensor, name, extra_attrs)
                setattr(layer, f"{sublayer_name}_{name}", param)

        if self.force_input_schema is not None:
            self.input_schema = self.force_input_schema

        locks = torch.zeros(1024, dtype=torch.int32)
        layer.register_buffer("locks", locks)

    def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
        self.process_weights_after_loading(layer)

        input_schema = self.input_schemas["w13"]
        weight_schema = self.weight_schemas["w13"]

        a_dtype = input_schema.a_dtype
        if a_dtype is None or a_dtype.num_bits == 16:
            a_quant_desc = FusedMoEQuantDesc(dtype=None)
        else:
            shape = GroupShape(row=1, col=-1)
            a_quant_desc = FusedMoEQuantDesc(dtype=str(a_dtype), shape=shape)

        weight_scale_group_size = weight_schema.weight_scale_group_size
        weight_scale_group_size_n = weight_schema.weight_scale_group_size_n
        weight_group_shape: tuple[int, ...] = ()
        if weight_scale_group_size_n > 1:
            weight_group_shape = GroupShape(
                row=weight_scale_group_size,
                col=weight_scale_group_size_n,
            )
        elif weight_scale_group_size == 0:
            weight_group_shape = GroupShape(row=-1, col=1)
        else:
            weight_group_shape = GroupShape(row=weight_scale_group_size, col=1)

        w1_quant_desc = FusedMoEQuantDesc(
            dtype=str(weight_schema.b_dtype),
            shape=weight_group_shape,
            scale=getattr(layer, "w13_weight_scale", None),
            alpha_or_gscale=getattr(layer, "w13_global_scale", None),
            zp=getattr(layer, "w13_zero_point", None),
            bias=getattr(layer, "w13_bias", None),
        )

        w2_quant_desc = FusedMoEQuantDesc(
            dtype=str(weight_schema.b_dtype),
            shape=weight_group_shape,
            scale=getattr(layer, "w2_weight_scale", None),
            alpha_or_gscale=getattr(layer, "w2_global_scale", None),
            zp=getattr(layer, "w2_zero_point", None),
            bias=getattr(layer, "w2_bias", None),
        )

        return FusedMoEQuantConfig(
            _a1=a_quant_desc,
            _a2=a_quant_desc,
            _w1=w1_quant_desc,
            _w2=w2_quant_desc,
        )

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        if getattr(self, "processed", False):
            return
        self.processed = True
        self.weight_schemas = {}
        self.input_schemas = {}
        for sublayer_name, configs in layer.sublayer_configs.items():
            input_schema = self.input_schema
            weight_schema = self.weight_schema
            # convert from checkpoint format to humming format
            if not isinstance(weight_schema, HummingWeightSchema):
                tensors: dict[str, torch.Tensor] = dict(
                    (key.removeprefix(sublayer_name + "_"), value)
                    for key, value in layer.state_dict().items()
                    if key.startswith(sublayer_name + "_")
                )

                shape_k_stacks = [configs["shape_k"]]
                shape_n_stacks = [configs["shape_n"]]
                if sublayer_name == "w13":
                    shape_n_stacks = [configs["shape_n"] // 2] * 2

                weight_schema, tensors = weight_schema.convert_humming(
                    tensors=tensors,
                    shape_n_stacks=shape_n_stacks,
                    shape_k_stacks=shape_k_stacks,
                    param_dtype=layer.param_dtype,
                    num_experts=layer.num_experts,
                )

                input_schema, _ = input_schema.convert_humming(
                    tensors=tensors,
                    shape_n_stacks=shape_n_stacks,
                    shape_k_stacks=shape_k_stacks,
                    param_dtype=layer.param_dtype,
                    num_experts=layer.num_experts,
                )

                for name, _ in list(layer.named_parameters()):
                    if not name.startswith(sublayer_name + "_"):
                        continue
                    delattr(layer, name)

                for name, tensor in tensors.items():
                    name = f"{sublayer_name}_{name}"
                    param = torch.nn.Parameter(tensor, requires_grad=False)
                    setattr(layer, name, param)

                self.weight_schemas[sublayer_name] = weight_schema
                self.input_schemas[sublayer_name] = input_schema

            # force requant (origin quant setting -> fp16/bf16 -> new_quant setting)
            assert isinstance(weight_schema, HummingWeightSchema)
            force_requant = self.force_weight_schema is not None
            if force_requant and weight_schema != self.force_weight_schema:
                tensors = dict(
                    (key.removeprefix(sublayer_name + "_"), value)
                    for key, value in layer.state_dict().items()
                    if key.startswith(sublayer_name + "_")
                )

                tensors = weight_schema.requant_tensors(
                    tensors=tensors,
                    target_weight_schema=self.force_weight_schema,
                    param_dtype=layer.param_dtype,
                )

                weight_schema = self.force_weight_schema

                for name, _ in list(layer.named_parameters()):
                    if not name.startswith(sublayer_name + "_"):
                        continue
                    if name == sublayer_name + "_bias":
                        continue
                    delattr(layer, name)

                for name, tensor in tensors.items():
                    name = f"{sublayer_name}_{name}"
                    param = torch.nn.Parameter(tensor, requires_grad=False)
                    setattr(layer, name, param)

                del tensors

            # prepare layer config from humming kernel
            HummingMethod.prepare_layer_meta(
                layer=layer,
                shape_n=configs["shape_n"],
                shape_k=configs["shape_k"],
                pad_n_to_multiple=256,
                pad_k_to_multiple=128,
                input_schema=input_schema,
                weight_schema=weight_schema,
                has_bias=self.moe.has_bias,
                num_experts=layer.num_experts,
                torch_dtype=layer.param_dtype,
                sublayer_name=sublayer_name,
            )

            # preprocess weight for inference
            HummingMethod.transform_humming_layer(layer, sublayer_name=sublayer_name)

        # use moe modular
        experts: HummingIndexedExperts | HummingGroupedExperts
        if get_humming_moe_gemm_type() == "indexed":
            experts = HummingIndexedExperts(layer, self)
        else:
            experts = HummingGroupedExperts(layer, self)
        self.experts = experts

    def select_gemm_impl(
        self,
        prepare_finalize,
        layer: torch.nn.Module,
    ):
        from vllm.model_executor.layers.fused_moe import modular_kernel as mk

        activation_format = prepare_finalize.activation_format
        if activation_format == mk.FusedMoEActivationFormat.BatchedExperts:
            return BatchedHummingGroupedExperts(layer, self, prepare_finalize)
        elif get_humming_moe_gemm_type() == "indexed":
            return HummingIndexedExperts(layer, self, prepare_finalize)
        else:
            return HummingGroupedExperts(layer, self, prepare_finalize)

    def apply(
        self,
        layer: FusedMoE,
        x: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
        shared_experts_input: torch.Tensor | None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        workspace1, workspace2, output = self.experts.make_workspaces(
            M=topk_ids.size(0),
            topk=topk_ids.size(1),
            activation=layer.activation,
        )

        assert workspace1.data_ptr() == output.data_ptr()

        self.experts.main_apply(
            hidden_states=x,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            workspace1=workspace1,
            workspace2=workspace2,
            expert_tokens_meta=None,
        )

        return output