chatglm.py 1.75 KB
Newer Older
Kun Lin's avatar
Kun Lin committed
1
2
3
4
5
import torch
import transformers

from ..registry import ModelAttribute, model_zoo
from .chatglm2_6b.configuration_chatglm import ChatGLMConfig
6
from .chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
Kun Lin's avatar
Kun Lin committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23

# ================================
# Register single-sentence ChatGLM
# ================================


def data_gen():
    input_ids = torch.tensor([[5941, 15, 2670, 3543, 632, 2075]], dtype=torch.int64)
    attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1]])
    return dict(input_ids=input_ids, attention_mask=attention_mask)


# define output transform function
output_transform_fn = lambda x: x

# define loss function
loss_fn_for_chatglm_model = lambda x: x.last_hidden_state.mean()
24
loss_fn = lambda x: x.logits.mean()
Kun Lin's avatar
Kun Lin committed
25
26
27
28
29
30
31
32
33
34
35
36
37
38
config = ChatGLMConfig(num_layers=1,
                       padded_vocab_size=65024,
                       hidden_size=64,
                       num_attention_heads=8,
                       rmsnorm=False,
                       original_rope=True,
                       use_cache=True)

model_zoo.register(name='transformers_chatglm',
                   model_fn=lambda: ChatGLMModel(config, empty_init=False),
                   data_gen_fn=data_gen,
                   output_transform_fn=output_transform_fn,
                   loss_fn=loss_fn_for_chatglm_model,
                   model_attribute=ModelAttribute(has_control_flow=True))
39
40
41
42
43
44
45

model_zoo.register(name="transformers_chatglm_for_conditional_generation",
                   model_fn=lambda: ChatGLMForConditionalGeneration(config, empty_init=False),
                   data_gen_fn=data_gen,
                   output_transform_fn=output_transform_fn,
                   loss_fn=loss_fn,
                   model_attribute=ModelAttribute(has_control_flow=True))