"...tests/snapshots/openai_completions__valid_samples-2.snap" did not exist on "ffc6dde1f0c6a45ac2ed72e91139949992c9c55d"
chatglm2.py 2.25 KB
Newer Older
Kun Lin's avatar
Kun Lin committed
1
2
3
import torch
import transformers

Jianghai's avatar
Jianghai committed
4
5
6
from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel

Kun Lin's avatar
Kun Lin committed
7
from ..registry import ModelAttribute, model_zoo
Jianghai's avatar
Jianghai committed
8

Kun Lin's avatar
Kun Lin committed
9
10
11
12
13
14
15
16
17
18
19
# ================================
# Register single-sentence ChatGLM
# ================================


def data_gen():
    input_ids = torch.tensor([[5941, 15, 2670, 3543, 632, 2075]], dtype=torch.int64)
    attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1]])
    return dict(input_ids=input_ids, attention_mask=attention_mask)


20
21
22
23
24
25
26
27
28
def data_gen_for_conditional_generation():
    # token classification data gen
    # `labels` is the type not the token id for token classification, 0 or 1
    data = data_gen()
    labels = data['input_ids'].clone()
    data['labels'] = labels
    return data


Kun Lin's avatar
Kun Lin committed
29
30
31
32
# define output transform function
output_transform_fn = lambda x: x

# define loss function
33
34
35
loss_fn_for_chatglm_model = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state,
                                                                   torch.ones_like(x.last_hidden_state))
loss_fn = lambda x: x.loss
Jianghai's avatar
Jianghai committed
36

37
config = ChatGLMConfig(num_layers=2,
Kun Lin's avatar
Kun Lin committed
38
39
40
                       padded_vocab_size=65024,
                       hidden_size=64,
                       num_attention_heads=8,
Jianghai's avatar
Jianghai committed
41
                       rmsnorm=True,
Kun Lin's avatar
Kun Lin committed
42
                       original_rope=True,
Jianghai's avatar
Jianghai committed
43
44
45
                       use_cache=True,
                       torch_dtype=torch.float32)

Kun Lin's avatar
Kun Lin committed
46
47
48
49
50
51
model_zoo.register(name='transformers_chatglm',
                   model_fn=lambda: ChatGLMModel(config, empty_init=False),
                   data_gen_fn=data_gen,
                   output_transform_fn=output_transform_fn,
                   loss_fn=loss_fn_for_chatglm_model,
                   model_attribute=ModelAttribute(has_control_flow=True))
52
53
54

model_zoo.register(name="transformers_chatglm_for_conditional_generation",
                   model_fn=lambda: ChatGLMForConditionalGeneration(config, empty_init=False),
55
                   data_gen_fn=data_gen_for_conditional_generation,
56
57
58
                   output_transform_fn=output_transform_fn,
                   loss_fn=loss_fn,
                   model_attribute=ModelAttribute(has_control_flow=True))