Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
8546dc55
"GRUB2/git@developer.sourcefind.cn:dadigang/Ventoy.git" did not exist on "5d0fe69b2592f3c19fd12529ad94cb64ec0ed7d1"
Unverified
Commit
8546dc55
authored
Sep 29, 2020
by
Sylvain Gugger
Committed by
GitHub
Sep 29, 2020
Browse files
Fix Trainer tests in a multiGPU env (#7458)
parent
d0fd7154
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
6 deletions
+9
-6
tests/test_trainer.py
tests/test_trainer.py
+9
-6
No files found.
tests/test_trainer.py
View file @
8546dc55
...
@@ -109,12 +109,15 @@ if is_torch_available():
...
@@ -109,12 +109,15 @@ if is_torch_available():
loss
=
torch
.
nn
.
functional
.
mse_loss
(
y
,
labels
)
loss
=
torch
.
nn
.
functional
.
mse_loss
(
y
,
labels
)
return
(
loss
,
y
,
y
)
if
self
.
double_output
else
(
loss
,
y
)
return
(
loss
,
y
,
y
)
if
self
.
double_output
else
(
loss
,
y
)
def
get_regression_trainer
(
a
=
0
,
b
=
0
,
double_output
=
False
,
train_len
=
64
,
eval_len
=
64
,
**
kwargs
):
def
get_regression_trainer
(
a
=
0
,
b
=
0
,
double_output
=
False
,
train_len
=
64
,
eval_len
=
64
,
pretrained
=
True
,
**
kwargs
):
label_names
=
kwargs
.
get
(
"label_names"
,
None
)
label_names
=
kwargs
.
get
(
"label_names"
,
None
)
train_dataset
=
RegressionDataset
(
length
=
train_len
,
label_names
=
label_names
)
train_dataset
=
RegressionDataset
(
length
=
train_len
,
label_names
=
label_names
)
eval_dataset
=
RegressionDataset
(
length
=
eval_len
,
label_names
=
label_names
)
eval_dataset
=
RegressionDataset
(
length
=
eval_len
,
label_names
=
label_names
)
if
pretrained
:
config
=
RegressionModelConfig
(
a
=
a
,
b
=
b
,
double_output
=
double_output
)
config
=
RegressionModelConfig
(
a
=
a
,
b
=
b
,
double_output
=
double_output
)
model
=
RegressionPreTrainedModel
(
config
)
model
=
RegressionPreTrainedModel
(
config
)
else
:
model
=
RegressionModel
(
a
=
a
,
b
=
b
,
double_output
=
double_output
)
compute_metrics
=
kwargs
.
pop
(
"compute_metrics"
,
None
)
compute_metrics
=
kwargs
.
pop
(
"compute_metrics"
,
None
)
data_collator
=
kwargs
.
pop
(
"data_collator"
,
None
)
data_collator
=
kwargs
.
pop
(
"data_collator"
,
None
)
optimizers
=
kwargs
.
pop
(
"optimizers"
,
(
None
,
None
))
optimizers
=
kwargs
.
pop
(
"optimizers"
,
(
None
,
None
))
...
@@ -178,6 +181,7 @@ class TrainerIntegrationTest(unittest.TestCase):
...
@@ -178,6 +181,7 @@ class TrainerIntegrationTest(unittest.TestCase):
best_model
=
RegressionModel
()
best_model
=
RegressionModel
()
state_dict
=
torch
.
load
(
os
.
path
.
join
(
checkpoint
,
WEIGHTS_NAME
))
state_dict
=
torch
.
load
(
os
.
path
.
join
(
checkpoint
,
WEIGHTS_NAME
))
best_model
.
load_state_dict
(
state_dict
)
best_model
.
load_state_dict
(
state_dict
)
best_model
.
to
(
trainer
.
args
.
device
)
self
.
assertTrue
(
torch
.
allclose
(
best_model
.
a
,
trainer
.
model
.
a
))
self
.
assertTrue
(
torch
.
allclose
(
best_model
.
a
,
trainer
.
model
.
a
))
self
.
assertTrue
(
torch
.
allclose
(
best_model
.
b
,
trainer
.
model
.
b
))
self
.
assertTrue
(
torch
.
allclose
(
best_model
.
b
,
trainer
.
model
.
b
))
...
@@ -360,8 +364,7 @@ class TrainerIntegrationTest(unittest.TestCase):
...
@@ -360,8 +364,7 @@ class TrainerIntegrationTest(unittest.TestCase):
# With a regular model that is not a PreTrainedModel
# With a regular model that is not a PreTrainedModel
with
tempfile
.
TemporaryDirectory
()
as
tmpdir
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdir
:
trainer
=
get_regression_trainer
(
output_dir
=
tmpdir
,
save_steps
=
5
)
trainer
=
get_regression_trainer
(
output_dir
=
tmpdir
,
save_steps
=
5
,
pretrained
=
False
)
trainer
.
model
=
RegressionModel
()
trainer
.
train
()
trainer
.
train
()
self
.
check_saved_checkpoints
(
tmpdir
,
5
,
int
(
self
.
n_epochs
*
64
/
self
.
batch_size
),
False
)
self
.
check_saved_checkpoints
(
tmpdir
,
5
,
int
(
self
.
n_epochs
*
64
/
self
.
batch_size
),
False
)
...
@@ -426,8 +429,8 @@ class TrainerIntegrationTest(unittest.TestCase):
...
@@ -426,8 +429,8 @@ class TrainerIntegrationTest(unittest.TestCase):
eval_steps
=
5
,
eval_steps
=
5
,
evaluation_strategy
=
"steps"
,
evaluation_strategy
=
"steps"
,
load_best_model_at_end
=
True
,
load_best_model_at_end
=
True
,
pretrained
=
False
,
)
)
trainer
.
model
=
RegressionModel
(
a
=
1.5
,
b
=
2.5
)
self
.
assertFalse
(
trainer
.
args
.
greater_is_better
)
self
.
assertFalse
(
trainer
.
args
.
greater_is_better
)
trainer
.
train
()
trainer
.
train
()
self
.
check_saved_checkpoints
(
tmpdir
,
5
,
total
,
is_pretrained
=
False
)
self
.
check_saved_checkpoints
(
tmpdir
,
5
,
total
,
is_pretrained
=
False
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment