Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
Geneformer_pytorch
Commits
bce4d8a1
"driver/src/driver.cpp" did not exist on "05e046654c9a226444091806a418a77fe0e4a4c2"
Commit
bce4d8a1
authored
Aug 19, 2024
by
wangsen
Browse files
readme.md
parent
9a5601d8
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
19 additions
and
56 deletions
+19
-56
README.md
README.md
+1
-10
test_cell_classifier.py
test_cell_classifier.py
+1
-2
train.py
train.py
+17
-44
No files found.
README.md
View file @
bce4d8a1
...
@@ -138,16 +138,7 @@ yarl 1.9.4
...
@@ -138,16 +138,7 @@ yarl 1.9.4
```
```
#单卡运行
#单卡运行
python geneformer/classifier.py \
python train.py
--Classifierclassifier="cell"\
--cell_state_dict = {"state_key": "disease", "states": "all"}\
--filter_data=filter_data_dict\
--training_args=training_args\
--max_ncells=None\
--freeze_layers = 2 \
--num_crossval_splits = 1\
--forward_batch_size=200\
--nproc=1
#详情请参考 Geneformer/examples/cell_classification.ipynb
#详情请参考 Geneformer/examples/cell_classification.ipynb
...
...
test_cell_classifier.py
View file @
bce4d8a1
...
@@ -196,4 +196,3 @@ for organ in organ_list:
...
@@ -196,4 +196,3 @@ for organ in organ_list:
pickle
.
dump
(
predictions
,
fp
)
pickle
.
dump
(
predictions
,
fp
)
trainer
.
save_metrics
(
"eval"
,
predictions
.
metrics
)
trainer
.
save_metrics
(
"eval"
,
predictions
.
metrics
)
trainer
.
save_model
(
output_dir
)
trainer
.
save_model
(
output_dir
)
train.py
View file @
bce4d8a1
import
datetime
import
datetime
import
pickle
from
geneformer
import
Classifier
from
geneformer
import
Classifier
import
os
current_date
=
datetime
.
datetime
.
now
()
current_date
=
datetime
.
datetime
.
now
()
datestamp
=
f
"
{
str
(
current_date
.
year
)[
-
2
:]
}{
current_date
.
month
:
02
d
}{
current_date
.
day
:
02
d
}{
current_date
.
hour
:
02
d
}{
current_date
.
minute
:
02
d
}{
current_date
.
second
:
02
d
}
"
datestamp
=
f
"
{
str
(
current_date
.
year
)[
-
2
:]
}{
current_date
.
month
:
02
d
}{
current_date
.
day
:
02
d
}{
current_date
.
hour
:
02
d
}{
current_date
.
minute
:
02
d
}{
current_date
.
second
:
02
d
}
"
datestamp_min
=
f
"
{
str
(
current_date
.
year
)[
-
2
:]
}{
current_date
.
month
:
02
d
}{
current_date
.
day
:
02
d
}
"
datestamp_min
=
f
"
{
str
(
current_date
.
year
)[
-
2
:]
}{
current_date
.
month
:
02
d
}{
current_date
.
day
:
02
d
}
"
output_prefix
=
"
cm_classifier
_test"
output_prefix
=
"
tf_dosage_sens
_test"
output_dir
=
f
"/path/to/output_dir/
{
datestamp
}
"
output_dir
=
f
"/path/to/output_dir/
{
datestamp
}
"
# !mk
dir
$
output_dir
os
.
make
dir
s
(
output_dir
)
with
open
(
"/path/to/Genecorpus-30M/dosage_sensitivity_TFs.pickle"
,
"rb"
)
as
fp
:
gene_class_dict
=
pickle
.
load
(
fp
)
filter_data_dict
=
{
"cell_type"
:[
"Cardiomyocyte1"
,
"Cardiomyocyte2"
,
"Cardiomyocyte3"
]}
cc
=
Classifier
(
classifier
=
"gene"
,
training_args
=
{
gene_class_dict
=
gene_class_dict
,
"num_train_epochs"
:
0.9
,
max_ncells
=
10_000
,
"learning_rate"
:
0.000804
,
freeze_layers
=
4
,
"lr_scheduler_type"
:
"polynomial"
,
num_crossval_splits
=
5
,
"warmup_steps"
:
1812
,
"weight_decay"
:
0.258828
,
"per_device_train_batch_size"
:
12
,
"seed"
:
73
,
}
cc
=
Classifier
(
classifier
=
"cell"
,
cell_state_dict
=
{
"state_key"
:
"disease"
,
"states"
:
"all"
},
filter_data
=
filter_data_dict
,
training_args
=
training_args
,
max_ncells
=
None
,
freeze_layers
=
2
,
num_crossval_splits
=
1
,
forward_batch_size
=
200
,
forward_batch_size
=
200
,
nproc
=
16
)
nproc
=
16
)
# previously balanced splits with prepare_data and validate functions
cc
.
prepare_data
(
input_data_file
=
"/path/to/Genecorpus-30M/dosage_sensitive_tfs"
,
# argument attr_to_split set to "individual" and attr_to_balance set to ["disease","lvef","age","sex","length"]
train_ids
=
[
"1447"
,
"1600"
,
"1462"
,
"1558"
,
"1300"
,
"1508"
,
"1358"
,
"1678"
,
"1561"
,
"1304"
,
"1610"
,
"1430"
,
"1472"
,
"1707"
,
"1726"
,
"1504"
,
"1425"
,
"1617"
,
"1631"
,
"1735"
,
"1582"
,
"1722"
,
"1622"
,
"1630"
,
"1290"
,
"1479"
,
"1371"
,
"1549"
,
"1515"
]
eval_ids
=
[
"1422"
,
"1510"
,
"1539"
,
"1606"
,
"1702"
]
test_ids
=
[
"1437"
,
"1516"
,
"1602"
,
"1685"
,
"1718"
]
train_test_id_split_dict
=
{
"attr_key"
:
"individual"
,
"train"
:
train_ids
+
eval_ids
,
"test"
:
test_ids
}
# Example input_data_file: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset
cc
.
prepare_data
(
input_data_file
=
"/path/to/Genecorpus-30M/genecorpus_30M_2048.dataset"
,
output_directory
=
output_dir
,
output_directory
=
output_dir
,
output_prefix
=
output_prefix
,
output_prefix
=
output_prefix
)
split_id_dict
=
train_test_id_split_dict
)
train_valid_id_split_dict
=
{
"attr_key"
:
"individual"
,
"train"
:
train_ids
,
"eval"
:
eval_ids
}
# 6 layer Geneformer: https://huggingface.co/ctheodoris/Geneformer/blob/main/model.safetensors
all_metrics
=
cc
.
validate
(
model_directory
=
"/home/Geneformer"
,
all_metrics
=
cc
.
validate
(
model_directory
=
"/home/Geneformer"
,
prepared_input_data_file
=
f
"
{
output_dir
}
/
{
output_prefix
}
_labeled
_train
.dataset"
,
prepared_input_data_file
=
f
"
{
output_dir
}
/
{
output_prefix
}
_labeled.dataset"
,
id_class_dict_file
=
f
"
{
output_dir
}
/
{
output_prefix
}
_id_class_dict.pkl"
,
id_class_dict_file
=
f
"
{
output_dir
}
/
{
output_prefix
}
_id_class_dict.pkl"
,
output_directory
=
output_dir
,
output_directory
=
output_dir
,
output_prefix
=
output_prefix
,
output_prefix
=
output_prefix
)
split_id_dict
=
train_valid_id_split_dict
)
# to optimize hyperparameters, set n_hyperopt_trials=100 (or alternative desired # of trials)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment