Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
1c343826
Unverified
Commit
1c343826
authored
Apr 26, 2022
by
Frank Lee
Committed by
GitHub
Apr 26, 2022
Browse files
[doc] improved assertion messages in trainer (#873)
parent
7a64fae3
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
46 additions
and
69 deletions
+46
-69
colossalai/trainer/_trainer.py
colossalai/trainer/_trainer.py
+46
-69
No files found.
colossalai/trainer/_trainer.py
View file @
1c343826
from
typing
import
Union
,
List
from
typing
import
Union
,
List
,
Any
from
colossalai.context.parallel_mode
import
ParallelMode
import
torch
import
torch
from
torch
import
Tensor
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
colossalai.core
import
global_context
as
gpc
from
colossalai.engine
import
Engine
from
colossalai.engine
import
Engine
from
colossalai.logging
import
DistributedLogger
from
colossalai.logging
import
DistributedLogger
from
colossalai.utils
import
MultiTimer
from
colossalai.utils
import
MultiTimer
...
@@ -53,6 +49,7 @@ class Trainer:
...
@@ -53,6 +49,7 @@ class Trainer:
`Training with engine and trainer <https://www.colossalai.org/docs/basics/engine_trainer>`_
`Training with engine and trainer <https://www.colossalai.org/docs/basics/engine_trainer>`_
and `ColossalAI-Examples <https://github.com/hpcaitech/ColossalAI-Examples/tree/main>`_.
and `ColossalAI-Examples <https://github.com/hpcaitech/ColossalAI-Examples/tree/main>`_.
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
engine
:
Engine
,
engine
:
Engine
,
...
@@ -154,8 +151,7 @@ class Trainer:
...
@@ -154,8 +151,7 @@ class Trainer:
@
staticmethod
@
staticmethod
def
_should_display_progress
(
display_progress
:
bool
):
def
_should_display_progress
(
display_progress
:
bool
):
"""Only display progress on DP rank 0, TP rank 0 and PP last rank"""
"""Only display progress on DP rank 0, TP rank 0 and PP last rank"""
return
(
display_progress
and
is_dp_rank_0
()
and
is_tp_rank_0
()
return
(
display_progress
and
is_dp_rank_0
()
and
is_tp_rank_0
()
and
is_no_pp_or_last_stage
())
and
is_no_pp_or_last_stage
())
def
_train_epoch
(
def
_train_epoch
(
self
,
self
,
...
@@ -189,9 +185,7 @@ class Trainer:
...
@@ -189,9 +185,7 @@ class Trainer:
return_output_label
=
return_output_label
,
return_output_label
=
return_output_label
,
)
)
self
.
engine
.
step
()
self
.
engine
.
step
()
self
.
_call_timer
(
action
=
"stop"
,
self
.
_call_timer
(
action
=
"stop"
,
item
=
"Train-step"
,
keep_in_history
=
True
)
item
=
"Train-step"
,
keep_in_history
=
True
)
self
.
_call_hooks
(
"after_train_iter"
,
output
=
(
logits
,
label
,
loss
))
self
.
_call_hooks
(
"after_train_iter"
,
output
=
(
logits
,
label
,
loss
))
self
.
_cur_step
+=
1
self
.
_cur_step
+=
1
...
@@ -204,9 +198,7 @@ class Trainer:
...
@@ -204,9 +198,7 @@ class Trainer:
if
self
.
_exceed_max_step
():
if
self
.
_exceed_max_step
():
break
break
self
.
_call_timer
(
action
=
"stop"
,
self
.
_call_timer
(
action
=
"stop"
,
item
=
"Train-epoch"
,
keep_in_history
=
True
)
item
=
"Train-epoch"
,
keep_in_history
=
True
)
self
.
_call_hooks
(
"after_train_epoch"
)
self
.
_call_hooks
(
"after_train_epoch"
)
self
.
_call_timer
(
action
=
"reset"
,
item
=
"Train-epoch"
)
self
.
_call_timer
(
action
=
"reset"
,
item
=
"Train-epoch"
)
...
@@ -244,19 +236,14 @@ class Trainer:
...
@@ -244,19 +236,14 @@ class Trainer:
return_loss
=
True
,
return_loss
=
True
,
return_output_label
=
return_output_label
,
return_output_label
=
return_output_label
,
)
)
self
.
_call_timer
(
action
=
"stop"
,
self
.
_call_timer
(
action
=
"stop"
,
item
=
"Test-step"
,
keep_in_history
=
True
)
item
=
"Test-step"
,
self
.
_call_hooks
(
"after_test_iter"
,
output
=
(
logits
,
label
,
loss
))
keep_in_history
=
True
)
self
.
_call_hooks
(
"after_test_iter"
,
output
=
(
logits
,
label
,
loss
))
if
display_progress
:
if
display_progress
:
if
"step_metrics"
in
self
.
states
:
if
"step_metrics"
in
self
.
states
:
progress
.
set_postfix
(
**
self
.
states
[
"step_metrics"
])
progress
.
set_postfix
(
**
self
.
states
[
"step_metrics"
])
self
.
_call_timer
(
action
=
"stop"
,
self
.
_call_timer
(
action
=
"stop"
,
item
=
"Test-epoch"
,
keep_in_history
=
True
)
item
=
"Test-epoch"
,
keep_in_history
=
True
)
self
.
_call_hooks
(
"after_test_epoch"
)
self
.
_call_hooks
(
"after_test_epoch"
)
self
.
_call_hooks
(
"after_test"
)
self
.
_call_hooks
(
"after_test"
)
self
.
_call_timer
(
action
=
"reset"
,
item
=
"Test-step"
)
self
.
_call_timer
(
action
=
"reset"
,
item
=
"Test-step"
)
...
@@ -303,9 +290,11 @@ class Trainer:
...
@@ -303,9 +290,11 @@ class Trainer:
# reset hooks
# reset hooks
self
.
_reset_states
()
self
.
_reset_states
()
if
hooks
is
not
None
:
if
hooks
is
not
None
:
assert
isinstance
(
assert
isinstance
(
hooks
,
list
),
f
"expected argument hooks be to list, but got
{
type
(
hooks
)
}
"
hooks
,
list
),
f
"expected argument hooks be to list, but got
{
type
(
hooks
)
}
"
for
hook
in
hooks
:
assert
isinstance
(
hook
,
BaseHook
),
\
f
'expected the hook to be of type BaseHook, but got
{
type
(
hook
)
}
'
else
:
else
:
hooks
=
[]
hooks
=
[]
self
.
hooks
=
hooks
self
.
hooks
=
hooks
...
@@ -316,9 +305,7 @@ class Trainer:
...
@@ -316,9 +305,7 @@ class Trainer:
f
"Using
{
hook
.
__class__
.
__name__
}
for training, priority =
{
hook
.
priority
}
"
,
f
"Using
{
hook
.
__class__
.
__name__
}
for training, priority =
{
hook
.
priority
}
"
,
ranks
=
[
0
],
ranks
=
[
0
],
)
)
self
.
_logger
.
info
(
self
.
_logger
.
info
(
"Lower value means higher priority for calling hook function"
,
ranks
=
[
0
])
"Lower value means higher priority for calling hook function"
,
ranks
=
[
0
])
self
.
_call_hooks
(
"after_hook_is_attached"
)
self
.
_call_hooks
(
"after_hook_is_attached"
)
self
.
_engine
.
train
()
self
.
_engine
.
train
()
...
@@ -381,9 +368,7 @@ class Trainer:
...
@@ -381,9 +368,7 @@ class Trainer:
# reset hooks
# reset hooks
self
.
_reset_states
()
self
.
_reset_states
()
if
hooks
is
not
None
:
if
hooks
is
not
None
:
assert
isinstance
(
assert
isinstance
(
hooks
,
list
),
f
"expected argument hooks be to list, but got
{
type
(
hooks
)
}
"
hooks
,
list
),
f
"expected argument hooks be to list, but got
{
type
(
hooks
)
}
"
else
:
else
:
hooks
=
[]
hooks
=
[]
self
.
hooks
=
hooks
self
.
hooks
=
hooks
...
@@ -394,9 +379,7 @@ class Trainer:
...
@@ -394,9 +379,7 @@ class Trainer:
f
"Using
{
hook
.
__class__
.
__name__
}
for training, priority =
{
hook
.
priority
}
"
,
f
"Using
{
hook
.
__class__
.
__name__
}
for training, priority =
{
hook
.
priority
}
"
,
ranks
=
[
0
],
ranks
=
[
0
],
)
)
self
.
_logger
.
info
(
self
.
_logger
.
info
(
"Lower value means higher priority for calling hook function"
,
ranks
=
[
0
])
"Lower value means higher priority for calling hook function"
,
ranks
=
[
0
])
self
.
_call_hooks
(
"after_hook_is_attached"
)
self
.
_call_hooks
(
"after_hook_is_attached"
)
# eval
# eval
...
@@ -406,7 +389,7 @@ class Trainer:
...
@@ -406,7 +389,7 @@ class Trainer:
return_output_label
=
return_output_label
,
return_output_label
=
return_output_label
,
)
)
def
predict
(
self
,
data
:
Union
[
Tensor
,
List
[
Tensor
]]):
def
predict
(
self
,
data
:
Union
[
Any
,
List
[
Any
]]):
"""Uses trained model to make a prediction for a tensor or a tensor list.
"""Uses trained model to make a prediction for a tensor or a tensor list.
Args:
Args:
...
@@ -416,17 +399,11 @@ class Trainer:
...
@@ -416,17 +399,11 @@ class Trainer:
:class:`torch.tensor`: The output of model as the prediction
:class:`torch.tensor`: The output of model as the prediction
"""
"""
# predict without labels
# predict without labels
if
isinstance
(
data
,
(
list
,
tuple
)):
assert
isinstance
(
data
[
0
],
Tensor
)
else
:
assert
isinstance
(
data
,
Tensor
)
self
.
_engine
.
eval
()
self
.
_engine
.
eval
()
# prepare a list of (data, label) to make it iterable
# prepare a list of (data, label) to make it iterable
# for compatibility with schedule
# for compatibility with schedule
simple_dataloader
=
[(
data
,
None
)]
simple_dataloader
=
[(
data
,
None
)]
data_iter
=
iter
(
simple_dataloader
)
data_iter
=
iter
(
simple_dataloader
)
output
,
_
,
_
=
self
.
engine
.
execute_schedule
(
data_iter
,
output
,
_
,
_
=
self
.
engine
.
execute_schedule
(
data_iter
,
forward_only
=
True
,
return_loss
=
False
)
forward_only
=
True
,
return_loss
=
False
)
return
output
return
output
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment