"driver/src/driver.cpp" did not exist on "05e046654c9a226444091806a418a77fe0e4a4c2"
Commit bce4d8a1 authored by wangsen's avatar wangsen
Browse files

readme.md

parent 9a5601d8
...@@ -138,16 +138,7 @@ yarl 1.9.4 ...@@ -138,16 +138,7 @@ yarl 1.9.4
``` ```
#单卡运行 #单卡运行
python geneformer/classifier.py \ python train.py
--Classifierclassifier="cell"\
--cell_state_dict = {"state_key": "disease", "states": "all"}\
--filter_data=filter_data_dict\
--training_args=training_args\
--max_ncells=None\
--freeze_layers = 2 \
--num_crossval_splits = 1\
--forward_batch_size=200\
--nproc=1
#详情请参考 Geneformer/examples/cell_classification.ipynb #详情请参考 Geneformer/examples/cell_classification.ipynb
......
...@@ -196,4 +196,3 @@ for organ in organ_list: ...@@ -196,4 +196,3 @@ for organ in organ_list:
pickle.dump(predictions, fp) pickle.dump(predictions, fp)
trainer.save_metrics("eval",predictions.metrics) trainer.save_metrics("eval",predictions.metrics)
trainer.save_model(output_dir) trainer.save_model(output_dir)
import datetime import datetime
import pickle
from geneformer import Classifier from geneformer import Classifier
import os
current_date = datetime.datetime.now() current_date = datetime.datetime.now()
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}{current_date.hour:02d}{current_date.minute:02d}{current_date.second:02d}" datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}{current_date.hour:02d}{current_date.minute:02d}{current_date.second:02d}"
datestamp_min = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}" datestamp_min = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
output_prefix = "cm_classifier_test" output_prefix = "tf_dosage_sens_test"
output_dir = f"/path/to/output_dir/{datestamp}" output_dir = f"/path/to/output_dir/{datestamp}"
# !mkdir $output_dir os.makedirs(output_dir)
with open("/path/to/Genecorpus-30M/dosage_sensitivity_TFs.pickle", "rb") as fp:
gene_class_dict = pickle.load(fp)
filter_data_dict={"cell_type":["Cardiomyocyte1","Cardiomyocyte2","Cardiomyocyte3"]} cc = Classifier(classifier="gene",
training_args = { gene_class_dict = gene_class_dict,
"num_train_epochs": 0.9, max_ncells = 10_000,
"learning_rate": 0.000804, freeze_layers = 4,
"lr_scheduler_type": "polynomial", num_crossval_splits = 5,
"warmup_steps": 1812,
"weight_decay":0.258828,
"per_device_train_batch_size": 12,
"seed": 73,
}
cc = Classifier(classifier="cell",
cell_state_dict = {"state_key": "disease", "states": "all"},
filter_data=filter_data_dict,
training_args=training_args,
max_ncells=None,
freeze_layers = 2,
num_crossval_splits = 1,
forward_batch_size=200, forward_batch_size=200,
nproc=16) nproc=16)
# previously balanced splits with prepare_data and validate functions cc.prepare_data(input_data_file="/path/to/Genecorpus-30M/dosage_sensitive_tfs",
# argument attr_to_split set to "individual" and attr_to_balance set to ["disease","lvef","age","sex","length"]
train_ids = ["1447", "1600", "1462", "1558", "1300", "1508", "1358", "1678", "1561", "1304", "1610", "1430", "1472", "1707", "1726", "1504", "1425", "1617", "1631", "1735", "1582", "1722", "1622", "1630", "1290", "1479", "1371", "1549", "1515"]
eval_ids = ["1422", "1510", "1539", "1606", "1702"]
test_ids = ["1437", "1516", "1602", "1685", "1718"]
train_test_id_split_dict = {"attr_key": "individual",
"train": train_ids+eval_ids,
"test": test_ids}
# Example input_data_file: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset
cc.prepare_data(input_data_file="/path/to/Genecorpus-30M/genecorpus_30M_2048.dataset",
output_directory=output_dir, output_directory=output_dir,
output_prefix=output_prefix, output_prefix=output_prefix)
split_id_dict=train_test_id_split_dict)
train_valid_id_split_dict = {"attr_key": "individual",
"train": train_ids,
"eval": eval_ids}
# 6 layer Geneformer: https://huggingface.co/ctheodoris/Geneformer/blob/main/model.safetensors
all_metrics = cc.validate(model_directory="/home/Geneformer", all_metrics = cc.validate(model_directory="/home/Geneformer",
prepared_input_data_file=f"{output_dir}/{output_prefix}_labeled_train.dataset", prepared_input_data_file=f"{output_dir}/{output_prefix}_labeled.dataset",
id_class_dict_file=f"{output_dir}/{output_prefix}_id_class_dict.pkl", id_class_dict_file=f"{output_dir}/{output_prefix}_id_class_dict.pkl",
output_directory=output_dir, output_directory=output_dir,
output_prefix=output_prefix, output_prefix=output_prefix)
split_id_dict=train_valid_id_split_dict)
# to optimize hyperparameters, set n_hyperopt_trials=100 (or alternative desired # of trials)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment