opt.py 2.98 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
import torch
import transformers

from ..registry import ModelAttribute, model_zoo

# ===============================
# Register single-sentence OPT
# ===============================
BATCH_SIZE = 2
SEQ_LENGTH = 16


def data_gen():
14
15
    input_ids = torch.Tensor([[1, 15043, 29892, 590, 11203, 338, 274, 1082]]).long()
    attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1]]).long()
16
17
18
    return dict(input_ids=input_ids, attention_mask=attention_mask)


19
20
21
22
def data_gen_for_causal_lm():
    # LM data gen
    # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
    data = data_gen()
23
24
    labels = data["input_ids"].clone()
    data["labels"] = labels
25
26
27
28
29
30
31
    return data


def data_gen_for_sequence_classification():
    # LM data gen
    # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
    data = data_gen()
32
33
    data["input_ids"].clone()
    data["labels"] = torch.tensor([1])
34
35
36
37
38
39
40
    return data


def data_gen_for_question_answering():
    # LM data gen
    # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
    data = data_gen()
41
42
    data["start_positions"] = torch.tensor([0])
    data["end_positions"] = torch.tensor([1])
43
44
    return data

45

46
output_transform_fn = lambda x: x
47
48
49
loss_fn_for_opt_model = lambda x: torch.nn.functional.mse_loss(
    x.last_hidden_state, torch.ones_like(x.last_hidden_state)
)
50
51
52
53
54
55
56
loss_fn_for_lm = lambda x: x.loss
config = transformers.OPTConfig(
    hidden_size=128,
    num_hidden_layers=2,
    num_attention_heads=4,
    dropout=0,
)
57
58
59
60

# register the following models
# transformers.OPTModel,
# transformers.OPTForCausalLM,
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
model_zoo.register(
    name="transformers_opt",
    model_fn=lambda: transformers.OPTModel(config),
    data_gen_fn=data_gen,
    output_transform_fn=output_transform_fn,
    loss_fn=loss_fn_for_opt_model,
    model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
    name="transformers_opt_for_causal_lm",
    model_fn=lambda: transformers.OPTForCausalLM(config),
    data_gen_fn=data_gen_for_causal_lm,
    output_transform_fn=output_transform_fn,
    loss_fn=loss_fn_for_lm,
    model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
    name="transformers_opt_for_question_answering",
    model_fn=lambda: transformers.OPTForQuestionAnswering(config),
    data_gen_fn=data_gen_for_question_answering,
    output_transform_fn=output_transform_fn,
    loss_fn=loss_fn_for_lm,
    model_attribute=ModelAttribute(has_control_flow=True),
)
85
86
87
88
89
90
91
92

# TODO The loss and gradient check in the test are failing, to be fixed.
# model_zoo.register(name='transformers_opt_for_sequence_classification',
#                    model_fn=lambda: transformers.OPTForSequenceClassification(config),
#                    data_gen_fn=data_gen_for_sequence_classification,
#                    output_transform_fn=output_transform_fn,
#                    loss_fn=loss_fn_for_lm,
#                    model_attribute=ModelAttribute(has_control_flow=True))