train_cell.py 6.82 KB
Newer Older
wangsen's avatar
wangsen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import os
GPU_NUMBER = [0]
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER])
os.environ["NCCL_DEBUG"] = "INFO"

# imports
from collections import Counter
import datetime
import pickle
import subprocess
import seaborn as sns; sns.set()
from datasets import load_from_disk
from sklearn.metrics import accuracy_score, f1_score
from transformers import BertForSequenceClassification
from transformers import Trainer
from transformers.training_args import TrainingArguments

from geneformer import DataCollatorForCellClassification


# load cell type dataset (includes all tissues)
wangsen's avatar
init  
wangsen committed
22
23
24
25
26
train_dataset=load_from_disk("/path/to/Genecorpus-30M/example_input_files/cell_classification/cell_type_annotation/cell_type_train_data.dataset/")




wangsen's avatar
wangsen committed
27
28
29
30
31
32
33
dataset_list = []
evalset_list = []
organ_list = []
target_dict_list = []

for organ in Counter(train_dataset["organ_major"]).keys():
    # collect list of tissues for fine-tuning (immune and bone marrow are included together)
wangsen's avatar
init  
wangsen committed
34
    if organ in ["bone_marrow"]:
wangsen's avatar
wangsen committed
35
36
37
38
39
40
41
        continue
    elif organ=="immune":
        organ_ids = ["immune","bone_marrow"]
        organ_list += ["immune"]
    else:
        organ_ids = [organ]
        organ_list += [organ]
wangsen's avatar
init  
wangsen committed
42

wangsen's avatar
wangsen committed
43
    print(organ)
wangsen's avatar
init  
wangsen committed
44

wangsen's avatar
wangsen committed
45
46
47
48
    # filter datasets for given organ
    def if_organ(example):
        return example["organ_major"] in organ_ids
    trainset_organ = train_dataset.filter(if_organ, num_proc=16)
wangsen's avatar
init  
wangsen committed
49
50


wangsen's avatar
wangsen committed
51
52
53
54
55
56
57
    # per scDeepsort published method, drop cell types representing <0.5% of cells
    celltype_counter = Counter(trainset_organ["cell_type"])
    total_cells = sum(celltype_counter.values())
    cells_to_keep = [k for k,v in celltype_counter.items() if v>(0.005*total_cells)]
    def if_not_rare_celltype(example):
        return example["cell_type"] in cells_to_keep
    trainset_organ_subset = trainset_organ.filter(if_not_rare_celltype, num_proc=16)
wangsen's avatar
init  
wangsen committed
58

wangsen's avatar
wangsen committed
59
60
61
62
    # shuffle datasets and rename columns
    trainset_organ_shuffled = trainset_organ_subset.shuffle(seed=42)
    trainset_organ_shuffled = trainset_organ_shuffled.rename_column("cell_type","label")
    trainset_organ_shuffled = trainset_organ_shuffled.remove_columns("organ_major")
wangsen's avatar
init  
wangsen committed
63

wangsen's avatar
wangsen committed
64
65
66
67
    # create dictionary of cell types : label ids
    target_names = list(Counter(trainset_organ_shuffled["label"]).keys())
    target_name_id_dict = dict(zip(target_names,[i for i in range(len(target_names))]))
    target_dict_list += [target_name_id_dict]
wangsen's avatar
init  
wangsen committed
68

wangsen's avatar
wangsen committed
69
70
71
72
73
    # change labels to numerical ids
    def classes_to_ids(example):
        example["label"] = target_name_id_dict[example["label"]]
        return example
    labeled_trainset = trainset_organ_shuffled.map(classes_to_ids, num_proc=16)
wangsen's avatar
init  
wangsen committed
74
        
wangsen's avatar
wangsen committed
75
76
77
78
    
    # create 80/20 train/eval splits
    labeled_train_split = labeled_trainset.select([i for i in range(0,round(len(labeled_trainset)*0.8))])
    labeled_eval_split = labeled_trainset.select([i for i in range(round(len(labeled_trainset)*0.8),len(labeled_trainset))])
wangsen's avatar
init  
wangsen committed
79

wangsen's avatar
wangsen committed
80
81
82
83
84
85
86
87
    # filter dataset for cell types in corresponding training set
    trained_labels = list(Counter(labeled_train_split["label"]).keys())
    def if_trained_label(example):
        return example["label"] in trained_labels
    labeled_eval_split_subset = labeled_eval_split.filter(if_trained_label, num_proc=16)

    dataset_list += [labeled_train_split]
    evalset_list += [labeled_eval_split_subset]
wangsen's avatar
init  
wangsen committed
88

wangsen's avatar
wangsen committed
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
trainset_dict = dict(zip(organ_list,dataset_list))
traintargetdict_dict = dict(zip(organ_list,target_dict_list))

evalset_dict = dict(zip(organ_list,evalset_list))


def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    # calculate accuracy and macro f1 using sklearn's function
    acc = accuracy_score(labels, preds)
    macro_f1 = f1_score(labels, preds, average='macro')
    return {
      'accuracy': acc,
      'macro_f1': macro_f1
wangsen's avatar
init  
wangsen committed
104
       }
wangsen's avatar
wangsen committed
105
106


wangsen's avatar
init  
wangsen committed
107
108
# set model parameters
# max input size
wangsen's avatar
wangsen committed
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
max_input_size = 2 ** 11  # 2048

# set training hyperparameters
# max learning rate
max_lr = 5e-5
# how many pretrained layers to freeze
freeze_layers = 0
# number gpus
num_gpus = 1
# number cpu cores
num_proc = 16
# batch size for training and eval
geneformer_batch_size = 12
# learning schedule
lr_schedule_fn = "linear"
# warmup steps
warmup_steps = 500
# number of epochs
epochs = 10
# optimizer
optimizer = "adamw"


for organ in organ_list:
    print(organ)
    organ_trainset = trainset_dict[organ]
    organ_evalset = evalset_dict[organ]
    organ_label_dict = traintargetdict_dict[organ]
wangsen's avatar
init  
wangsen committed
137

wangsen's avatar
wangsen committed
138
139
    # set logging steps
    logging_steps = round(len(organ_trainset)/geneformer_batch_size/10)
wangsen's avatar
init  
wangsen committed
140
141
142

    # reload pretrained model
    model = BertForSequenceClassification.from_pretrained("//Geneformer",
wangsen's avatar
wangsen committed
143
144
145
                                                      num_labels=len(organ_label_dict.keys()),
                                                      output_attentions = False,
                                                      output_hidden_states = False).to("cuda")
wangsen's avatar
init  
wangsen committed
146
  # define output directory path
wangsen's avatar
wangsen committed
147
148
    current_date = datetime.datetime.now()
    datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
wangsen's avatar
init  
wangsen committed
149
150
    output_dir = f"/path/to/models/{datestamp}_geneformer_CellClassifier_{organ}_L{max_input_size}_B{geneformer_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_E{epochs}_O{optimizer}_F{freeze_layers}/"

wangsen's avatar
wangsen committed
151
152
153
154
155
156
157
    # ensure not overwriting previously saved model
    saved_model_test = os.path.join(output_dir, f"pytorch_model.bin")
    if os.path.isfile(saved_model_test) == True:
        raise Exception("Model already saved to this directory.")

    # make output directory
    subprocess.call(f'mkdir {output_dir}', shell=True)
wangsen's avatar
init  
wangsen committed
158

wangsen's avatar
wangsen committed
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
    # set training arguments
    training_args = {
        "learning_rate": max_lr,
        "do_train": True,
        "do_eval": True,
        "evaluation_strategy": "epoch",
        "save_strategy": "epoch",
        "logging_steps": logging_steps,
        "group_by_length": True,
        "length_column_name": "length",
        "disable_tqdm": False,
        "lr_scheduler_type": lr_schedule_fn,
        "warmup_steps": warmup_steps,
        "weight_decay": 0.001,
        "per_device_train_batch_size": geneformer_batch_size,
        "per_device_eval_batch_size": geneformer_batch_size,
        "num_train_epochs": epochs,
        "load_best_model_at_end": True,
        "output_dir": output_dir,
    }
wangsen's avatar
init  
wangsen committed
179

wangsen's avatar
wangsen committed
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
    training_args_init = TrainingArguments(**training_args)

    # create the trainer
    trainer = Trainer(
        model=model,
        args=training_args_init,
        data_collator=DataCollatorForCellClassification(),
        train_dataset=organ_trainset,
        eval_dataset=organ_evalset,
        compute_metrics=compute_metrics
    )
    # train the cell type classifier
    trainer.train()
    predictions = trainer.predict(organ_evalset)
    with open(f"{output_dir}predictions.pickle", "wb") as fp:
        pickle.dump(predictions, fp)
    trainer.save_metrics("eval",predictions.metrics)
wangsen's avatar
init  
wangsen committed
197
    trainer.save_model(output_dir)