Commit 9a5601d8 authored by wangsen's avatar wangsen
Browse files

change train.py

parent de076fe9
...@@ -11,8 +11,8 @@ sudo apt-get install git-lfs ...@@ -11,8 +11,8 @@ sudo apt-get install git-lfs
``` ```
#git clone https://hf-mirror.com/datasets/ctheodoris/Genecorpus-30M #git clone https://hf-mirror.com/datasets/ctheodoris/Genecorpus-30M
mkdir datasets mkdir cell_type_train_data.dataset
cd datasets cd cell_type_train_data.dataset
wget https://hf-mirror.com/datasets/ctheodoris/Genecorpus-30M/resolve/main/example_input_files/cell_classification/cell_type_annotation/cell_type_train_data.dataset/dataset.arrow wget https://hf-mirror.com/datasets/ctheodoris/Genecorpus-30M/resolve/main/example_input_files/cell_classification/cell_type_annotation/cell_type_train_data.dataset/dataset.arrow
wget https://hf-mirror.com/datasets/ctheodoris/Genecorpus-30M/resolve/main/example_input_files/cell_classification/cell_type_annotation/cell_type_train_data.dataset/dataset_info.json wget https://hf-mirror.com/datasets/ctheodoris/Genecorpus-30M/resolve/main/example_input_files/cell_classification/cell_type_annotation/cell_type_train_data.dataset/dataset_info.json
wget https://hf-mirror.com/datasets/ctheodoris/Genecorpus-30M/resolve/main/example_input_files/cell_classification/cell_type_annotation/cell_type_train_data.dataset/state.json wget https://hf-mirror.com/datasets/ctheodoris/Genecorpus-30M/resolve/main/example_input_files/cell_classification/cell_type_annotation/cell_type_train_data.dataset/state.json
...@@ -36,15 +36,109 @@ conda create -n geneformer python=3.10 ...@@ -36,15 +36,109 @@ conda create -n geneformer python=3.10
conda activate geneformer conda activate geneformer
pip install torch #dcu版本的torch pip install torch #dcu版本的torch
pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
```
## 部署后环境
```
accelerate 0.33.0
accumulation_tree 0.6.2
aiohappyeyeballs 2.3.6
aiohttp 3.10.3
aiosignal 1.3.1
anndata 0.10.8
array_api_compat 1.8
async-timeout 4.0.3
attrs 24.2.0
certifi 2024.7.4
charset-normalizer 3.3.2
click 8.1.7
cloudpickle 3.0.0
contourpy 1.2.1
cycler 0.12.1
datasets 2.21.0
dill 0.3.8
exceptiongroup 1.2.2
filelock 3.15.4
fonttools 4.53.1
frozenlist 1.4.1
fsspec 2024.6.1
future 1.0.0
geneformer 0.1.0
h5py 3.11.0
huggingface-hub 0.24.5
hyperopt 0.2.7
idna 3.7
Jinja2 3.1.4
joblib 1.4.2
jsonschema 4.23.0
jsonschema-specifications 2023.12.1
kiwisolver 1.4.5
legacy-api-wrap 1.4
llvmlite 0.43.0
loompy 3.0.7
MarkupSafe 2.1.5
matplotlib 3.9.2
mpmath 1.3.0
msgpack 1.0.8
multidict 6.0.5
multiprocess 0.70.16
natsort 8.4.0
networkx 3.3
numba 0.60.0
numpy 1.26.4
numpy-groupies 0.11.2
packaging 24.1
pandas 2.2.2
patsy 0.5.6
pillow 10.4.0
pip 24.2
protobuf 5.27.3
psutil 6.0.0
py4j 0.10.9.7
pyarrow 17.0.0
pynndescent 0.5.13
pyparsing 3.1.2
python-dateutil 2.9.0.post0
pytz 2024.1
pyudorandom 1.0.0
PyYAML 6.0.2
ray 2.34.0
referencing 0.35.1
regex 2024.7.24
requests 2.32.3
rpds-py 0.20.0
safetensors 0.4.4
scanpy 1.10.2
scikit-learn 1.5.1
scipy 1.14.0
seaborn 0.13.2
session_info 1.0.0
setuptools 72.1.0
six 1.16.0
statsmodels 0.14.2
stdlib-list 0.10.0
sympy 1.13.2
tdigest 0.5.2.2
threadpoolctl 3.5.0
tokenizers 0.19.1
torch 2.1.0+git540102b.abi0.dtk2404
tqdm 4.66.5
transformers 4.44.0
typing_extensions 4.12.2
tzdata 2024.1
umap-learn 0.5.6
urllib3 2.2.2
wheel 0.43.0
xxhash 3.4.1
yarl 1.9.4
``` ```
# 模型训练 # 模型训练
``` ```
#单卡运行 #单卡运行
python classifier.py \ python geneformer/classifier.py \
--Classifierclassifier="cell"\ --Classifierclassifier="cell"\
--cell_state_dict = {"state_key": "disease", "states": "all"}\ --cell_state_dict = {"state_key": "disease", "states": "all"}\
--filter_data=filter_data_dict\ --filter_data=filter_data_dict\
...@@ -56,15 +150,20 @@ python classifier.py \ ...@@ -56,15 +150,20 @@ python classifier.py \
--nproc=1 --nproc=1
#参考 Geneformer/examples/cell_classification.ipynb #详情请参考 Geneformer/examples/cell_classification.ipynb
```
# 或者执行
python test_cell_classifier.py # 替换py文件中dataset的路径
'''
# 模型推理 # 模型推理
``` ```
python classifier.py --classifier="cell" --cell_state_dict = {"state_key": "disease", "states": "all"} --forward_batch_size=200 --nproc=1 python geneformer/classifier.py --classifier="cell" --cell_state_dict = {"state_key": "disease", "states": "all"} --forward_batch_size=200 --nproc=1 # 直接运行会出现报错 具体请参考Geneformer/examples/cell_classification.ipynb
#Geneformer/examples/cell_classification.ipynb
``` ```
......
#https://gitee.com/hf-models/Geneformer/blob/main/examples/cell_classification.ipynb
#具体可以参考
import os
GPU_NUMBER = [0]
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER])
os.environ["NCCL_DEBUG"] = "INFO"
# imports
from collections import Counter
import datetime
import pickle
import subprocess
import seaborn as sns; sns.set()
from datasets import load_from_disk
from sklearn.metrics import accuracy_score, f1_score
from transformers import BertForSequenceClassification
from transformers import Trainer
from transformers.training_args import TrainingArguments
from geneformer import DataCollatorForCellClassification
# load cell type dataset (includes all tissues)
train_dataset=load_from_disk("/genecorpus_30M_2048.dataset") ##更改数据集路径
dataset_list = []
evalset_list = []
organ_list = []
target_dict_list = []
for organ in Counter(train_dataset["organ_major"]).keys():
# collect list of tissues for fine-tuning (immune and bone marrow are included together)
if organ in ["bone_marrow"]:
continue
elif organ=="immune":
organ_ids = ["immune","bone_marrow"]
organ_list += ["immune"]
else:
organ_ids = [organ]
organ_list += [organ]
print(organ)
# filter datasets for given organ
def if_organ(example):
return example["organ_major"] in organ_ids
trainset_organ = train_dataset.filter(if_organ, num_proc=16)
# per scDeepsort published method, drop cell types representing <0.5% of cells
celltype_counter = Counter(trainset_organ["cell_type"])
total_cells = sum(celltype_counter.values())
cells_to_keep = [k for k,v in celltype_counter.items() if v>(0.005*total_cells)]
def if_not_rare_celltype(example):
return example["cell_type"] in cells_to_keep
trainset_organ_subset = trainset_organ.filter(if_not_rare_celltype, num_proc=16)
# shuffle datasets and rename columns
trainset_organ_shuffled = trainset_organ_subset.shuffle(seed=42)
trainset_organ_shuffled = trainset_organ_shuffled.rename_column("cell_type","label")
trainset_organ_shuffled = trainset_organ_shuffled.remove_columns("organ_major")
# create dictionary of cell types : label ids
target_names = list(Counter(trainset_organ_shuffled["label"]).keys())
target_name_id_dict = dict(zip(target_names,[i for i in range(len(target_names))]))
target_dict_list += [target_name_id_dict]
# change labels to numerical ids
def classes_to_ids(example):
example["label"] = target_name_id_dict[example["label"]]
return example
labeled_trainset = trainset_organ_shuffled.map(classes_to_ids, num_proc=16)
# create 80/20 train/eval splits
labeled_train_split = labeled_trainset.select([i for i in range(0,round(len(labeled_trainset)*0.8))])
labeled_eval_split = labeled_trainset.select([i for i in range(round(len(labeled_trainset)*0.8),len(labeled_trainset))])
# filter dataset for cell types in corresponding training set
trained_labels = list(Counter(labeled_train_split["label"]).keys())
def if_trained_label(example):
return example["label"] in trained_labels
labeled_eval_split_subset = labeled_eval_split.filter(if_trained_label, num_proc=16)
dataset_list += [labeled_train_split]
evalset_list += [labeled_eval_split_subset]
trainset_dict = dict(zip(organ_list,dataset_list))
traintargetdict_dict = dict(zip(organ_list,target_dict_list))
evalset_dict = dict(zip(organ_list,evalset_list))
def compute_metrics(pred):
labels = pred.label_ids
preds = pred.predictions.argmax(-1)
# calculate accuracy and macro f1 using sklearn's function
acc = accuracy_score(labels, preds)
macro_f1 = f1_score(labels, preds, average='macro')
return {
'accuracy': acc,
'macro_f1': macro_f1
}
max_input_size = 2 ** 11 # 2048
# set training hyperparameters
# max learning rate
max_lr = 5e-5
# how many pretrained layers to freeze
freeze_layers = 0
# number gpus
num_gpus = 1
# number cpu cores
num_proc = 16
# batch size for training and eval
geneformer_batch_size = 12
# learning schedule
lr_schedule_fn = "linear"
# warmup steps
warmup_steps = 500
# number of epochs
epochs = 10
# optimizer
optimizer = "adamw"
for organ in organ_list:
print(organ)
organ_trainset = trainset_dict[organ]
organ_evalset = evalset_dict[organ]
organ_label_dict = traintargetdict_dict[organ]
# set logging steps
logging_steps = round(len(organ_trainset)/geneformer_batch_size/10)
# reload pretrained model # 更改路径Geneformer 路径
model = BertForSequenceClassification.from_pretrained("/home/Geneformer",
num_labels=len(organ_label_dict.keys()),
output_attentions = False,
output_hidden_states = False).to("cuda")
# define output directory path
current_date = datetime.datetime.now()
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
output_dir = f"/path/to/models/{datestamp}_geneformer_CellClassifier_{organ}_L{max_input_size}_B{geneformer_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_E{epochs}_O{optimizer}_F{freeze_layers}/" #
# ensure not overwriting previously saved model
saved_model_test = os.path.join(output_dir, f"pytorch_model.bin")
if os.path.isfile(saved_model_test) == True:
raise Exception("Model already saved to this directory.")
# make output directory
subprocess.call(f'mkdir {output_dir}', shell=True)
# set training arguments
training_args = {
"learning_rate": max_lr,
"do_train": True,
"do_eval": True,
"evaluation_strategy": "epoch",
"save_strategy": "epoch",
"logging_steps": logging_steps,
"group_by_length": True,
"length_column_name": "length",
"disable_tqdm": False,
"lr_scheduler_type": lr_schedule_fn,
"warmup_steps": warmup_steps,
"weight_decay": 0.001,
"per_device_train_batch_size": geneformer_batch_size,
"per_device_eval_batch_size": geneformer_batch_size,
"num_train_epochs": epochs,
"load_best_model_at_end": True,
"output_dir": output_dir,
}
training_args_init = TrainingArguments(**training_args)
# create the trainer
trainer = Trainer(
model=model,
args=training_args_init,
data_collator=DataCollatorForCellClassification(),
train_dataset=organ_trainset,
eval_dataset=organ_evalset,
compute_metrics=compute_metrics
)
# train the cell type classifier
trainer.train()
predictions = trainer.predict(organ_evalset)
with open(f"{output_dir}predictions.pickle", "wb") as fp:
pickle.dump(predictions, fp)
trainer.save_metrics("eval",predictions.metrics)
trainer.save_model(output_dir)
import datetime
from geneformer import Classifier
current_date = datetime.datetime.now()
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}{current_date.hour:02d}{current_date.minute:02d}{current_date.second:02d}"
datestamp_min = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
output_prefix = "cm_classifier_test"
output_dir = f"/path/to/output_dir/{datestamp}"
# !mkdir $output_dir
filter_data_dict={"cell_type":["Cardiomyocyte1","Cardiomyocyte2","Cardiomyocyte3"]}
training_args = {
"num_train_epochs": 0.9,
"learning_rate": 0.000804,
"lr_scheduler_type": "polynomial",
"warmup_steps": 1812,
"weight_decay":0.258828,
"per_device_train_batch_size": 12,
"seed": 73,
}
cc = Classifier(classifier="cell",
cell_state_dict = {"state_key": "disease", "states": "all"},
filter_data=filter_data_dict,
training_args=training_args,
max_ncells=None,
freeze_layers = 2,
num_crossval_splits = 1,
forward_batch_size=200,
nproc=16)
# previously balanced splits with prepare_data and validate functions
# argument attr_to_split set to "individual" and attr_to_balance set to ["disease","lvef","age","sex","length"]
train_ids = ["1447", "1600", "1462", "1558", "1300", "1508", "1358", "1678", "1561", "1304", "1610", "1430", "1472", "1707", "1726", "1504", "1425", "1617", "1631", "1735", "1582", "1722", "1622", "1630", "1290", "1479", "1371", "1549", "1515"]
eval_ids = ["1422", "1510", "1539", "1606", "1702"]
test_ids = ["1437", "1516", "1602", "1685", "1718"]
train_test_id_split_dict = {"attr_key": "individual",
"train": train_ids+eval_ids,
"test": test_ids}
# Example input_data_file: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset
cc.prepare_data(input_data_file="/path/to/Genecorpus-30M/genecorpus_30M_2048.dataset",
output_directory=output_dir,
output_prefix=output_prefix,
split_id_dict=train_test_id_split_dict)
train_valid_id_split_dict = {"attr_key": "individual",
"train": train_ids,
"eval": eval_ids}
# 6 layer Geneformer: https://huggingface.co/ctheodoris/Geneformer/blob/main/model.safetensors
all_metrics = cc.validate(model_directory="/home/Geneformer",
prepared_input_data_file=f"{output_dir}/{output_prefix}_labeled_train.dataset",
id_class_dict_file=f"{output_dir}/{output_prefix}_id_class_dict.pkl",
output_directory=output_dir,
output_prefix=output_prefix,
split_id_dict=train_valid_id_split_dict)
# to optimize hyperparameters, set n_hyperopt_trials=100 (or alternative desired # of trials)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment