module.py 1.45 KB
Newer Older
wangxj's avatar
wangxj committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.

from transformers import AutoConfig, AutoModel

from megatron.core.transformer.module import MegatronModule


class HuggingFaceModule(MegatronModule):
    """
    Basic module for huggingface
    """

    def __init__(self, config):
        super().__init__(config=config)

    def set_input_tensor(self, input_tensor):
        """Dummy function for set_input_tensor"""
        self.input_tensor = input_tensor


class AutoHuggingFaceModel(HuggingFaceModule):
    """
    Wrapper for HuggingFace AutoModel
    """

    def __init__(self, config):
        super().__init__(config)
        self.model = AutoModel.from_pretrained(config.huggingface_model_name_or_path)

    def forward(self, *args, **kwargs):
        """Forward function"""
        return self.model(*args, **kwargs)


def build_hf_model(config):
    """Builds huggingface wrapper model given config"""
    hf_config = AutoConfig.from_pretrained(config.huggingface_model_name_or_path)

    if "qwen" in hf_config.model_type:
        from megatron.core.models.huggingface.qwen_model import QwenHuggingFaceModel

        model = QwenHuggingFaceModel(config)
    elif "vit" in hf_config.model_type:
        from megatron.core.models.huggingface.clip_model import ClipHuggingFaceModel

        model = ClipHuggingFaceModel(config)
    else:
        raise NotImplementedError(f"Huggingface model type {hf_config.model_type} is not supported")

    return model