Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
cfb41297
Commit
cfb41297
authored
Mar 31, 2022
by
yuxuan-lou
Committed by
binmakeswell
Apr 06, 2022
Browse files
'fix/format' (#573)
parent
b0f708df
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
6 additions
and
6 deletions
+6
-6
colossalai/engine/ophooks/_memtracer_ophook.py
colossalai/engine/ophooks/_memtracer_ophook.py
+2
-2
colossalai/engine/schedule/_base_schedule.py
colossalai/engine/schedule/_base_schedule.py
+4
-4
No files found.
colossalai/engine/ophooks/_memtracer_ophook.py
View file @
cfb41297
...
@@ -106,7 +106,7 @@ class MemTracerOpHook(BaseOpHook):
...
@@ -106,7 +106,7 @@ class MemTracerOpHook(BaseOpHook):
# output file info
# output file info
self
.
_logger
.
info
(
f
"dump a memory statistics as pickle to
{
self
.
_data_prefix
}
-
{
self
.
_rank
}
.pkl"
)
self
.
_logger
.
info
(
f
"dump a memory statistics as pickle to
{
self
.
_data_prefix
}
-
{
self
.
_rank
}
.pkl"
)
home_dir
=
Path
.
home
()
home_dir
=
Path
.
home
()
with
open
(
home_dir
.
joinpath
(
f
".cache/colossal/mem-
{
self
.
_rank
}
.pkl"
),
"wb"
)
as
f
:
with
open
(
home_dir
.
joinpath
(
f
".cache/colossal/mem-
{
self
.
_rank
}
.pkl"
),
"wb"
)
as
f
:
pickle
.
dump
(
self
.
async_mem_monitor
.
state_dict
,
f
)
pickle
.
dump
(
self
.
async_mem_monitor
.
state_dict
,
f
)
self
.
_count
+=
1
self
.
_count
+=
1
self
.
_logger
.
debug
(
f
"data file has been refreshed
{
self
.
_count
}
times"
)
self
.
_logger
.
debug
(
f
"data file has been refreshed
{
self
.
_count
}
times"
)
...
@@ -115,4 +115,4 @@ class MemTracerOpHook(BaseOpHook):
...
@@ -115,4 +115,4 @@ class MemTracerOpHook(BaseOpHook):
def
save_results
(
self
,
data_file
:
Union
[
str
,
Path
]):
def
save_results
(
self
,
data_file
:
Union
[
str
,
Path
]):
with
open
(
data_file
,
"w"
)
as
f
:
with
open
(
data_file
,
"w"
)
as
f
:
f
.
write
(
json
.
dumps
(
self
.
async_mem_monitor
.
state_dict
))
f
.
write
(
json
.
dumps
(
self
.
async_mem_monitor
.
state_dict
))
\ No newline at end of file
colossalai/engine/schedule/_base_schedule.py
View file @
cfb41297
...
@@ -85,8 +85,7 @@ class BaseSchedule(ABC):
...
@@ -85,8 +85,7 @@ class BaseSchedule(ABC):
data_iter
:
Iterable
,
data_iter
:
Iterable
,
forward_only
:
bool
,
forward_only
:
bool
,
return_loss
:
bool
=
True
,
return_loss
:
bool
=
True
,
return_output_label
:
bool
=
True
return_output_label
:
bool
=
True
):
):
"""The process function over a batch of dataset for training or evaluation.
"""The process function over a batch of dataset for training or evaluation.
Args:
Args:
...
@@ -107,8 +106,9 @@ class BaseSchedule(ABC):
...
@@ -107,8 +106,9 @@ class BaseSchedule(ABC):
@
staticmethod
@
staticmethod
def
_call_engine_criterion
(
engine
,
outputs
,
labels
):
def
_call_engine_criterion
(
engine
,
outputs
,
labels
):
assert
isinstance
(
outputs
,
(
torch
.
Tensor
,
list
,
tuple
)
assert
isinstance
(
),
f
'Expect output of model is (torch.Tensor, list, tuple), got
{
type
(
outputs
)
}
'
outputs
,
(
torch
.
Tensor
,
list
,
tuple
)),
f
'Expect output of model is (torch.Tensor, list, tuple), got
{
type
(
outputs
)
}
'
if
isinstance
(
outputs
,
torch
.
Tensor
):
if
isinstance
(
outputs
,
torch
.
Tensor
):
outputs
=
(
outputs
,)
outputs
=
(
outputs
,)
if
isinstance
(
labels
,
torch
.
Tensor
):
if
isinstance
(
labels
,
torch
.
Tensor
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment