chatglm2.py 1.94 KB
Newer Older
Kun Lin's avatar
Kun Lin committed
1
2
import torch

Jianghai's avatar
Jianghai committed
3
4
5
from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel

Kun Lin's avatar
Kun Lin committed
6
from ..registry import ModelAttribute, model_zoo
Jianghai's avatar
Jianghai committed
7

Kun Lin's avatar
Kun Lin committed
8
9
10
11
12
13
# ================================
# Register single-sentence ChatGLM
# ================================


def data_gen():
14
15
    input_ids = torch.tensor([[5941, 15, 2670, 3543, 632, 2075, 632, 2075]], dtype=torch.int64)
    attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]])
Kun Lin's avatar
Kun Lin committed
16
17
18
    return dict(input_ids=input_ids, attention_mask=attention_mask)


19
20
21
22
def data_gen_for_conditional_generation():
    # token classification data gen
    # `labels` is the type not the token id for token classification, 0 or 1
    data = data_gen()
23
24
    labels = data["input_ids"].clone()
    data["labels"] = labels
25
26
27
    return data


Kun Lin's avatar
Kun Lin committed
28
29
30
31
# define output transform function
output_transform_fn = lambda x: x

# define loss function
32
33
34
loss_fn_for_chatglm_model = lambda x: torch.nn.functional.mse_loss(
    x.last_hidden_state, torch.ones_like(x.last_hidden_state)
)
35
loss_fn = lambda x: x.loss
Jianghai's avatar
Jianghai committed
36

37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
config = ChatGLMConfig(
    num_layers=2,
    padded_vocab_size=65024,
    hidden_size=64,
    num_attention_heads=8,
    rmsnorm=True,
    original_rope=True,
    use_cache=True,
    torch_dtype=torch.float32,
)

model_zoo.register(
    name="transformers_chatglm",
    model_fn=lambda: ChatGLMModel(config, empty_init=False),
    data_gen_fn=data_gen,
    output_transform_fn=output_transform_fn,
    loss_fn=loss_fn_for_chatglm_model,
    model_attribute=ModelAttribute(has_control_flow=True),
)

model_zoo.register(
    name="transformers_chatglm_for_conditional_generation",
    model_fn=lambda: ChatGLMForConditionalGeneration(config, empty_init=False),
    data_gen_fn=data_gen_for_conditional_generation,
    output_transform_fn=output_transform_fn,
    loss_fn=loss_fn,
    model_attribute=ModelAttribute(has_control_flow=True),
)