Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
c7d49329
Commit
c7d49329
authored
Sep 08, 2022
by
LuGY
Committed by
Frank Lee
Sep 08, 2022
Browse files
[NFC] polish colossalai/utils/tensor_detector/tensor_detector.py code style (#1566)
parent
0c4c9aa6
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
32 additions
and
36 deletions
+32
-36
colossalai/utils/tensor_detector/tensor_detector.py
colossalai/utils/tensor_detector/tensor_detector.py
+32
-36
No files found.
colossalai/utils/tensor_detector/tensor_detector.py
View file @
c7d49329
...
@@ -5,18 +5,17 @@ import torch.nn as nn
...
@@ -5,18 +5,17 @@ import torch.nn as nn
from
typing
import
Optional
from
typing
import
Optional
from
collections
import
defaultdict
from
collections
import
defaultdict
LINE_WIDTH
=
108
LINE_WIDTH
=
108
LINE
=
'-'
*
LINE_WIDTH
+
'
\n
'
LINE
=
'-'
*
LINE_WIDTH
+
'
\n
'
class
TensorDetector
():
class
TensorDetector
():
def
__init__
(
self
,
def
__init__
(
self
,
show_info
:
bool
=
True
,
show_info
:
bool
=
True
,
log
:
str
=
None
,
log
:
str
=
None
,
include_cpu
:
bool
=
False
,
include_cpu
:
bool
=
False
,
module
:
Optional
[
nn
.
Module
]
=
None
module
:
Optional
[
nn
.
Module
]
=
None
):
):
"""This class is a detector to detect tensor on different devices.
"""This class is a detector to detect tensor on different devices.
Args:
Args:
...
@@ -57,12 +56,12 @@ class TensorDetector():
...
@@ -57,12 +56,12 @@ class TensorDetector():
def
mem_format
(
self
,
real_memory_size
):
def
mem_format
(
self
,
real_memory_size
):
# format the tensor memory into a reasonal magnitude
# format the tensor memory into a reasonal magnitude
if
real_memory_size
>=
2
**
30
:
if
real_memory_size
>=
2
**
30
:
return
str
(
real_memory_size
/
(
2
**
30
))
+
' GB'
return
str
(
real_memory_size
/
(
2
**
30
))
+
' GB'
if
real_memory_size
>=
2
**
20
:
if
real_memory_size
>=
2
**
20
:
return
str
(
real_memory_size
/
(
2
**
20
))
+
' MB'
return
str
(
real_memory_size
/
(
2
**
20
))
+
' MB'
if
real_memory_size
>=
2
**
10
:
if
real_memory_size
>=
2
**
10
:
return
str
(
real_memory_size
/
(
2
**
10
))
+
' KB'
return
str
(
real_memory_size
/
(
2
**
10
))
+
' KB'
return
str
(
real_memory_size
)
+
' B'
return
str
(
real_memory_size
)
+
' B'
def
collect_tensors_state
(
self
):
def
collect_tensors_state
(
self
):
...
@@ -125,8 +124,7 @@ class TensorDetector():
...
@@ -125,8 +124,7 @@ class TensorDetector():
minus
=
outdated
+
minus
minus
=
outdated
+
minus
if
len
(
self
.
order
)
>
0
:
if
len
(
self
.
order
)
>
0
:
for
tensor_id
in
self
.
order
:
for
tensor_id
in
self
.
order
:
self
.
info
+=
template_format
.
format
(
'+'
,
self
.
info
+=
template_format
.
format
(
'+'
,
str
(
self
.
tensor_info
[
tensor_id
][
0
]),
str
(
self
.
tensor_info
[
tensor_id
][
0
]),
str
(
self
.
tensor_info
[
tensor_id
][
1
]),
str
(
self
.
tensor_info
[
tensor_id
][
1
]),
str
(
tuple
(
self
.
tensor_info
[
tensor_id
][
2
])),
str
(
tuple
(
self
.
tensor_info
[
tensor_id
][
2
])),
str
(
self
.
tensor_info
[
tensor_id
][
3
]),
str
(
self
.
tensor_info
[
tensor_id
][
3
]),
...
@@ -137,8 +135,7 @@ class TensorDetector():
...
@@ -137,8 +135,7 @@ class TensorDetector():
self
.
info
+=
'
\n
'
self
.
info
+=
'
\n
'
if
len
(
minus
)
>
0
:
if
len
(
minus
)
>
0
:
for
tensor_id
in
minus
:
for
tensor_id
in
minus
:
self
.
info
+=
template_format
.
format
(
'-'
,
self
.
info
+=
template_format
.
format
(
'-'
,
str
(
self
.
saved_tensor_info
[
tensor_id
][
0
]),
str
(
self
.
saved_tensor_info
[
tensor_id
][
0
]),
str
(
self
.
saved_tensor_info
[
tensor_id
][
1
]),
str
(
self
.
saved_tensor_info
[
tensor_id
][
1
]),
str
(
tuple
(
self
.
saved_tensor_info
[
tensor_id
][
2
])),
str
(
tuple
(
self
.
saved_tensor_info
[
tensor_id
][
2
])),
str
(
self
.
saved_tensor_info
[
tensor_id
][
3
]),
str
(
self
.
saved_tensor_info
[
tensor_id
][
3
]),
...
@@ -148,7 +145,6 @@ class TensorDetector():
...
@@ -148,7 +145,6 @@ class TensorDetector():
# deleted the updated tensor
# deleted the updated tensor
self
.
saved_tensor_info
.
pop
(
tensor_id
)
self
.
saved_tensor_info
.
pop
(
tensor_id
)
# trace where is the detect()
# trace where is the detect()
locate_info
=
inspect
.
stack
()[
2
]
locate_info
=
inspect
.
stack
()[
2
]
locate_msg
=
'"'
+
locate_info
.
filename
+
'" line '
+
str
(
locate_info
.
lineno
)
locate_msg
=
'"'
+
locate_info
.
filename
+
'" line '
+
str
(
locate_info
.
lineno
)
...
@@ -168,7 +164,7 @@ class TensorDetector():
...
@@ -168,7 +164,7 @@ class TensorDetector():
with
open
(
self
.
log
+
'.log'
,
'a'
)
as
f
:
with
open
(
self
.
log
+
'.log'
,
'a'
)
as
f
:
f
.
write
(
self
.
info
)
f
.
write
(
self
.
info
)
def
detect
(
self
,
include_cpu
=
False
):
def
detect
(
self
,
include_cpu
=
False
):
self
.
include_cpu
=
include_cpu
self
.
include_cpu
=
include_cpu
self
.
collect_tensors_state
()
self
.
collect_tensors_state
()
self
.
print_tensors_state
()
self
.
print_tensors_state
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment