opt.py 3.16 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
import torch
import transformers

from ..registry import ModelAttribute, model_zoo

# ===============================
# Register single-sentence OPT
# ===============================
BATCH_SIZE = 2
SEQ_LENGTH = 16


def data_gen():
14
15
    input_ids = torch.Tensor([[1, 15043, 29892, 590, 11203, 338, 274, 1082]]).long()
    attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1]]).long()
16
17
18
    return dict(input_ids=input_ids, attention_mask=attention_mask)


19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def data_gen_for_causal_lm():
    # LM data gen
    # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
    data = data_gen()
    labels = data['input_ids'].clone()
    data['labels'] = labels
    return data


def data_gen_for_sequence_classification():
    # LM data gen
    # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
    data = data_gen()
    labels = data['input_ids'].clone()
    data['labels'] = torch.tensor([1])
    return data


def data_gen_for_question_answering():
    # LM data gen
    # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
    data = data_gen()
    data['start_positions'] = torch.tensor([0])
    data['end_positions'] = torch.tensor([1])
    return data

45

46
output_transform_fn = lambda x: x
47
48
loss_fn_for_opt_model = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state, torch.ones_like(x.last_hidden_state)
                                                              )
49
50
51
52
53
54
55
loss_fn_for_lm = lambda x: x.loss
config = transformers.OPTConfig(
    hidden_size=128,
    num_hidden_layers=2,
    num_attention_heads=4,
    dropout=0,
)
56
57
58
59
60
61
62
63

# register the following models
# transformers.OPTModel,
# transformers.OPTForCausalLM,
model_zoo.register(name='transformers_opt',
                   model_fn=lambda: transformers.OPTModel(config),
                   data_gen_fn=data_gen,
                   output_transform_fn=output_transform_fn,
64
                   loss_fn=loss_fn_for_opt_model,
65
66
67
                   model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_opt_for_causal_lm',
                   model_fn=lambda: transformers.OPTForCausalLM(config),
68
69
70
71
72
73
74
75
76
77
78
79
80
                   data_gen_fn=data_gen_for_causal_lm,
                   output_transform_fn=output_transform_fn,
                   loss_fn=loss_fn_for_lm,
                   model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_opt_for_question_answering',
                   model_fn=lambda: transformers.OPTForQuestionAnswering(config),
                   data_gen_fn=data_gen_for_question_answering,
                   output_transform_fn=output_transform_fn,
                   loss_fn=loss_fn_for_lm,
                   model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_opt_for_sequence_classification',
                   model_fn=lambda: transformers.OPTForSequenceClassification(config),
                   data_gen_fn=data_gen_for_sequence_classification,
81
                   output_transform_fn=output_transform_fn,
82
                   loss_fn=loss_fn_for_lm,
83
                   model_attribute=ModelAttribute(has_control_flow=True))