Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
9823cbf2
Commit
9823cbf2
authored
Sep 08, 2022
by
Zangwei Zheng
Committed by
Frank Lee
Sep 08, 2022
Browse files
[NFC] polish colossalai/gemini/update/chunkv2.py code style (#1565)
parent
f586887a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
25 deletions
+20
-25
colossalai/gemini/update/chunkv2.py
colossalai/gemini/update/chunkv2.py
+20
-25
No files found.
colossalai/gemini/update/chunkv2.py
View file @
9823cbf2
...
...
@@ -9,6 +9,7 @@ from colossalai.gemini.chunk import TensorState, STATE_TRANS, TensorInfo, ChunkF
class
ChunkV2
:
def
__init__
(
self
,
chunk_size
:
int
,
process_group
:
ColoProcessGroup
,
...
...
@@ -49,9 +50,9 @@ class ChunkV2:
self
.
dtype
=
dtype
device
=
init_device
or
get_current_device
()
self
.
chunk_temp
=
torch
.
zeros
(
chunk_size
,
dtype
=
dtype
,
device
=
device
)
# keep all zero
self
.
chunk_total
=
None
# we force chunk_total located in CUDA
self
.
cuda_shard
=
None
# using two attributes for the better interpretation
self
.
chunk_temp
=
torch
.
zeros
(
chunk_size
,
dtype
=
dtype
,
device
=
device
)
# keep all zero
self
.
chunk_total
=
None
# we force chunk_total located in CUDA
self
.
cuda_shard
=
None
# using two attributes for the better interpretation
self
.
cpu_shard
=
None
self
.
is_gathered
=
True
...
...
@@ -71,7 +72,7 @@ class ChunkV2:
# so their computation patterns are the same as that of the parameters in DDP
self
.
keep_gathered
=
keep_gathered
if
self
.
keep_gathered
:
pin_memory
=
False
# since this chunk is gathered, it doesn't need to pin
pin_memory
=
False
# since this chunk is gathered, it doesn't need to pin
# if pin_memory is True, we allocate a piece of CPU pin-memory
# for it all the time
...
...
@@ -137,9 +138,9 @@ class ChunkV2:
if
new_utilized_size
>
self
.
chunk_size
:
raise
ChunkFullError
self
.
chunk_temp
[
self
.
utilized_size
:
new_utilized_size
].
copy_
(
tensor
.
data
.
flatten
())
self
.
chunk_temp
[
self
.
utilized_size
:
new_utilized_size
].
copy_
(
tensor
.
data
.
flatten
())
assert
type
(
self
.
chunk_temp
)
==
torch
.
Tensor
,
"copy_tensor_to_chunk_slice must use a torch tensor"
tensor
.
data
=
self
.
chunk_temp
[
self
.
utilized_size
:
new_utilized_size
].
view
(
tensor
.
shape
)
tensor
.
data
=
self
.
chunk_temp
[
self
.
utilized_size
:
new_utilized_size
].
view
(
tensor
.
shape
)
# record all the information about the tensor
self
.
num_tensors
+=
1
...
...
@@ -177,11 +178,9 @@ class ChunkV2:
shard_dev
=
torch
.
device
(
'cpu'
)
if
self
.
pin_memory
or
shard_dev
.
type
==
'cpu'
:
self
.
cpu_shard
=
torch
.
empty
(
self
.
shard_size
,
dtype
=
self
.
dtype
,
pin_memory
=
self
.
pin_memory
)
self
.
cpu_shard
=
torch
.
empty
(
self
.
shard_size
,
dtype
=
self
.
dtype
,
pin_memory
=
self
.
pin_memory
)
self
.
cpu_shard
.
copy_
(
self
.
cuda_shard
)
self
.
cpu_vis_flag
=
True
# cpu_shard has been visited
self
.
cpu_vis_flag
=
True
# cpu_shard has been visited
if
shard_dev
.
type
==
'cpu'
:
self
.
cuda_shard
=
None
...
...
@@ -260,8 +259,7 @@ class ChunkV2:
# we use all-reduce here
dist
.
all_reduce
(
self
.
chunk_total
,
group
=
self
.
torch_pg
)
else
:
self
.
cuda_shard
=
torch
.
empty
(
self
.
shard_size
,
dtype
=
self
.
dtype
,
device
=
get_current_device
())
self
.
cuda_shard
=
torch
.
empty
(
self
.
shard_size
,
dtype
=
self
.
dtype
,
device
=
get_current_device
())
input_list
=
list
(
torch
.
chunk
(
self
.
chunk_total
,
chunks
=
self
.
pg_size
,
dim
=
0
))
dist
.
reduce_scatter
(
self
.
cuda_shard
,
input_list
,
group
=
self
.
torch_pg
)
...
...
@@ -330,10 +328,10 @@ class ChunkV2:
Check if the chunk has inf or nan values in CUDA.
"""
if
self
.
is_gathered
:
valid_tensor
=
self
.
chunk_total
[:
self
.
utilized_size
]
valid_tensor
=
self
.
chunk_total
[:
self
.
utilized_size
]
else
:
assert
self
.
cuda_shard
is
not
None
# only check in CUDA
valid_tensor
=
self
.
cuda_shard
[:
self
.
valid_end
]
assert
self
.
cuda_shard
is
not
None
# only check in CUDA
valid_tensor
=
self
.
cuda_shard
[:
self
.
valid_end
]
return
torch
.
isinf
(
valid_tensor
).
any
().
item
()
|
torch
.
isnan
(
valid_tensor
).
any
().
item
()
...
...
@@ -346,8 +344,7 @@ class ChunkV2:
self
.
chunk_total
=
self
.
cuda_shard
else
:
alloc_storage
(
self
.
chunk_total
)
gather_list
=
list
(
torch
.
chunk
(
input
=
self
.
chunk_total
,
chunks
=
self
.
pg_size
,
dim
=
0
))
gather_list
=
list
(
torch
.
chunk
(
input
=
self
.
chunk_total
,
chunks
=
self
.
pg_size
,
dim
=
0
))
dist
.
all_gather
(
gather_list
,
self
.
cuda_shard
,
self
.
torch_pg
)
self
.
cuda_shard
=
None
...
...
@@ -361,11 +358,9 @@ class ChunkV2:
# sanity check
assert
self
.
cuda_shard
is
None
self
.
cuda_shard
=
torch
.
empty
(
self
.
shard_size
,
dtype
=
self
.
dtype
,
device
=
self
.
chunk_total
.
device
)
self
.
cuda_shard
=
torch
.
empty
(
self
.
shard_size
,
dtype
=
self
.
dtype
,
device
=
self
.
chunk_total
.
device
)
self
.
cuda_shard
.
copy_
(
self
.
chunk_total
[
self
.
shard_begin
:
self
.
shard_end
])
self
.
cuda_shard
.
copy_
(
self
.
chunk_total
[
self
.
shard_begin
:
self
.
shard_end
])
free_storage
(
self
.
chunk_total
)
self
.
is_gathered
=
False
...
...
@@ -412,15 +407,15 @@ class ChunkV2:
def
__repr__
(
self
,
detailed
:
bool
=
False
):
output
=
[
"AgChunk Information:
\n
"
,
"
\t
chunk size: {}, chunk dtype: {}, process group size: {}
\n
"
.
format
(
self
.
chunk_size
,
self
.
dtype
,
self
.
pg_size
),
"
\t
chunk size: {}, chunk dtype: {}, process group size: {}
\n
"
.
format
(
self
.
chunk_size
,
self
.
dtype
,
self
.
pg_size
),
"
\t
# of tensors: {}, utilized size: {}, utilized percentage: {:.2f}
\n
"
.
format
(
self
.
num_tensors
,
self
.
utilized_size
,
self
.
utilized_size
/
self
.
chunk_size
)
]
def
print_tensor
(
tensor
,
prefix
=
''
):
output
.
append
(
"{}shape: {}, dtype: {}, device: {}
\n
"
.
format
(
prefix
,
tensor
.
shape
,
tensor
.
dtype
,
tensor
.
device
))
output
.
append
(
"{}shape: {}, dtype: {}, device: {}
\n
"
.
format
(
prefix
,
tensor
.
shape
,
tensor
.
dtype
,
tensor
.
device
))
if
self
.
chunk_temp
is
not
None
:
output
.
append
(
"
\t
chunk temp:
\n
"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment