chatglm2.py 2.25 KB
Newer Older
Kun Lin's avatar
Kun Lin committed
1
2
import torch

Jianghai's avatar
Jianghai committed
3
4
5
from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel

Kun Lin's avatar
Kun Lin committed
6
from ..registry import ModelAttribute, model_zoo
Jianghai's avatar
Jianghai committed
7

Kun Lin's avatar
Kun Lin committed
8
9
10
11
12
13
# ================================
# Register single-sentence ChatGLM
# ================================


def data_gen():
14
15
    input_ids = torch.tensor([[5941, 15, 2670, 3543, 632, 2075, 632, 2075]], dtype=torch.int64)
    attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]])
Kun Lin's avatar
Kun Lin committed
16
17
18
    return dict(input_ids=input_ids, attention_mask=attention_mask)


19
20
21
22
def data_gen_for_conditional_generation():
    # token classification data gen
    # `labels` is the type not the token id for token classification, 0 or 1
    data = data_gen()
23
24
    labels = data["input_ids"].clone()
    data["labels"] = labels
25
26
27
    return data


Kun Lin's avatar
Kun Lin committed
28
29
30
31
# define output transform function
output_transform_fn = lambda x: x

# define loss function
32
33
34
loss_fn_for_chatglm_model = lambda x: torch.nn.functional.mse_loss(
    x.last_hidden_state, torch.ones_like(x.last_hidden_state)
)
35
loss_fn = lambda x: x.loss
Jianghai's avatar
Jianghai committed
36

37
38
39
40
41
config = ChatGLMConfig(
    num_layers=2,
    padded_vocab_size=65024,
    hidden_size=64,
    num_attention_heads=8,
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
    kv_channels=16,
    rmsnorm=True,
    original_rope=True,
    use_cache=True,
    torch_dtype=torch.float32,
)

infer_config = ChatGLMConfig(
    num_layers=2,
    padded_vocab_size=65024,
    hidden_size=128,
    num_attention_heads=8,
    multi_query_attention=True,
    multi_query_group_num=2,
    kv_channels=16,
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
    rmsnorm=True,
    original_rope=True,
    use_cache=True,
    torch_dtype=torch.float32,
)

model_zoo.register(
    name="transformers_chatglm",
    model_fn=lambda: ChatGLMModel(config, empty_init=False),
    data_gen_fn=data_gen,
    output_transform_fn=output_transform_fn,
    loss_fn=loss_fn_for_chatglm_model,
    model_attribute=ModelAttribute(has_control_flow=True),
)

model_zoo.register(
    name="transformers_chatglm_for_conditional_generation",
    model_fn=lambda: ChatGLMForConditionalGeneration(config, empty_init=False),
    data_gen_fn=data_gen_for_conditional_generation,
    output_transform_fn=output_transform_fn,
    loss_fn=loss_fn,
    model_attribute=ModelAttribute(has_control_flow=True),
)