Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
6a025487
Unverified
Commit
6a025487
authored
Dec 12, 2021
by
Suraj Patil
Committed by
GitHub
Dec 12, 2021
Browse files
[Flax examples] remove dependancy on pytorch training args (#14636)
* use custom training arguments * update tests
parent
027074f4
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
457 additions
and
17 deletions
+457
-17
examples/flax/language-modeling/run_clm_flax.py
examples/flax/language-modeling/run_clm_flax.py
+64
-2
examples/flax/language-modeling/run_mlm_flax.py
examples/flax/language-modeling/run_mlm_flax.py
+64
-2
examples/flax/language-modeling/run_t5_mlm_flax.py
examples/flax/language-modeling/run_t5_mlm_flax.py
+66
-3
examples/flax/question-answering/run_qa.py
examples/flax/question-answering/run_qa.py
+65
-2
examples/flax/summarization/run_summarization_flax.py
examples/flax/summarization/run_summarization_flax.py
+68
-2
examples/flax/test_examples.py
examples/flax/test_examples.py
+2
-2
examples/flax/token-classification/run_flax_ner.py
examples/flax/token-classification/run_flax_ner.py
+64
-2
examples/flax/vision/run_image_classification.py
examples/flax/vision/run_image_classification.py
+64
-2
No files found.
examples/flax/language-modeling/run_clm_flax.py
View file @
6a025487
...
...
@@ -27,7 +27,8 @@ import math
import
os
import
sys
import
time
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
asdict
,
dataclass
,
field
from
enum
import
Enum
from
itertools
import
chain
from
pathlib
import
Path
from
typing
import
Callable
,
Optional
...
...
@@ -53,7 +54,6 @@ from transformers import (
AutoTokenizer
,
FlaxAutoModelForCausalLM
,
HfArgumentParser
,
TrainingArguments
,
is_tensorboard_available
,
set_seed
,
)
...
...
@@ -67,6 +67,68 @@ MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_CAUSAL_LM_MAPPING.keys())
MODEL_TYPES
=
tuple
(
conf
.
model_type
for
conf
in
MODEL_CONFIG_CLASSES
)
@
dataclass
class
TrainingArguments
:
output_dir
:
str
=
field
(
metadata
=
{
"help"
:
"The output directory where the model predictions and checkpoints will be written."
},
)
overwrite_output_dir
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
(
"Overwrite the content of the output directory. "
"Use this to continue training if output_dir points to a checkpoint directory."
)
},
)
do_train
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether to run training."
})
do_eval
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether to run eval on the dev set."
})
per_device_train_batch_size
:
int
=
field
(
default
=
8
,
metadata
=
{
"help"
:
"Batch size per GPU/TPU core/CPU for training."
}
)
per_device_eval_batch_size
:
int
=
field
(
default
=
8
,
metadata
=
{
"help"
:
"Batch size per GPU/TPU core/CPU for evaluation."
}
)
learning_rate
:
float
=
field
(
default
=
5e-5
,
metadata
=
{
"help"
:
"The initial learning rate for AdamW."
})
weight_decay
:
float
=
field
(
default
=
0.0
,
metadata
=
{
"help"
:
"Weight decay for AdamW if we apply some."
})
adam_beta1
:
float
=
field
(
default
=
0.9
,
metadata
=
{
"help"
:
"Beta1 for AdamW optimizer"
})
adam_beta2
:
float
=
field
(
default
=
0.999
,
metadata
=
{
"help"
:
"Beta2 for AdamW optimizer"
})
adam_epsilon
:
float
=
field
(
default
=
1e-8
,
metadata
=
{
"help"
:
"Epsilon for AdamW optimizer."
})
adafactor
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to replace AdamW by Adafactor."
})
num_train_epochs
:
float
=
field
(
default
=
3.0
,
metadata
=
{
"help"
:
"Total number of training epochs to perform."
})
warmup_steps
:
int
=
field
(
default
=
0
,
metadata
=
{
"help"
:
"Linear warmup over warmup_steps."
})
logging_steps
:
int
=
field
(
default
=
500
,
metadata
=
{
"help"
:
"Log every X updates steps."
})
save_steps
:
int
=
field
(
default
=
500
,
metadata
=
{
"help"
:
"Save checkpoint every X updates steps."
})
eval_steps
:
int
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Run an evaluation every X steps."
})
seed
:
int
=
field
(
default
=
42
,
metadata
=
{
"help"
:
"Random seed that will be set at the beginning of training."
})
push_to_hub
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to upload the trained model to the model hub after training."
}
)
hub_model_id
:
str
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The name of the repository to keep in sync with the local `output_dir`."
}
)
hub_token
:
str
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The token to use to push to the Model Hub."
})
def
__post_init__
(
self
):
if
self
.
output_dir
is
not
None
:
self
.
output_dir
=
os
.
path
.
expanduser
(
self
.
output_dir
)
def
to_dict
(
self
):
"""
Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates
the token values by removing their value.
"""
d
=
asdict
(
self
)
for
k
,
v
in
d
.
items
():
if
isinstance
(
v
,
Enum
):
d
[
k
]
=
v
.
value
if
isinstance
(
v
,
list
)
and
len
(
v
)
>
0
and
isinstance
(
v
[
0
],
Enum
):
d
[
k
]
=
[
x
.
value
for
x
in
v
]
if
k
.
endswith
(
"_token"
):
d
[
k
]
=
f
"<
{
k
.
upper
()
}
>"
return
d
@
dataclass
class
ModelArguments
:
"""
...
...
examples/flax/language-modeling/run_mlm_flax.py
View file @
6a025487
...
...
@@ -26,7 +26,8 @@ import math
import
os
import
sys
import
time
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
asdict
,
dataclass
,
field
from
enum
import
Enum
from
itertools
import
chain
# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
...
...
@@ -54,7 +55,6 @@ from transformers import (
HfArgumentParser
,
PreTrainedTokenizerBase
,
TensorType
,
TrainingArguments
,
is_tensorboard_available
,
set_seed
,
)
...
...
@@ -65,6 +65,68 @@ MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
MODEL_TYPES
=
tuple
(
conf
.
model_type
for
conf
in
MODEL_CONFIG_CLASSES
)
@
dataclass
class
TrainingArguments
:
output_dir
:
str
=
field
(
metadata
=
{
"help"
:
"The output directory where the model predictions and checkpoints will be written."
},
)
overwrite_output_dir
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
(
"Overwrite the content of the output directory. "
"Use this to continue training if output_dir points to a checkpoint directory."
)
},
)
do_train
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether to run training."
})
do_eval
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether to run eval on the dev set."
})
per_device_train_batch_size
:
int
=
field
(
default
=
8
,
metadata
=
{
"help"
:
"Batch size per GPU/TPU core/CPU for training."
}
)
per_device_eval_batch_size
:
int
=
field
(
default
=
8
,
metadata
=
{
"help"
:
"Batch size per GPU/TPU core/CPU for evaluation."
}
)
learning_rate
:
float
=
field
(
default
=
5e-5
,
metadata
=
{
"help"
:
"The initial learning rate for AdamW."
})
weight_decay
:
float
=
field
(
default
=
0.0
,
metadata
=
{
"help"
:
"Weight decay for AdamW if we apply some."
})
adam_beta1
:
float
=
field
(
default
=
0.9
,
metadata
=
{
"help"
:
"Beta1 for AdamW optimizer"
})
adam_beta2
:
float
=
field
(
default
=
0.999
,
metadata
=
{
"help"
:
"Beta2 for AdamW optimizer"
})
adam_epsilon
:
float
=
field
(
default
=
1e-8
,
metadata
=
{
"help"
:
"Epsilon for AdamW optimizer."
})
adafactor
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to replace AdamW by Adafactor."
})
num_train_epochs
:
float
=
field
(
default
=
3.0
,
metadata
=
{
"help"
:
"Total number of training epochs to perform."
})
warmup_steps
:
int
=
field
(
default
=
0
,
metadata
=
{
"help"
:
"Linear warmup over warmup_steps."
})
logging_steps
:
int
=
field
(
default
=
500
,
metadata
=
{
"help"
:
"Log every X updates steps."
})
save_steps
:
int
=
field
(
default
=
500
,
metadata
=
{
"help"
:
"Save checkpoint every X updates steps."
})
eval_steps
:
int
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Run an evaluation every X steps."
})
seed
:
int
=
field
(
default
=
42
,
metadata
=
{
"help"
:
"Random seed that will be set at the beginning of training."
})
push_to_hub
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to upload the trained model to the model hub after training."
}
)
hub_model_id
:
str
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The name of the repository to keep in sync with the local `output_dir`."
}
)
hub_token
:
str
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The token to use to push to the Model Hub."
})
def
__post_init__
(
self
):
if
self
.
output_dir
is
not
None
:
self
.
output_dir
=
os
.
path
.
expanduser
(
self
.
output_dir
)
def
to_dict
(
self
):
"""
Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates
the token values by removing their value.
"""
d
=
asdict
(
self
)
for
k
,
v
in
d
.
items
():
if
isinstance
(
v
,
Enum
):
d
[
k
]
=
v
.
value
if
isinstance
(
v
,
list
)
and
len
(
v
)
>
0
and
isinstance
(
v
[
0
],
Enum
):
d
[
k
]
=
[
x
.
value
for
x
in
v
]
if
k
.
endswith
(
"_token"
):
d
[
k
]
=
f
"<
{
k
.
upper
()
}
>"
return
d
@
dataclass
class
ModelArguments
:
"""
...
...
examples/flax/language-modeling/run_t5_mlm_flax.py
View file @
6a025487
...
...
@@ -19,13 +19,15 @@ Pretraining the library models for T5-like span-masked language modeling on a te
Here is the full list of checkpoints on the hub that can be pretrained by this script:
https://huggingface.co/models?filter=t5
"""
# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
import
json
import
logging
import
os
import
sys
import
time
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
asdict
,
dataclass
,
field
# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
from
enum
import
Enum
from
itertools
import
chain
from
pathlib
import
Path
from
typing
import
Dict
,
List
,
Optional
...
...
@@ -51,7 +53,6 @@ from transformers import (
HfArgumentParser
,
PreTrainedTokenizerBase
,
T5Config
,
TrainingArguments
,
is_tensorboard_available
,
set_seed
,
)
...
...
@@ -63,6 +64,68 @@ MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
MODEL_TYPES
=
tuple
(
conf
.
model_type
for
conf
in
MODEL_CONFIG_CLASSES
)
@
dataclass
class
TrainingArguments
:
output_dir
:
str
=
field
(
metadata
=
{
"help"
:
"The output directory where the model predictions and checkpoints will be written."
},
)
overwrite_output_dir
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
(
"Overwrite the content of the output directory. "
"Use this to continue training if output_dir points to a checkpoint directory."
)
},
)
do_train
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether to run training."
})
do_eval
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether to run eval on the dev set."
})
per_device_train_batch_size
:
int
=
field
(
default
=
8
,
metadata
=
{
"help"
:
"Batch size per GPU/TPU core/CPU for training."
}
)
per_device_eval_batch_size
:
int
=
field
(
default
=
8
,
metadata
=
{
"help"
:
"Batch size per GPU/TPU core/CPU for evaluation."
}
)
learning_rate
:
float
=
field
(
default
=
5e-5
,
metadata
=
{
"help"
:
"The initial learning rate for AdamW."
})
weight_decay
:
float
=
field
(
default
=
0.0
,
metadata
=
{
"help"
:
"Weight decay for AdamW if we apply some."
})
adam_beta1
:
float
=
field
(
default
=
0.9
,
metadata
=
{
"help"
:
"Beta1 for AdamW optimizer"
})
adam_beta2
:
float
=
field
(
default
=
0.999
,
metadata
=
{
"help"
:
"Beta2 for AdamW optimizer"
})
adam_epsilon
:
float
=
field
(
default
=
1e-8
,
metadata
=
{
"help"
:
"Epsilon for AdamW optimizer."
})
adafactor
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to replace AdamW by Adafactor."
})
num_train_epochs
:
float
=
field
(
default
=
3.0
,
metadata
=
{
"help"
:
"Total number of training epochs to perform."
})
warmup_steps
:
int
=
field
(
default
=
0
,
metadata
=
{
"help"
:
"Linear warmup over warmup_steps."
})
logging_steps
:
int
=
field
(
default
=
500
,
metadata
=
{
"help"
:
"Log every X updates steps."
})
save_steps
:
int
=
field
(
default
=
500
,
metadata
=
{
"help"
:
"Save checkpoint every X updates steps."
})
eval_steps
:
int
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Run an evaluation every X steps."
})
seed
:
int
=
field
(
default
=
42
,
metadata
=
{
"help"
:
"Random seed that will be set at the beginning of training."
})
push_to_hub
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to upload the trained model to the model hub after training."
}
)
hub_model_id
:
str
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The name of the repository to keep in sync with the local `output_dir`."
}
)
hub_token
:
str
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The token to use to push to the Model Hub."
})
def
__post_init__
(
self
):
if
self
.
output_dir
is
not
None
:
self
.
output_dir
=
os
.
path
.
expanduser
(
self
.
output_dir
)
def
to_dict
(
self
):
"""
Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates
the token values by removing their value.
"""
d
=
asdict
(
self
)
for
k
,
v
in
d
.
items
():
if
isinstance
(
v
,
Enum
):
d
[
k
]
=
v
.
value
if
isinstance
(
v
,
list
)
and
len
(
v
)
>
0
and
isinstance
(
v
[
0
],
Enum
):
d
[
k
]
=
[
x
.
value
for
x
in
v
]
if
k
.
endswith
(
"_token"
):
d
[
k
]
=
f
"<
{
k
.
upper
()
}
>"
return
d
@
dataclass
class
ModelArguments
:
"""
...
...
examples/flax/question-answering/run_qa.py
View file @
6a025487
...
...
@@ -24,7 +24,8 @@ import os
import
random
import
sys
import
time
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
asdict
,
dataclass
,
field
from
enum
import
Enum
from
itertools
import
chain
from
pathlib
import
Path
from
typing
import
Any
,
Callable
,
Dict
,
Optional
,
Tuple
...
...
@@ -50,7 +51,6 @@ from transformers import (
FlaxAutoModelForQuestionAnswering
,
HfArgumentParser
,
PreTrainedTokenizerFast
,
TrainingArguments
,
is_tensorboard_available
,
)
from
transformers.file_utils
import
get_full_repo_name
...
...
@@ -69,6 +69,69 @@ PRNGKey = Any
# region Arguments
@
dataclass
class
TrainingArguments
:
output_dir
:
str
=
field
(
metadata
=
{
"help"
:
"The output directory where the model predictions and checkpoints will be written."
},
)
overwrite_output_dir
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
(
"Overwrite the content of the output directory. "
"Use this to continue training if output_dir points to a checkpoint directory."
)
},
)
do_train
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether to run training."
})
do_eval
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether to run eval on the dev set."
})
do_predict
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether to run predictions on the test set."
})
per_device_train_batch_size
:
int
=
field
(
default
=
8
,
metadata
=
{
"help"
:
"Batch size per GPU/TPU core/CPU for training."
}
)
per_device_eval_batch_size
:
int
=
field
(
default
=
8
,
metadata
=
{
"help"
:
"Batch size per GPU/TPU core/CPU for evaluation."
}
)
learning_rate
:
float
=
field
(
default
=
5e-5
,
metadata
=
{
"help"
:
"The initial learning rate for AdamW."
})
weight_decay
:
float
=
field
(
default
=
0.0
,
metadata
=
{
"help"
:
"Weight decay for AdamW if we apply some."
})
adam_beta1
:
float
=
field
(
default
=
0.9
,
metadata
=
{
"help"
:
"Beta1 for AdamW optimizer"
})
adam_beta2
:
float
=
field
(
default
=
0.999
,
metadata
=
{
"help"
:
"Beta2 for AdamW optimizer"
})
adam_epsilon
:
float
=
field
(
default
=
1e-8
,
metadata
=
{
"help"
:
"Epsilon for AdamW optimizer."
})
adafactor
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to replace AdamW by Adafactor."
})
num_train_epochs
:
float
=
field
(
default
=
3.0
,
metadata
=
{
"help"
:
"Total number of training epochs to perform."
})
warmup_steps
:
int
=
field
(
default
=
0
,
metadata
=
{
"help"
:
"Linear warmup over warmup_steps."
})
logging_steps
:
int
=
field
(
default
=
500
,
metadata
=
{
"help"
:
"Log every X updates steps."
})
save_steps
:
int
=
field
(
default
=
500
,
metadata
=
{
"help"
:
"Save checkpoint every X updates steps."
})
eval_steps
:
int
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Run an evaluation every X steps."
})
seed
:
int
=
field
(
default
=
42
,
metadata
=
{
"help"
:
"Random seed that will be set at the beginning of training."
})
push_to_hub
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to upload the trained model to the model hub after training."
}
)
hub_model_id
:
str
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The name of the repository to keep in sync with the local `output_dir`."
}
)
hub_token
:
str
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The token to use to push to the Model Hub."
})
def
__post_init__
(
self
):
if
self
.
output_dir
is
not
None
:
self
.
output_dir
=
os
.
path
.
expanduser
(
self
.
output_dir
)
def
to_dict
(
self
):
"""
Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates
the token values by removing their value.
"""
d
=
asdict
(
self
)
for
k
,
v
in
d
.
items
():
if
isinstance
(
v
,
Enum
):
d
[
k
]
=
v
.
value
if
isinstance
(
v
,
list
)
and
len
(
v
)
>
0
and
isinstance
(
v
[
0
],
Enum
):
d
[
k
]
=
[
x
.
value
for
x
in
v
]
if
k
.
endswith
(
"_token"
):
d
[
k
]
=
f
"<
{
k
.
upper
()
}
>"
return
d
@
dataclass
class
ModelArguments
:
"""
...
...
examples/flax/summarization/run_summarization_flax.py
View file @
6a025487
...
...
@@ -23,7 +23,8 @@ import logging
import
os
import
sys
import
time
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
asdict
,
dataclass
,
field
from
enum
import
Enum
from
functools
import
partial
from
pathlib
import
Path
from
typing
import
Callable
,
Optional
...
...
@@ -51,7 +52,6 @@ from transformers import (
AutoTokenizer
,
FlaxAutoModelForSeq2SeqLM
,
HfArgumentParser
,
TrainingArguments
,
is_tensorboard_available
,
)
from
transformers.file_utils
import
get_full_repo_name
,
is_offline_mode
...
...
@@ -74,6 +74,72 @@ MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys())
MODEL_TYPES
=
tuple
(
conf
.
model_type
for
conf
in
MODEL_CONFIG_CLASSES
)
@
dataclass
class
TrainingArguments
:
output_dir
:
str
=
field
(
metadata
=
{
"help"
:
"The output directory where the model predictions and checkpoints will be written."
},
)
overwrite_output_dir
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
(
"Overwrite the content of the output directory. "
"Use this to continue training if output_dir points to a checkpoint directory."
)
},
)
do_train
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether to run training."
})
do_eval
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether to run eval on the dev set."
})
do_predict
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether to run predictions on the test set."
})
per_device_train_batch_size
:
int
=
field
(
default
=
8
,
metadata
=
{
"help"
:
"Batch size per GPU/TPU core/CPU for training."
}
)
per_device_eval_batch_size
:
int
=
field
(
default
=
8
,
metadata
=
{
"help"
:
"Batch size per GPU/TPU core/CPU for evaluation."
}
)
learning_rate
:
float
=
field
(
default
=
5e-5
,
metadata
=
{
"help"
:
"The initial learning rate for AdamW."
})
weight_decay
:
float
=
field
(
default
=
0.0
,
metadata
=
{
"help"
:
"Weight decay for AdamW if we apply some."
})
adam_beta1
:
float
=
field
(
default
=
0.9
,
metadata
=
{
"help"
:
"Beta1 for AdamW optimizer"
})
adam_beta2
:
float
=
field
(
default
=
0.999
,
metadata
=
{
"help"
:
"Beta2 for AdamW optimizer"
})
adam_epsilon
:
float
=
field
(
default
=
1e-8
,
metadata
=
{
"help"
:
"Epsilon for AdamW optimizer."
})
label_smoothing_factor
:
float
=
field
(
default
=
0.0
,
metadata
=
{
"help"
:
"The label smoothing epsilon to apply (zero means no label smoothing)."
}
)
adafactor
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to replace AdamW by Adafactor."
})
num_train_epochs
:
float
=
field
(
default
=
3.0
,
metadata
=
{
"help"
:
"Total number of training epochs to perform."
})
warmup_steps
:
int
=
field
(
default
=
0
,
metadata
=
{
"help"
:
"Linear warmup over warmup_steps."
})
logging_steps
:
int
=
field
(
default
=
500
,
metadata
=
{
"help"
:
"Log every X updates steps."
})
save_steps
:
int
=
field
(
default
=
500
,
metadata
=
{
"help"
:
"Save checkpoint every X updates steps."
})
eval_steps
:
int
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Run an evaluation every X steps."
})
seed
:
int
=
field
(
default
=
42
,
metadata
=
{
"help"
:
"Random seed that will be set at the beginning of training."
})
push_to_hub
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to upload the trained model to the model hub after training."
}
)
hub_model_id
:
str
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The name of the repository to keep in sync with the local `output_dir`."
}
)
hub_token
:
str
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The token to use to push to the Model Hub."
})
def
__post_init__
(
self
):
if
self
.
output_dir
is
not
None
:
self
.
output_dir
=
os
.
path
.
expanduser
(
self
.
output_dir
)
def
to_dict
(
self
):
"""
Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates
the token values by removing their value.
"""
d
=
asdict
(
self
)
for
k
,
v
in
d
.
items
():
if
isinstance
(
v
,
Enum
):
d
[
k
]
=
v
.
value
if
isinstance
(
v
,
list
)
and
len
(
v
)
>
0
and
isinstance
(
v
[
0
],
Enum
):
d
[
k
]
=
[
x
.
value
for
x
in
v
]
if
k
.
endswith
(
"_token"
):
d
[
k
]
=
f
"<
{
k
.
upper
()
}
>"
return
d
@
dataclass
class
ModelArguments
:
"""
...
...
examples/flax/test_examples.py
View file @
6a025487
...
...
@@ -137,7 +137,7 @@ class ExamplesTests(TestCasePlus):
--test_file tests/fixtures/tests_samples/xsum/sample.json
--output_dir
{
tmp_dir
}
--overwrite_output_dir
--
max_steps=50
--
num_train_epochs=3
--warmup_steps=8
--do_train
--do_eval
...
...
@@ -257,7 +257,7 @@ class ExamplesTests(TestCasePlus):
--validation_file tests/fixtures/tests_samples/SQUAD/sample.json
--output_dir
{
tmp_dir
}
--overwrite_output_dir
--
max_steps=10
--
num_train_epochs=3
--warmup_steps=2
--do_train
--do_eval
...
...
examples/flax/token-classification/run_flax_ner.py
View file @
6a025487
...
...
@@ -20,7 +20,8 @@ import os
import
random
import
sys
import
time
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
asdict
,
dataclass
,
field
from
enum
import
Enum
from
itertools
import
chain
from
pathlib
import
Path
from
typing
import
Any
,
Callable
,
Dict
,
Optional
,
Tuple
...
...
@@ -44,7 +45,6 @@ from transformers import (
AutoTokenizer
,
FlaxAutoModelForTokenClassification
,
HfArgumentParser
,
TrainingArguments
,
is_tensorboard_available
,
)
from
transformers.file_utils
import
get_full_repo_name
...
...
@@ -63,6 +63,68 @@ Dataset = datasets.arrow_dataset.Dataset
PRNGKey
=
Any
@
dataclass
class
TrainingArguments
:
output_dir
:
str
=
field
(
metadata
=
{
"help"
:
"The output directory where the model predictions and checkpoints will be written."
},
)
overwrite_output_dir
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
(
"Overwrite the content of the output directory. "
"Use this to continue training if output_dir points to a checkpoint directory."
)
},
)
do_train
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether to run training."
})
do_eval
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether to run eval on the dev set."
})
per_device_train_batch_size
:
int
=
field
(
default
=
8
,
metadata
=
{
"help"
:
"Batch size per GPU/TPU core/CPU for training."
}
)
per_device_eval_batch_size
:
int
=
field
(
default
=
8
,
metadata
=
{
"help"
:
"Batch size per GPU/TPU core/CPU for evaluation."
}
)
learning_rate
:
float
=
field
(
default
=
5e-5
,
metadata
=
{
"help"
:
"The initial learning rate for AdamW."
})
weight_decay
:
float
=
field
(
default
=
0.0
,
metadata
=
{
"help"
:
"Weight decay for AdamW if we apply some."
})
adam_beta1
:
float
=
field
(
default
=
0.9
,
metadata
=
{
"help"
:
"Beta1 for AdamW optimizer"
})
adam_beta2
:
float
=
field
(
default
=
0.999
,
metadata
=
{
"help"
:
"Beta2 for AdamW optimizer"
})
adam_epsilon
:
float
=
field
(
default
=
1e-8
,
metadata
=
{
"help"
:
"Epsilon for AdamW optimizer."
})
adafactor
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to replace AdamW by Adafactor."
})
num_train_epochs
:
float
=
field
(
default
=
3.0
,
metadata
=
{
"help"
:
"Total number of training epochs to perform."
})
warmup_steps
:
int
=
field
(
default
=
0
,
metadata
=
{
"help"
:
"Linear warmup over warmup_steps."
})
logging_steps
:
int
=
field
(
default
=
500
,
metadata
=
{
"help"
:
"Log every X updates steps."
})
save_steps
:
int
=
field
(
default
=
500
,
metadata
=
{
"help"
:
"Save checkpoint every X updates steps."
})
eval_steps
:
int
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Run an evaluation every X steps."
})
seed
:
int
=
field
(
default
=
42
,
metadata
=
{
"help"
:
"Random seed that will be set at the beginning of training."
})
push_to_hub
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to upload the trained model to the model hub after training."
}
)
hub_model_id
:
str
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The name of the repository to keep in sync with the local `output_dir`."
}
)
hub_token
:
str
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The token to use to push to the Model Hub."
})
def
__post_init__
(
self
):
if
self
.
output_dir
is
not
None
:
self
.
output_dir
=
os
.
path
.
expanduser
(
self
.
output_dir
)
def
to_dict
(
self
):
"""
Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates
the token values by removing their value.
"""
d
=
asdict
(
self
)
for
k
,
v
in
d
.
items
():
if
isinstance
(
v
,
Enum
):
d
[
k
]
=
v
.
value
if
isinstance
(
v
,
list
)
and
len
(
v
)
>
0
and
isinstance
(
v
[
0
],
Enum
):
d
[
k
]
=
[
x
.
value
for
x
in
v
]
if
k
.
endswith
(
"_token"
):
d
[
k
]
=
f
"<
{
k
.
upper
()
}
>"
return
d
@
dataclass
class
ModelArguments
:
"""
...
...
examples/flax/vision/run_image_classification.py
View file @
6a025487
...
...
@@ -24,7 +24,8 @@ import logging
import
os
import
sys
import
time
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
asdict
,
dataclass
,
field
from
enum
import
Enum
from
pathlib
import
Path
from
typing
import
Callable
,
Optional
...
...
@@ -49,7 +50,6 @@ from transformers import (
AutoConfig
,
FlaxAutoModelForImageClassification
,
HfArgumentParser
,
TrainingArguments
,
is_tensorboard_available
,
set_seed
,
)
...
...
@@ -63,6 +63,68 @@ MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING.keys())
MODEL_TYPES
=
tuple
(
conf
.
model_type
for
conf
in
MODEL_CONFIG_CLASSES
)
@
dataclass
class
TrainingArguments
:
output_dir
:
str
=
field
(
metadata
=
{
"help"
:
"The output directory where the model predictions and checkpoints will be written."
},
)
overwrite_output_dir
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
(
"Overwrite the content of the output directory. "
"Use this to continue training if output_dir points to a checkpoint directory."
)
},
)
do_train
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether to run training."
})
do_eval
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether to run eval on the dev set."
})
per_device_train_batch_size
:
int
=
field
(
default
=
8
,
metadata
=
{
"help"
:
"Batch size per GPU/TPU core/CPU for training."
}
)
per_device_eval_batch_size
:
int
=
field
(
default
=
8
,
metadata
=
{
"help"
:
"Batch size per GPU/TPU core/CPU for evaluation."
}
)
learning_rate
:
float
=
field
(
default
=
5e-5
,
metadata
=
{
"help"
:
"The initial learning rate for AdamW."
})
weight_decay
:
float
=
field
(
default
=
0.0
,
metadata
=
{
"help"
:
"Weight decay for AdamW if we apply some."
})
adam_beta1
:
float
=
field
(
default
=
0.9
,
metadata
=
{
"help"
:
"Beta1 for AdamW optimizer"
})
adam_beta2
:
float
=
field
(
default
=
0.999
,
metadata
=
{
"help"
:
"Beta2 for AdamW optimizer"
})
adam_epsilon
:
float
=
field
(
default
=
1e-8
,
metadata
=
{
"help"
:
"Epsilon for AdamW optimizer."
})
adafactor
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to replace AdamW by Adafactor."
})
num_train_epochs
:
float
=
field
(
default
=
3.0
,
metadata
=
{
"help"
:
"Total number of training epochs to perform."
})
warmup_steps
:
int
=
field
(
default
=
0
,
metadata
=
{
"help"
:
"Linear warmup over warmup_steps."
})
logging_steps
:
int
=
field
(
default
=
500
,
metadata
=
{
"help"
:
"Log every X updates steps."
})
save_steps
:
int
=
field
(
default
=
500
,
metadata
=
{
"help"
:
"Save checkpoint every X updates steps."
})
eval_steps
:
int
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Run an evaluation every X steps."
})
seed
:
int
=
field
(
default
=
42
,
metadata
=
{
"help"
:
"Random seed that will be set at the beginning of training."
})
push_to_hub
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to upload the trained model to the model hub after training."
}
)
hub_model_id
:
str
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The name of the repository to keep in sync with the local `output_dir`."
}
)
hub_token
:
str
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The token to use to push to the Model Hub."
})
def
__post_init__
(
self
):
if
self
.
output_dir
is
not
None
:
self
.
output_dir
=
os
.
path
.
expanduser
(
self
.
output_dir
)
def
to_dict
(
self
):
"""
Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates
the token values by removing their value.
"""
d
=
asdict
(
self
)
for
k
,
v
in
d
.
items
():
if
isinstance
(
v
,
Enum
):
d
[
k
]
=
v
.
value
if
isinstance
(
v
,
list
)
and
len
(
v
)
>
0
and
isinstance
(
v
[
0
],
Enum
):
d
[
k
]
=
[
x
.
value
for
x
in
v
]
if
k
.
endswith
(
"_token"
):
d
[
k
]
=
f
"<
{
k
.
upper
()
}
>"
return
d
@
dataclass
class
ModelArguments
:
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment