Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
1c343826
Unverified
Commit
1c343826
authored
Apr 26, 2022
by
Frank Lee
Committed by
GitHub
Apr 26, 2022
Browse files
[doc] improved assertion messages in trainer (#873)
parent
7a64fae3
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
46 additions
and
69 deletions
+46
-69
colossalai/trainer/_trainer.py
colossalai/trainer/_trainer.py
+46
-69
No files found.
colossalai/trainer/_trainer.py
View file @
1c343826
from
typing
import
Union
,
List
from
colossalai.context.parallel_mode
import
ParallelMode
from
typing
import
Union
,
List
,
Any
import
torch
from
torch
import
Tensor
from
torch.utils.data
import
DataLoader
from
tqdm
import
tqdm
from
colossalai.core
import
global_context
as
gpc
from
colossalai.engine
import
Engine
from
colossalai.logging
import
DistributedLogger
from
colossalai.utils
import
MultiTimer
...
...
@@ -53,6 +49,7 @@ class Trainer:
`Training with engine and trainer <https://www.colossalai.org/docs/basics/engine_trainer>`_
and `ColossalAI-Examples <https://github.com/hpcaitech/ColossalAI-Examples/tree/main>`_.
"""
def
__init__
(
self
,
engine
:
Engine
,
...
...
@@ -154,8 +151,7 @@ class Trainer:
@
staticmethod
def
_should_display_progress
(
display_progress
:
bool
):
"""Only display progress on DP rank 0, TP rank 0 and PP last rank"""
return
(
display_progress
and
is_dp_rank_0
()
and
is_tp_rank_0
()
and
is_no_pp_or_last_stage
())
return
(
display_progress
and
is_dp_rank_0
()
and
is_tp_rank_0
()
and
is_no_pp_or_last_stage
())
def
_train_epoch
(
self
,
...
...
@@ -189,9 +185,7 @@ class Trainer:
return_output_label
=
return_output_label
,
)
self
.
engine
.
step
()
self
.
_call_timer
(
action
=
"stop"
,
item
=
"Train-step"
,
keep_in_history
=
True
)
self
.
_call_timer
(
action
=
"stop"
,
item
=
"Train-step"
,
keep_in_history
=
True
)
self
.
_call_hooks
(
"after_train_iter"
,
output
=
(
logits
,
label
,
loss
))
self
.
_cur_step
+=
1
...
...
@@ -204,9 +198,7 @@ class Trainer:
if
self
.
_exceed_max_step
():
break
self
.
_call_timer
(
action
=
"stop"
,
item
=
"Train-epoch"
,
keep_in_history
=
True
)
self
.
_call_timer
(
action
=
"stop"
,
item
=
"Train-epoch"
,
keep_in_history
=
True
)
self
.
_call_hooks
(
"after_train_epoch"
)
self
.
_call_timer
(
action
=
"reset"
,
item
=
"Train-epoch"
)
...
...
@@ -244,19 +236,14 @@ class Trainer:
return_loss
=
True
,
return_output_label
=
return_output_label
,
)
self
.
_call_timer
(
action
=
"stop"
,
item
=
"Test-step"
,
keep_in_history
=
True
)
self
.
_call_hooks
(
"after_test_iter"
,
output
=
(
logits
,
label
,
loss
))
self
.
_call_timer
(
action
=
"stop"
,
item
=
"Test-step"
,
keep_in_history
=
True
)
self
.
_call_hooks
(
"after_test_iter"
,
output
=
(
logits
,
label
,
loss
))
if
display_progress
:
if
"step_metrics"
in
self
.
states
:
progress
.
set_postfix
(
**
self
.
states
[
"step_metrics"
])
self
.
_call_timer
(
action
=
"stop"
,
item
=
"Test-epoch"
,
keep_in_history
=
True
)
self
.
_call_timer
(
action
=
"stop"
,
item
=
"Test-epoch"
,
keep_in_history
=
True
)
self
.
_call_hooks
(
"after_test_epoch"
)
self
.
_call_hooks
(
"after_test"
)
self
.
_call_timer
(
action
=
"reset"
,
item
=
"Test-step"
)
...
...
@@ -303,9 +290,11 @@ class Trainer:
# reset hooks
self
.
_reset_states
()
if
hooks
is
not
None
:
assert
isinstance
(
hooks
,
list
),
f
"expected argument hooks be to list, but got
{
type
(
hooks
)
}
"
assert
isinstance
(
hooks
,
list
),
f
"expected argument hooks be to list, but got
{
type
(
hooks
)
}
"
for
hook
in
hooks
:
assert
isinstance
(
hook
,
BaseHook
),
\
f
'expected the hook to be of type BaseHook, but got
{
type
(
hook
)
}
'
else
:
hooks
=
[]
self
.
hooks
=
hooks
...
...
@@ -316,9 +305,7 @@ class Trainer:
f
"Using
{
hook
.
__class__
.
__name__
}
for training, priority =
{
hook
.
priority
}
"
,
ranks
=
[
0
],
)
self
.
_logger
.
info
(
"Lower value means higher priority for calling hook function"
,
ranks
=
[
0
])
self
.
_logger
.
info
(
"Lower value means higher priority for calling hook function"
,
ranks
=
[
0
])
self
.
_call_hooks
(
"after_hook_is_attached"
)
self
.
_engine
.
train
()
...
...
@@ -381,9 +368,7 @@ class Trainer:
# reset hooks
self
.
_reset_states
()
if
hooks
is
not
None
:
assert
isinstance
(
hooks
,
list
),
f
"expected argument hooks be to list, but got
{
type
(
hooks
)
}
"
assert
isinstance
(
hooks
,
list
),
f
"expected argument hooks be to list, but got
{
type
(
hooks
)
}
"
else
:
hooks
=
[]
self
.
hooks
=
hooks
...
...
@@ -394,9 +379,7 @@ class Trainer:
f
"Using
{
hook
.
__class__
.
__name__
}
for training, priority =
{
hook
.
priority
}
"
,
ranks
=
[
0
],
)
self
.
_logger
.
info
(
"Lower value means higher priority for calling hook function"
,
ranks
=
[
0
])
self
.
_logger
.
info
(
"Lower value means higher priority for calling hook function"
,
ranks
=
[
0
])
self
.
_call_hooks
(
"after_hook_is_attached"
)
# eval
...
...
@@ -406,7 +389,7 @@ class Trainer:
return_output_label
=
return_output_label
,
)
def
predict
(
self
,
data
:
Union
[
Tensor
,
List
[
Tensor
]]):
def
predict
(
self
,
data
:
Union
[
Any
,
List
[
Any
]]):
"""Uses trained model to make a prediction for a tensor or a tensor list.
Args:
...
...
@@ -416,17 +399,11 @@ class Trainer:
:class:`torch.tensor`: The output of model as the prediction
"""
# predict without labels
if
isinstance
(
data
,
(
list
,
tuple
)):
assert
isinstance
(
data
[
0
],
Tensor
)
else
:
assert
isinstance
(
data
,
Tensor
)
self
.
_engine
.
eval
()
# prepare a list of (data, label) to make it iterable
# for compatibility with schedule
simple_dataloader
=
[(
data
,
None
)]
data_iter
=
iter
(
simple_dataloader
)
output
,
_
,
_
=
self
.
engine
.
execute_schedule
(
data_iter
,
forward_only
=
True
,
return_loss
=
False
)
output
,
_
,
_
=
self
.
engine
.
execute_schedule
(
data_iter
,
forward_only
=
True
,
return_loss
=
False
)
return
output
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment