Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
Geneformer_pytorch
Commits
9a5601d8
Commit
9a5601d8
authored
Aug 19, 2024
by
wangsen
Browse files
change train.py
parent
de076fe9
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
371 additions
and
8 deletions
+371
-8
README.md
README.md
+107
-8
test_cell_classifier.py
test_cell_classifier.py
+199
-0
train.py
train.py
+65
-0
No files found.
README.md
View file @
9a5601d8
...
@@ -11,8 +11,8 @@ sudo apt-get install git-lfs
...
@@ -11,8 +11,8 @@ sudo apt-get install git-lfs
```
```
#git clone https://hf-mirror.com/datasets/ctheodoris/Genecorpus-30M
#git clone https://hf-mirror.com/datasets/ctheodoris/Genecorpus-30M
mkdir dataset
s
mkdir
cell_type_train_data.
dataset
cd dataset
s
cd
cell_type_train_data.
dataset
wget https://hf-mirror.com/datasets/ctheodoris/Genecorpus-30M/resolve/main/example_input_files/cell_classification/cell_type_annotation/cell_type_train_data.dataset/dataset.arrow
wget https://hf-mirror.com/datasets/ctheodoris/Genecorpus-30M/resolve/main/example_input_files/cell_classification/cell_type_annotation/cell_type_train_data.dataset/dataset.arrow
wget https://hf-mirror.com/datasets/ctheodoris/Genecorpus-30M/resolve/main/example_input_files/cell_classification/cell_type_annotation/cell_type_train_data.dataset/dataset_info.json
wget https://hf-mirror.com/datasets/ctheodoris/Genecorpus-30M/resolve/main/example_input_files/cell_classification/cell_type_annotation/cell_type_train_data.dataset/dataset_info.json
wget https://hf-mirror.com/datasets/ctheodoris/Genecorpus-30M/resolve/main/example_input_files/cell_classification/cell_type_annotation/cell_type_train_data.dataset/state.json
wget https://hf-mirror.com/datasets/ctheodoris/Genecorpus-30M/resolve/main/example_input_files/cell_classification/cell_type_annotation/cell_type_train_data.dataset/state.json
...
@@ -36,15 +36,109 @@ conda create -n geneformer python=3.10
...
@@ -36,15 +36,109 @@ conda create -n geneformer python=3.10
conda activate geneformer
conda activate geneformer
pip install torch #dcu版本的torch
pip install torch #dcu版本的torch
pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
```
## 部署后环境
```
accelerate 0.33.0
accumulation_tree 0.6.2
aiohappyeyeballs 2.3.6
aiohttp 3.10.3
aiosignal 1.3.1
anndata 0.10.8
array_api_compat 1.8
async-timeout 4.0.3
attrs 24.2.0
certifi 2024.7.4
charset-normalizer 3.3.2
click 8.1.7
cloudpickle 3.0.0
contourpy 1.2.1
cycler 0.12.1
datasets 2.21.0
dill 0.3.8
exceptiongroup 1.2.2
filelock 3.15.4
fonttools 4.53.1
frozenlist 1.4.1
fsspec 2024.6.1
future 1.0.0
geneformer 0.1.0
h5py 3.11.0
huggingface-hub 0.24.5
hyperopt 0.2.7
idna 3.7
Jinja2 3.1.4
joblib 1.4.2
jsonschema 4.23.0
jsonschema-specifications 2023.12.1
kiwisolver 1.4.5
legacy-api-wrap 1.4
llvmlite 0.43.0
loompy 3.0.7
MarkupSafe 2.1.5
matplotlib 3.9.2
mpmath 1.3.0
msgpack 1.0.8
multidict 6.0.5
multiprocess 0.70.16
natsort 8.4.0
networkx 3.3
numba 0.60.0
numpy 1.26.4
numpy-groupies 0.11.2
packaging 24.1
pandas 2.2.2
patsy 0.5.6
pillow 10.4.0
pip 24.2
protobuf 5.27.3
psutil 6.0.0
py4j 0.10.9.7
pyarrow 17.0.0
pynndescent 0.5.13
pyparsing 3.1.2
python-dateutil 2.9.0.post0
pytz 2024.1
pyudorandom 1.0.0
PyYAML 6.0.2
ray 2.34.0
referencing 0.35.1
regex 2024.7.24
requests 2.32.3
rpds-py 0.20.0
safetensors 0.4.4
scanpy 1.10.2
scikit-learn 1.5.1
scipy 1.14.0
seaborn 0.13.2
session_info 1.0.0
setuptools 72.1.0
six 1.16.0
statsmodels 0.14.2
stdlib-list 0.10.0
sympy 1.13.2
tdigest 0.5.2.2
threadpoolctl 3.5.0
tokenizers 0.19.1
torch 2.1.0+git540102b.abi0.dtk2404
tqdm 4.66.5
transformers 4.44.0
typing_extensions 4.12.2
tzdata 2024.1
umap-learn 0.5.6
urllib3 2.2.2
wheel 0.43.0
xxhash 3.4.1
yarl 1.9.4
```
```
# 模型训练
# 模型训练
```
```
#单卡运行
#单卡运行
python classifier.py \
python
geneformer/
classifier.py \
--Classifierclassifier="cell"\
--Classifierclassifier="cell"\
--cell_state_dict = {"state_key": "disease", "states": "all"}\
--cell_state_dict = {"state_key": "disease", "states": "all"}\
--filter_data=filter_data_dict\
--filter_data=filter_data_dict\
...
@@ -56,15 +150,20 @@ python classifier.py \
...
@@ -56,15 +150,20 @@ python classifier.py \
--nproc=1
--nproc=1
#参考 Geneformer/examples/cell_classification.ipynb
#详情请参考 Geneformer/examples/cell_classification.ipynb
```
# 或者执行
python test_cell_classifier.py # 替换py文件中dataset的路径
'''
# 模型推理
# 模型推理
```
```
python classifier.py --classifier="cell" --cell_state_dict = {"state_key": "disease", "states": "all"} --forward_batch_size=200 --nproc=1
python
geneformer/
classifier.py --classifier="cell" --cell_state_dict = {"state_key": "disease", "states": "all"} --forward_batch_size=200 --nproc=1
# 直接运行会出现报错 具体请参考Geneformer/examples/cell_classification.ipynb
#Geneformer/examples/cell_classification.ipynb
```
```
...
...
test_cell_classifier.py
0 → 100644
View file @
9a5601d8
#https://gitee.com/hf-models/Geneformer/blob/main/examples/cell_classification.ipynb
#具体可以参考
import
os
GPU_NUMBER
=
[
0
]
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
","
.
join
([
str
(
s
)
for
s
in
GPU_NUMBER
])
os
.
environ
[
"NCCL_DEBUG"
]
=
"INFO"
# imports
from
collections
import
Counter
import
datetime
import
pickle
import
subprocess
import
seaborn
as
sns
;
sns
.
set
()
from
datasets
import
load_from_disk
from
sklearn.metrics
import
accuracy_score
,
f1_score
from
transformers
import
BertForSequenceClassification
from
transformers
import
Trainer
from
transformers.training_args
import
TrainingArguments
from
geneformer
import
DataCollatorForCellClassification
# load cell type dataset (includes all tissues)
train_dataset
=
load_from_disk
(
"/genecorpus_30M_2048.dataset"
)
##更改数据集路径
dataset_list
=
[]
evalset_list
=
[]
organ_list
=
[]
target_dict_list
=
[]
for
organ
in
Counter
(
train_dataset
[
"organ_major"
]).
keys
():
# collect list of tissues for fine-tuning (immune and bone marrow are included together)
if
organ
in
[
"bone_marrow"
]:
continue
elif
organ
==
"immune"
:
organ_ids
=
[
"immune"
,
"bone_marrow"
]
organ_list
+=
[
"immune"
]
else
:
organ_ids
=
[
organ
]
organ_list
+=
[
organ
]
print
(
organ
)
# filter datasets for given organ
def
if_organ
(
example
):
return
example
[
"organ_major"
]
in
organ_ids
trainset_organ
=
train_dataset
.
filter
(
if_organ
,
num_proc
=
16
)
# per scDeepsort published method, drop cell types representing <0.5% of cells
celltype_counter
=
Counter
(
trainset_organ
[
"cell_type"
])
total_cells
=
sum
(
celltype_counter
.
values
())
cells_to_keep
=
[
k
for
k
,
v
in
celltype_counter
.
items
()
if
v
>
(
0.005
*
total_cells
)]
def
if_not_rare_celltype
(
example
):
return
example
[
"cell_type"
]
in
cells_to_keep
trainset_organ_subset
=
trainset_organ
.
filter
(
if_not_rare_celltype
,
num_proc
=
16
)
# shuffle datasets and rename columns
trainset_organ_shuffled
=
trainset_organ_subset
.
shuffle
(
seed
=
42
)
trainset_organ_shuffled
=
trainset_organ_shuffled
.
rename_column
(
"cell_type"
,
"label"
)
trainset_organ_shuffled
=
trainset_organ_shuffled
.
remove_columns
(
"organ_major"
)
# create dictionary of cell types : label ids
target_names
=
list
(
Counter
(
trainset_organ_shuffled
[
"label"
]).
keys
())
target_name_id_dict
=
dict
(
zip
(
target_names
,[
i
for
i
in
range
(
len
(
target_names
))]))
target_dict_list
+=
[
target_name_id_dict
]
# change labels to numerical ids
def
classes_to_ids
(
example
):
example
[
"label"
]
=
target_name_id_dict
[
example
[
"label"
]]
return
example
labeled_trainset
=
trainset_organ_shuffled
.
map
(
classes_to_ids
,
num_proc
=
16
)
# create 80/20 train/eval splits
labeled_train_split
=
labeled_trainset
.
select
([
i
for
i
in
range
(
0
,
round
(
len
(
labeled_trainset
)
*
0.8
))])
labeled_eval_split
=
labeled_trainset
.
select
([
i
for
i
in
range
(
round
(
len
(
labeled_trainset
)
*
0.8
),
len
(
labeled_trainset
))])
# filter dataset for cell types in corresponding training set
trained_labels
=
list
(
Counter
(
labeled_train_split
[
"label"
]).
keys
())
def
if_trained_label
(
example
):
return
example
[
"label"
]
in
trained_labels
labeled_eval_split_subset
=
labeled_eval_split
.
filter
(
if_trained_label
,
num_proc
=
16
)
dataset_list
+=
[
labeled_train_split
]
evalset_list
+=
[
labeled_eval_split_subset
]
trainset_dict
=
dict
(
zip
(
organ_list
,
dataset_list
))
traintargetdict_dict
=
dict
(
zip
(
organ_list
,
target_dict_list
))
evalset_dict
=
dict
(
zip
(
organ_list
,
evalset_list
))
def
compute_metrics
(
pred
):
labels
=
pred
.
label_ids
preds
=
pred
.
predictions
.
argmax
(
-
1
)
# calculate accuracy and macro f1 using sklearn's function
acc
=
accuracy_score
(
labels
,
preds
)
macro_f1
=
f1_score
(
labels
,
preds
,
average
=
'macro'
)
return
{
'accuracy'
:
acc
,
'macro_f1'
:
macro_f1
}
max_input_size
=
2
**
11
# 2048
# set training hyperparameters
# max learning rate
max_lr
=
5e-5
# how many pretrained layers to freeze
freeze_layers
=
0
# number gpus
num_gpus
=
1
# number cpu cores
num_proc
=
16
# batch size for training and eval
geneformer_batch_size
=
12
# learning schedule
lr_schedule_fn
=
"linear"
# warmup steps
warmup_steps
=
500
# number of epochs
epochs
=
10
# optimizer
optimizer
=
"adamw"
for
organ
in
organ_list
:
print
(
organ
)
organ_trainset
=
trainset_dict
[
organ
]
organ_evalset
=
evalset_dict
[
organ
]
organ_label_dict
=
traintargetdict_dict
[
organ
]
# set logging steps
logging_steps
=
round
(
len
(
organ_trainset
)
/
geneformer_batch_size
/
10
)
# reload pretrained model # 更改路径Geneformer 路径
model
=
BertForSequenceClassification
.
from_pretrained
(
"/home/Geneformer"
,
num_labels
=
len
(
organ_label_dict
.
keys
()),
output_attentions
=
False
,
output_hidden_states
=
False
).
to
(
"cuda"
)
# define output directory path
current_date
=
datetime
.
datetime
.
now
()
datestamp
=
f
"
{
str
(
current_date
.
year
)[
-
2
:]
}{
current_date
.
month
:
02
d
}{
current_date
.
day
:
02
d
}
"
output_dir
=
f
"/path/to/models/
{
datestamp
}
_geneformer_CellClassifier_
{
organ
}
_L
{
max_input_size
}
_B
{
geneformer_batch_size
}
_LR
{
max_lr
}
_LS
{
lr_schedule_fn
}
_WU
{
warmup_steps
}
_E
{
epochs
}
_O
{
optimizer
}
_F
{
freeze_layers
}
/"
#
# ensure not overwriting previously saved model
saved_model_test
=
os
.
path
.
join
(
output_dir
,
f
"pytorch_model.bin"
)
if
os
.
path
.
isfile
(
saved_model_test
)
==
True
:
raise
Exception
(
"Model already saved to this directory."
)
# make output directory
subprocess
.
call
(
f
'mkdir
{
output_dir
}
'
,
shell
=
True
)
# set training arguments
training_args
=
{
"learning_rate"
:
max_lr
,
"do_train"
:
True
,
"do_eval"
:
True
,
"evaluation_strategy"
:
"epoch"
,
"save_strategy"
:
"epoch"
,
"logging_steps"
:
logging_steps
,
"group_by_length"
:
True
,
"length_column_name"
:
"length"
,
"disable_tqdm"
:
False
,
"lr_scheduler_type"
:
lr_schedule_fn
,
"warmup_steps"
:
warmup_steps
,
"weight_decay"
:
0.001
,
"per_device_train_batch_size"
:
geneformer_batch_size
,
"per_device_eval_batch_size"
:
geneformer_batch_size
,
"num_train_epochs"
:
epochs
,
"load_best_model_at_end"
:
True
,
"output_dir"
:
output_dir
,
}
training_args_init
=
TrainingArguments
(
**
training_args
)
# create the trainer
trainer
=
Trainer
(
model
=
model
,
args
=
training_args_init
,
data_collator
=
DataCollatorForCellClassification
(),
train_dataset
=
organ_trainset
,
eval_dataset
=
organ_evalset
,
compute_metrics
=
compute_metrics
)
# train the cell type classifier
trainer
.
train
()
predictions
=
trainer
.
predict
(
organ_evalset
)
with
open
(
f
"
{
output_dir
}
predictions.pickle"
,
"wb"
)
as
fp
:
pickle
.
dump
(
predictions
,
fp
)
trainer
.
save_metrics
(
"eval"
,
predictions
.
metrics
)
trainer
.
save_model
(
output_dir
)
train.py
0 → 100644
View file @
9a5601d8
import
datetime
from
geneformer
import
Classifier
current_date
=
datetime
.
datetime
.
now
()
datestamp
=
f
"
{
str
(
current_date
.
year
)[
-
2
:]
}{
current_date
.
month
:
02
d
}{
current_date
.
day
:
02
d
}{
current_date
.
hour
:
02
d
}{
current_date
.
minute
:
02
d
}{
current_date
.
second
:
02
d
}
"
datestamp_min
=
f
"
{
str
(
current_date
.
year
)[
-
2
:]
}{
current_date
.
month
:
02
d
}{
current_date
.
day
:
02
d
}
"
output_prefix
=
"cm_classifier_test"
output_dir
=
f
"/path/to/output_dir/
{
datestamp
}
"
# !mkdir $output_dir
filter_data_dict
=
{
"cell_type"
:[
"Cardiomyocyte1"
,
"Cardiomyocyte2"
,
"Cardiomyocyte3"
]}
training_args
=
{
"num_train_epochs"
:
0.9
,
"learning_rate"
:
0.000804
,
"lr_scheduler_type"
:
"polynomial"
,
"warmup_steps"
:
1812
,
"weight_decay"
:
0.258828
,
"per_device_train_batch_size"
:
12
,
"seed"
:
73
,
}
cc
=
Classifier
(
classifier
=
"cell"
,
cell_state_dict
=
{
"state_key"
:
"disease"
,
"states"
:
"all"
},
filter_data
=
filter_data_dict
,
training_args
=
training_args
,
max_ncells
=
None
,
freeze_layers
=
2
,
num_crossval_splits
=
1
,
forward_batch_size
=
200
,
nproc
=
16
)
# previously balanced splits with prepare_data and validate functions
# argument attr_to_split set to "individual" and attr_to_balance set to ["disease","lvef","age","sex","length"]
train_ids
=
[
"1447"
,
"1600"
,
"1462"
,
"1558"
,
"1300"
,
"1508"
,
"1358"
,
"1678"
,
"1561"
,
"1304"
,
"1610"
,
"1430"
,
"1472"
,
"1707"
,
"1726"
,
"1504"
,
"1425"
,
"1617"
,
"1631"
,
"1735"
,
"1582"
,
"1722"
,
"1622"
,
"1630"
,
"1290"
,
"1479"
,
"1371"
,
"1549"
,
"1515"
]
eval_ids
=
[
"1422"
,
"1510"
,
"1539"
,
"1606"
,
"1702"
]
test_ids
=
[
"1437"
,
"1516"
,
"1602"
,
"1685"
,
"1718"
]
train_test_id_split_dict
=
{
"attr_key"
:
"individual"
,
"train"
:
train_ids
+
eval_ids
,
"test"
:
test_ids
}
# Example input_data_file: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset
cc
.
prepare_data
(
input_data_file
=
"/path/to/Genecorpus-30M/genecorpus_30M_2048.dataset"
,
output_directory
=
output_dir
,
output_prefix
=
output_prefix
,
split_id_dict
=
train_test_id_split_dict
)
train_valid_id_split_dict
=
{
"attr_key"
:
"individual"
,
"train"
:
train_ids
,
"eval"
:
eval_ids
}
# 6 layer Geneformer: https://huggingface.co/ctheodoris/Geneformer/blob/main/model.safetensors
all_metrics
=
cc
.
validate
(
model_directory
=
"/home/Geneformer"
,
prepared_input_data_file
=
f
"
{
output_dir
}
/
{
output_prefix
}
_labeled_train.dataset"
,
id_class_dict_file
=
f
"
{
output_dir
}
/
{
output_prefix
}
_id_class_dict.pkl"
,
output_directory
=
output_dir
,
output_prefix
=
output_prefix
,
split_id_dict
=
train_valid_id_split_dict
)
# to optimize hyperparameters, set n_hyperopt_trials=100 (or alternative desired # of trials)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment