import datetime from geneformer import Classifier current_date = datetime.datetime.now() datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}{current_date.hour:02d}{current_date.minute:02d}{current_date.second:02d}" datestamp_min = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}" output_prefix = "cm_classifier_test" output_dir = f"/path/to/output_dir/{datestamp}" # !mkdir $output_dir filter_data_dict={"cell_type":["Cardiomyocyte1","Cardiomyocyte2","Cardiomyocyte3"]} training_args = { "num_train_epochs": 0.9, "learning_rate": 0.000804, "lr_scheduler_type": "polynomial", "warmup_steps": 1812, "weight_decay":0.258828, "per_device_train_batch_size": 12, "seed": 73, } cc = Classifier(classifier="cell", cell_state_dict = {"state_key": "disease", "states": "all"}, filter_data=filter_data_dict, training_args=training_args, max_ncells=None, freeze_layers = 2, num_crossval_splits = 1, forward_batch_size=200, nproc=16) # previously balanced splits with prepare_data and validate functions # argument attr_to_split set to "individual" and attr_to_balance set to ["disease","lvef","age","sex","length"] train_ids = ["1447", "1600", "1462", "1558", "1300", "1508", "1358", "1678", "1561", "1304", "1610", "1430", "1472", "1707", "1726", "1504", "1425", "1617", "1631", "1735", "1582", "1722", "1622", "1630", "1290", "1479", "1371", "1549", "1515"] eval_ids = ["1422", "1510", "1539", "1606", "1702"] test_ids = ["1437", "1516", "1602", "1685", "1718"] train_test_id_split_dict = {"attr_key": "individual", "train": train_ids+eval_ids, "test": test_ids} # Example input_data_file: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset cc.prepare_data(input_data_file="/path/to/Genecorpus-30M/genecorpus_30M_2048.dataset", output_directory=output_dir, output_prefix=output_prefix, split_id_dict=train_test_id_split_dict) train_valid_id_split_dict = {"attr_key": "individual", "train": train_ids, "eval": eval_ids} # 6 layer Geneformer: https://huggingface.co/ctheodoris/Geneformer/blob/main/model.safetensors all_metrics = cc.validate(model_directory="/home/Geneformer", prepared_input_data_file=f"{output_dir}/{output_prefix}_labeled_train.dataset", id_class_dict_file=f"{output_dir}/{output_prefix}_id_class_dict.pkl", output_directory=output_dir, output_prefix=output_prefix, split_id_dict=train_valid_id_split_dict) # to optimize hyperparameters, set n_hyperopt_trials=100 (or alternative desired # of trials)