Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
Uni-Fold_pytorch
Commits
a1c29028
Commit
a1c29028
authored
Apr 17, 2023
by
zhangqha
Browse files
update uni-fold
parents
Pipeline
#183
canceled with stages
Changes
312
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2378 additions
and
0 deletions
+2378
-0
Uni-Core-main/unicore/data/lru_cache_dataset.py
Uni-Core-main/unicore/data/lru_cache_dataset.py
+22
-0
Uni-Core-main/unicore/data/mask_tokens_dataset.py
Uni-Core-main/unicore/data/mask_tokens_dataset.py
+132
-0
Uni-Core-main/unicore/data/nested_dictionary_dataset.py
Uni-Core-main/unicore/data/nested_dictionary_dataset.py
+111
-0
Uni-Core-main/unicore/data/num_samples_dataset.py
Uni-Core-main/unicore/data/num_samples_dataset.py
+18
-0
Uni-Core-main/unicore/data/numel_dataset.py
Uni-Core-main/unicore/data/numel_dataset.py
+32
-0
Uni-Core-main/unicore/data/pad_dataset.py
Uni-Core-main/unicore/data/pad_dataset.py
+38
-0
Uni-Core-main/unicore/data/prepend_token_dataset.py
Uni-Core-main/unicore/data/prepend_token_dataset.py
+25
-0
Uni-Core-main/unicore/data/raw_dataset.py
Uni-Core-main/unicore/data/raw_dataset.py
+64
-0
Uni-Core-main/unicore/data/sort_dataset.py
Uni-Core-main/unicore/data/sort_dataset.py
+43
-0
Uni-Core-main/unicore/data/tokenize_dataset.py
Uni-Core-main/unicore/data/tokenize_dataset.py
+29
-0
Uni-Core-main/unicore/data/unicore_dataset.py
Uni-Core-main/unicore/data/unicore_dataset.py
+91
-0
Uni-Core-main/unicore/distributed/__init__.py
Uni-Core-main/unicore/distributed/__init__.py
+12
-0
Uni-Core-main/unicore/distributed/legacy_distributed_data_parallel.py
...n/unicore/distributed/legacy_distributed_data_parallel.py
+168
-0
Uni-Core-main/unicore/distributed/module_proxy_wrapper.py
Uni-Core-main/unicore/distributed/module_proxy_wrapper.py
+56
-0
Uni-Core-main/unicore/distributed/utils.py
Uni-Core-main/unicore/distributed/utils.py
+553
-0
Uni-Core-main/unicore/logging/__init__.py
Uni-Core-main/unicore/logging/__init__.py
+0
-0
Uni-Core-main/unicore/logging/meters.py
Uni-Core-main/unicore/logging/meters.py
+292
-0
Uni-Core-main/unicore/logging/metrics.py
Uni-Core-main/unicore/logging/metrics.py
+288
-0
Uni-Core-main/unicore/logging/progress_bar.py
Uni-Core-main/unicore/logging/progress_bar.py
+370
-0
Uni-Core-main/unicore/losses/__init__.py
Uni-Core-main/unicore/losses/__init__.py
+34
-0
No files found.
Uni-Core-main/unicore/data/lru_cache_dataset.py
0 → 100644
View file @
a1c29028
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
functools
import
lru_cache
from
.
import
BaseWrapperDataset
class
LRUCacheDataset
(
BaseWrapperDataset
):
def
__init__
(
self
,
dataset
,
token
=
None
):
super
().
__init__
(
dataset
)
@
lru_cache
(
maxsize
=
16
)
def
__getitem__
(
self
,
index
):
return
self
.
dataset
[
index
]
@
lru_cache
(
maxsize
=
16
)
def
collater
(
self
,
samples
):
return
self
.
dataset
.
collater
(
samples
)
Uni-Core-main/unicore/data/mask_tokens_dataset.py
0 → 100644
View file @
a1c29028
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
functools
import
lru_cache
import
numpy
as
np
import
torch
from
unicore.data
import
Dictionary
,
data_utils
from
.
import
BaseWrapperDataset
,
LRUCacheDataset
class
MaskTokensDataset
(
BaseWrapperDataset
):
@
classmethod
def
apply_mask
(
cls
,
dataset
:
torch
.
utils
.
data
.
Dataset
,
*
args
,
**
kwargs
):
"""Return the source and target datasets for masked LM training."""
dataset
=
LRUCacheDataset
(
dataset
)
return
(
LRUCacheDataset
(
cls
(
dataset
,
*
args
,
**
kwargs
,
return_masked_tokens
=
False
)),
LRUCacheDataset
(
cls
(
dataset
,
*
args
,
**
kwargs
,
return_masked_tokens
=
True
)),
)
def
__init__
(
self
,
dataset
:
torch
.
utils
.
data
.
Dataset
,
vocab
:
Dictionary
,
pad_idx
:
int
,
mask_idx
:
int
,
return_masked_tokens
:
bool
=
False
,
seed
:
int
=
1
,
mask_prob
:
float
=
0.15
,
leave_unmasked_prob
:
float
=
0.1
,
random_token_prob
:
float
=
0.1
,
):
assert
0.0
<
mask_prob
<
1.0
assert
0.0
<=
random_token_prob
<=
1.0
assert
0.0
<=
leave_unmasked_prob
<=
1.0
assert
random_token_prob
+
leave_unmasked_prob
<=
1.0
self
.
dataset
=
dataset
self
.
vocab
=
vocab
self
.
pad_idx
=
pad_idx
self
.
mask_idx
=
mask_idx
self
.
return_masked_tokens
=
return_masked_tokens
self
.
seed
=
seed
self
.
mask_prob
=
mask_prob
self
.
leave_unmasked_prob
=
leave_unmasked_prob
self
.
random_token_prob
=
random_token_prob
if
random_token_prob
>
0.0
:
weights
=
np
.
ones
(
len
(
self
.
vocab
))
weights
[
vocab
.
special_index
()]
=
0
self
.
weights
=
weights
/
weights
.
sum
()
self
.
epoch
=
None
@
property
def
can_reuse_epoch_itr_across_epochs
(
self
):
return
True
# only the noise changes, not item sizes
def
set_epoch
(
self
,
epoch
,
**
unused
):
super
().
set_epoch
(
epoch
)
self
.
epoch
=
epoch
def
__getitem__
(
self
,
index
:
int
):
return
self
.
__getitem_cached__
(
self
.
epoch
,
index
)
@
lru_cache
(
maxsize
=
16
)
def
__getitem_cached__
(
self
,
epoch
:
int
,
index
:
int
):
with
data_utils
.
numpy_seed
(
self
.
seed
,
epoch
,
index
):
item
=
self
.
dataset
[
index
]
sz
=
len
(
item
)
# don't allow empty sequence
assert
sz
>
2
assert
(
self
.
mask_idx
not
in
item
),
"Dataset contains mask_idx (={}), this is not expected!"
.
format
(
self
.
mask_idx
,
)
# decide elements to mask
mask
=
np
.
full
(
sz
,
False
)
num_mask
=
int
(
# add a random number for probabilistic rounding
self
.
mask_prob
*
(
sz
-
2
)
+
np
.
random
.
rand
()
)
# don't mask first and last position
mask_idc
=
np
.
random
.
choice
(
sz
-
2
,
num_mask
,
replace
=
False
)
+
1
mask
[
mask_idc
]
=
True
if
self
.
return_masked_tokens
:
new_item
=
np
.
full
(
len
(
mask
),
self
.
pad_idx
)
new_item
[
mask
]
=
item
[
torch
.
from_numpy
(
mask
.
astype
(
np
.
uint8
))
==
1
]
return
torch
.
from_numpy
(
new_item
)
# decide unmasking and random replacement
rand_or_unmask_prob
=
self
.
random_token_prob
+
self
.
leave_unmasked_prob
if
rand_or_unmask_prob
>
0.0
:
rand_or_unmask
=
mask
&
(
np
.
random
.
rand
(
sz
)
<
rand_or_unmask_prob
)
if
self
.
random_token_prob
==
0.0
:
unmask
=
rand_or_unmask
rand_mask
=
None
elif
self
.
leave_unmasked_prob
==
0.0
:
unmask
=
None
rand_mask
=
rand_or_unmask
else
:
unmask_prob
=
self
.
leave_unmasked_prob
/
rand_or_unmask_prob
decision
=
np
.
random
.
rand
(
sz
)
<
unmask_prob
unmask
=
rand_or_unmask
&
decision
rand_mask
=
rand_or_unmask
&
(
~
decision
)
else
:
unmask
=
rand_mask
=
None
if
unmask
is
not
None
:
mask
=
mask
^
unmask
new_item
=
np
.
copy
(
item
)
new_item
[
mask
]
=
self
.
mask_idx
if
rand_mask
is
not
None
:
num_rand
=
rand_mask
.
sum
()
if
num_rand
>
0
:
new_item
[
rand_mask
]
=
np
.
random
.
choice
(
len
(
self
.
vocab
),
num_rand
,
p
=
self
.
weights
,
)
return
torch
.
from_numpy
(
new_item
)
Uni-Core-main/unicore/data/nested_dictionary_dataset.py
0 → 100644
View file @
a1c29028
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
collections
import
OrderedDict
import
torch
from
torch.utils.data.dataloader
import
default_collate
from
.
import
UnicoreDataset
def
_flatten
(
dico
,
prefix
=
None
):
"""Flatten a nested dictionary."""
new_dico
=
OrderedDict
()
if
isinstance
(
dico
,
dict
):
prefix
=
prefix
+
"."
if
prefix
is
not
None
else
""
for
k
,
v
in
dico
.
items
():
if
v
is
None
:
continue
new_dico
.
update
(
_flatten
(
v
,
prefix
+
k
))
elif
isinstance
(
dico
,
list
):
for
i
,
v
in
enumerate
(
dico
):
new_dico
.
update
(
_flatten
(
v
,
prefix
+
".["
+
str
(
i
)
+
"]"
))
else
:
new_dico
=
OrderedDict
({
prefix
:
dico
})
return
new_dico
def
_unflatten
(
dico
):
"""Unflatten a flattened dictionary into a nested dictionary."""
new_dico
=
OrderedDict
()
for
full_k
,
v
in
dico
.
items
():
full_k
=
full_k
.
split
(
"."
)
node
=
new_dico
for
k
in
full_k
[:
-
1
]:
if
k
.
startswith
(
"["
)
and
k
.
endswith
(
"]"
):
k
=
int
(
k
[
1
:
-
1
])
if
k
not
in
node
:
node
[
k
]
=
OrderedDict
()
node
=
node
[
k
]
node
[
full_k
[
-
1
]]
=
v
return
new_dico
class
NestedDictionaryDataset
(
UnicoreDataset
):
def
__init__
(
self
,
defn
):
super
().
__init__
()
self
.
defn
=
_flatten
(
defn
)
first
=
None
for
v
in
self
.
defn
.
values
():
if
not
isinstance
(
v
,
(
UnicoreDataset
,
torch
.
utils
.
data
.
Dataset
,
),
):
raise
ValueError
(
"Expected Dataset but found: {}"
.
format
(
v
.
__class__
))
first
=
first
or
v
if
len
(
v
)
>
0
:
assert
len
(
v
)
==
len
(
first
),
"dataset lengths must match"
self
.
_len
=
len
(
first
)
def
__getitem__
(
self
,
index
):
return
OrderedDict
((
k
,
ds
[
index
])
for
k
,
ds
in
self
.
defn
.
items
())
def
__len__
(
self
):
return
self
.
_len
def
collater
(
self
,
samples
):
"""Merge a list of samples to form a mini-batch.
Args:
samples (List[dict]): samples to collate
Returns:
dict: a mini-batch suitable for forwarding with a Model
"""
if
len
(
samples
)
==
0
:
return
{}
sample
=
OrderedDict
()
for
k
,
ds
in
self
.
defn
.
items
():
try
:
sample
[
k
]
=
ds
.
collater
([
s
[
k
]
for
s
in
samples
])
except
NotImplementedError
:
sample
[
k
]
=
default_collate
([
s
[
k
]
for
s
in
samples
])
return
_unflatten
(
sample
)
@
property
def
supports_prefetch
(
self
):
"""Whether this dataset supports prefetching."""
return
any
(
ds
.
supports_prefetch
for
ds
in
self
.
defn
.
values
())
def
prefetch
(
self
,
indices
):
"""Prefetch the data required for this epoch."""
for
ds
in
self
.
defn
.
values
():
if
getattr
(
ds
,
"supports_prefetch"
,
False
):
ds
.
prefetch
(
indices
)
@
property
def
can_reuse_epoch_itr_across_epochs
(
self
):
return
all
(
ds
.
can_reuse_epoch_itr_across_epochs
for
ds
in
self
.
defn
.
values
())
def
set_epoch
(
self
,
epoch
):
super
().
set_epoch
(
epoch
)
for
ds
in
self
.
defn
.
values
():
ds
.
set_epoch
(
epoch
)
Uni-Core-main/unicore/data/num_samples_dataset.py
0 → 100644
View file @
a1c29028
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
.
import
UnicoreDataset
class
NumSamplesDataset
(
UnicoreDataset
):
def
__getitem__
(
self
,
index
):
return
1
def
__len__
(
self
):
return
0
def
collater
(
self
,
samples
):
return
sum
(
samples
)
Uni-Core-main/unicore/data/numel_dataset.py
0 → 100644
View file @
a1c29028
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
numpy
as
np
import
torch
from
.
import
BaseWrapperDataset
class
NumelDataset
(
BaseWrapperDataset
):
def
__init__
(
self
,
dataset
,
reduce
=
False
):
super
().
__init__
(
dataset
)
self
.
reduce
=
reduce
def
__getitem__
(
self
,
index
):
item
=
self
.
dataset
[
index
]
if
torch
.
is_tensor
(
item
):
return
torch
.
numel
(
item
)
else
:
return
np
.
size
(
item
)
def
__len__
(
self
):
return
len
(
self
.
dataset
)
def
collater
(
self
,
samples
):
if
self
.
reduce
:
return
sum
(
samples
)
else
:
return
torch
.
tensor
(
samples
)
Uni-Core-main/unicore/data/pad_dataset.py
0 → 100644
View file @
a1c29028
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
unicore.data
import
data_utils
from
.
import
BaseWrapperDataset
class
PadDataset
(
BaseWrapperDataset
):
def
__init__
(
self
,
dataset
,
pad_idx
,
left_pad
):
super
().
__init__
(
dataset
)
self
.
pad_idx
=
pad_idx
self
.
left_pad
=
left_pad
def
collater
(
self
,
samples
):
return
data_utils
.
collate_tokens
(
samples
,
self
.
pad_idx
,
left_pad
=
self
.
left_pad
,
pad_to_multiple
=
8
)
class
LeftPadDataset
(
PadDataset
):
def
__init__
(
self
,
dataset
,
pad_idx
):
super
().
__init__
(
dataset
,
pad_idx
,
left_pad
=
True
)
class
RightPadDataset
(
PadDataset
):
def
__init__
(
self
,
dataset
,
pad_idx
):
super
().
__init__
(
dataset
,
pad_idx
,
left_pad
=
False
)
class
RightPadDataset2D
(
BaseWrapperDataset
):
def
__init__
(
self
,
dataset
,
pad_idx
,
left_pad
=
False
):
super
().
__init__
(
dataset
)
self
.
pad_idx
=
pad_idx
self
.
left_pad
=
left_pad
def
collater
(
self
,
samples
):
return
data_utils
.
collate_tokens_2d
(
samples
,
self
.
pad_idx
,
left_pad
=
self
.
left_pad
,
pad_to_multiple
=
8
)
Uni-Core-main/unicore/data/prepend_token_dataset.py
0 → 100644
View file @
a1c29028
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
numpy
as
np
import
torch
from
functools
import
lru_cache
from
.
import
BaseWrapperDataset
class
PrependTokenDataset
(
BaseWrapperDataset
):
def
__init__
(
self
,
dataset
,
token
=
None
):
super
().
__init__
(
dataset
)
self
.
token
=
token
@
lru_cache
(
maxsize
=
16
)
def
__getitem__
(
self
,
idx
):
item
=
self
.
dataset
[
idx
]
if
self
.
token
is
not
None
:
item
=
torch
.
cat
([
torch
.
full_like
(
item
[
0
],
self
.
token
).
unsqueeze
(
0
),
item
],
dim
=
0
)
return
item
Uni-Core-main/unicore/data/raw_dataset.py
0 → 100644
View file @
a1c29028
# Copyright (c) DP Technology.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
torch
from
torch.utils.data.dataloader
import
default_collate
from
functools
import
lru_cache
from
.
import
UnicoreDataset
class
RawLabelDataset
(
UnicoreDataset
):
def
__init__
(
self
,
labels
):
super
().
__init__
()
self
.
labels
=
labels
@
lru_cache
(
maxsize
=
16
)
def
__getitem__
(
self
,
index
):
return
self
.
labels
[
index
]
def
__len__
(
self
):
return
len
(
self
.
labels
)
def
collater
(
self
,
samples
):
return
torch
.
tensor
(
samples
)
class
RawArrayDataset
(
UnicoreDataset
):
def
__init__
(
self
,
dataset
):
super
().
__init__
()
self
.
dataset
=
dataset
@
lru_cache
(
maxsize
=
16
)
def
__getitem__
(
self
,
index
):
return
self
.
dataset
[
index
]
def
__len__
(
self
):
return
len
(
self
.
dataset
)
def
collater
(
self
,
samples
):
if
hasattr
(
self
.
dataset
,
'collater'
):
return
self
.
dataset
.
collater
(
samples
)
else
:
return
default_collate
(
samples
)
class
RawNumpyDataset
(
UnicoreDataset
):
def
__init__
(
self
,
dataset
):
super
().
__init__
()
self
.
dataset
=
dataset
@
lru_cache
(
maxsize
=
16
)
def
__getitem__
(
self
,
index
):
return
torch
.
from_numpy
(
self
.
dataset
[
index
])
def
__len__
(
self
):
return
len
(
self
.
dataset
)
def
collater
(
self
,
samples
):
if
hasattr
(
self
.
dataset
,
'collater'
):
return
self
.
dataset
.
collater
(
samples
)
else
:
return
default_collate
(
samples
)
Uni-Core-main/unicore/data/sort_dataset.py
0 → 100644
View file @
a1c29028
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
numpy
as
np
from
.
import
BaseWrapperDataset
,
data_utils
class
SortDataset
(
BaseWrapperDataset
):
def
__init__
(
self
,
dataset
,
sort_order
):
super
().
__init__
(
dataset
)
if
not
isinstance
(
sort_order
,
(
list
,
tuple
)):
sort_order
=
[
sort_order
]
self
.
sort_order
=
sort_order
assert
all
(
len
(
so
)
==
len
(
dataset
)
for
so
in
sort_order
)
def
ordered_indices
(
self
):
return
np
.
lexsort
(
self
.
sort_order
)
class
EpochShuffleDataset
(
BaseWrapperDataset
):
def
__init__
(
self
,
dataset
,
size
,
seed
):
super
().
__init__
(
dataset
)
self
.
size
=
size
self
.
seed
=
seed
self
.
set_epoch
(
1
)
def
set_epoch
(
self
,
epoch
):
super
().
set_epoch
(
epoch
)
with
data_utils
.
numpy_seed
(
self
.
seed
+
epoch
-
1
):
self
.
sort_order
=
np
.
random
.
permutation
(
self
.
size
)
def
ordered_indices
(
self
):
return
self
.
sort_order
@
property
def
can_reuse_epoch_itr_across_epochs
(
self
):
return
False
\ No newline at end of file
Uni-Core-main/unicore/data/tokenize_dataset.py
0 → 100644
View file @
a1c29028
# Copyright (c) DP Technology.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
functools
import
lru_cache
import
torch
from
unicore.data
import
Dictionary
from
functools
import
lru_cache
from
.
import
BaseWrapperDataset
class
TokenizeDataset
(
BaseWrapperDataset
):
def
__init__
(
self
,
dataset
:
torch
.
utils
.
data
.
Dataset
,
dictionary
:
Dictionary
,
max_seq_len
:
int
=
512
,
):
self
.
dataset
=
dataset
self
.
dictionary
=
dictionary
self
.
max_seq_len
=
max_seq_len
@
lru_cache
(
maxsize
=
16
)
def
__getitem__
(
self
,
index
:
int
):
raw_data
=
self
.
dataset
[
index
]
assert
len
(
raw_data
)
<
self
.
max_seq_len
and
len
(
raw_data
)
>
0
return
torch
.
from_numpy
(
self
.
dictionary
.
vec_index
(
raw_data
)).
long
()
\ No newline at end of file
Uni-Core-main/unicore/data/unicore_dataset.py
0 → 100644
View file @
a1c29028
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
logging
import
numpy
as
np
import
torch.utils.data
logger
=
logging
.
getLogger
(
__name__
)
class
EpochListening
:
"""Mixin for receiving updates whenever the epoch increments."""
@
property
def
can_reuse_epoch_itr_across_epochs
(
self
):
"""
Whether we can reuse the :class:`unicore.data.EpochBatchIterator` for
this dataset across epochs.
This needs to return ``False`` if the sample sizes can change across
epochs, in which case we may need to regenerate batches at each epoch.
If your dataset relies in ``set_epoch`` then you should consider setting
this to ``False``.
"""
return
True
def
set_epoch
(
self
,
epoch
):
"""Will receive the updated epoch number at the beginning of the epoch."""
pass
class
UnicoreDataset
(
torch
.
utils
.
data
.
Dataset
,
EpochListening
):
"""A dataset that provides helpers for batching."""
def
__getitem__
(
self
,
index
):
raise
NotImplementedError
def
__len__
(
self
):
raise
NotImplementedError
def
collater
(
self
,
samples
):
"""Merge a list of samples to form a mini-batch.
Args:
samples (List[dict]): samples to collate
Returns:
dict: a mini-batch suitable for forwarding with a Model
"""
raise
NotImplementedError
def
ordered_indices
(
self
):
"""Return an ordered list of indices. Batches will be constructed based
on this order."""
return
np
.
arange
(
len
(
self
),
dtype
=
np
.
int64
)
@
property
def
supports_prefetch
(
self
):
"""Whether this dataset supports prefetching."""
return
False
def
attr
(
self
,
attr
:
str
,
index
:
int
):
return
getattr
(
self
,
attr
,
None
)
def
prefetch
(
self
,
indices
):
"""Prefetch the data required for this epoch."""
raise
NotImplementedError
def
batch_by_size
(
self
,
indices
,
batch_size
=
None
,
required_batch_size_multiple
=
1
,
):
"""
Given an ordered set of indices
"""
from
unicore.data
import
data_utils
return
data_utils
.
batch_by_size
(
indices
,
batch_size
=
batch_size
,
required_batch_size_multiple
=
required_batch_size_multiple
,
)
@
property
def
supports_fetch_outside_dataloader
(
self
):
"""Whether this dataset supports fetching outside the workers of the dataloader."""
return
True
Uni-Core-main/unicore/distributed/__init__.py
0 → 100644
View file @
a1c29028
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
.module_proxy_wrapper
import
ModuleProxyWrapper
from
.legacy_distributed_data_parallel
import
LegacyDistributedDataParallel
__all__
=
[
"ModuleProxyWrapper"
,
]
Uni-Core-main/unicore/distributed/legacy_distributed_data_parallel.py
0 → 100644
View file @
a1c29028
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
A modified version of the legacy DistributedDataParallel module that uses c10d
communication primitives. This version is simpler than the latest PyTorch
version and is useful for debugging. Notably it does not overlap gradient
communication with the backward pass, which makes it slower but more robust
than the PyTorch version.
This version also supports the *no_sync* context manager, which allows faster
training with `--update-freq`.
"""
from
collections
import
OrderedDict
from
contextlib
import
contextmanager
import
torch
from
torch
import
nn
from
unicore.distributed
import
utils
class
LegacyDistributedDataParallel
(
nn
.
Module
):
"""Implements distributed data parallelism at the module level.
A simplified version of :class:`torch.nn.parallel.DistributedDataParallel`.
This version uses a c10d process group for communication and does not
broadcast buffers.
Args:
module (~torch.nn.Module): module to be parallelized
process_group: the c10d process group to be used for distributed data
parallel all-reduction.
buffer_size (int, optional): number of elements to buffer before
performing all-reduce (default: 256M).
"""
def
__init__
(
self
,
module
,
process_group
,
buffer_size
=
2
**
28
):
super
().
__init__
()
self
.
module
=
module
self
.
process_group
=
process_group
self
.
world_size
=
utils
.
get_world_size
(
self
.
process_group
)
# Never use a bigger buffer than the number of model params
self
.
buffer_size
=
min
(
buffer_size
,
sum
(
p
.
numel
()
for
p
in
module
.
parameters
()))
self
.
buffer
=
None
# We can also forcibly accumulate grads locally and only do the
# all-reduce at some later time
self
.
accumulate_grads
=
False
# make per-device lists of parameters
paramlists
=
OrderedDict
()
for
param
in
self
.
module
.
parameters
():
device
=
param
.
device
if
paramlists
.
get
(
device
)
is
None
:
paramlists
[
device
]
=
[]
paramlists
[
device
]
+=
[
param
]
self
.
per_device_params
=
list
(
paramlists
.
values
())
@
contextmanager
def
no_sync
(
self
):
"""A context manager to disable gradient synchronization."""
old_accumulate_grads
=
self
.
accumulate_grads
self
.
accumulate_grads
=
True
yield
self
.
accumulate_grads
=
old_accumulate_grads
def
forward
(
self
,
*
inputs
,
**
kwargs
):
return
self
.
module
(
*
inputs
,
**
kwargs
)
def
all_reduce_params
(
self
,
params
):
if
self
.
accumulate_grads
:
return
buffer
=
self
.
buffer
nonzero_buffer
=
False
if
len
(
params
)
>
1
:
offset
=
0
for
p
in
params
:
sz
=
p
.
numel
()
if
p
.
grad
is
not
None
:
buffer
[
offset
:
offset
+
sz
].
copy_
(
p
.
grad
.
data
.
view
(
-
1
))
nonzero_buffer
=
True
else
:
buffer
[
offset
:
offset
+
sz
].
zero_
()
offset
+=
sz
else
:
# we only have a single grad to all-reduce
p
=
params
[
0
]
if
p
.
grad
is
not
None
:
buffer
=
p
.
grad
.
data
nonzero_buffer
=
True
elif
p
.
numel
()
<=
self
.
buffer
.
numel
():
buffer
=
buffer
[:
p
.
numel
()]
buffer
.
zero_
()
else
:
buffer
=
torch
.
zeros_like
(
p
)
if
nonzero_buffer
:
buffer
.
div_
(
self
.
world_size
)
utils
.
all_reduce
(
buffer
,
self
.
process_group
)
# copy all-reduced grads back into their original place
offset
=
0
for
p
in
params
:
sz
=
p
.
numel
()
if
p
.
grad
is
not
None
:
p
.
grad
.
data
.
copy_
(
buffer
[
offset
:
offset
+
sz
].
view_as
(
p
))
else
:
p
.
grad
=
buffer
[
offset
:
offset
+
sz
].
view_as
(
p
).
clone
()
offset
+=
sz
def
all_reduce_grads
(
self
):
"""
This function must be called explicitly after backward to reduce
gradients. There is no automatic hook like c10d.
"""
def
reduction_fn
():
# This function only needs to be called once
if
self
.
accumulate_grads
:
return
if
self
.
buffer
is
None
:
self
.
buffer
=
next
(
self
.
module
.
parameters
()).
new
(
self
.
buffer_size
)
for
params
in
self
.
per_device_params
:
# All-reduce the gradients in buckets
offset
=
0
buffered_params
=
[]
for
param
in
params
:
if
not
param
.
requires_grad
:
continue
if
param
.
grad
is
None
:
param
.
grad
=
torch
.
zeros_like
(
param
)
if
hasattr
(
param
,
"expert"
):
# Skip gradient sync for unshared parameters
continue
if
param
.
grad
.
requires_grad
:
raise
RuntimeError
(
"DistributedDataParallel only works "
"with gradients that don't require "
"grad"
)
sz
=
param
.
numel
()
if
sz
>
self
.
buffer
.
numel
():
# all-reduce big params directly
self
.
all_reduce_params
([
param
])
else
:
if
offset
+
sz
>
self
.
buffer
.
numel
():
self
.
all_reduce_params
(
buffered_params
)
offset
=
0
buffered_params
.
clear
()
buffered_params
.
append
(
param
)
offset
+=
sz
if
len
(
buffered_params
)
>
0
:
self
.
all_reduce_params
(
buffered_params
)
reduction_fn
()
\ No newline at end of file
Uni-Core-main/unicore/distributed/module_proxy_wrapper.py
0 → 100644
View file @
a1c29028
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
torch
import
nn
class
ModuleProxyWrapper
(
nn
.
Module
):
"""
Wrap a DistributedDataParallel module and forward requests for missing
attributes to the module wrapped by DDP (the twice-wrapped module).
Also forward calls to :func:`state_dict` and :func:`load_state_dict`.
Usage::
module.xyz = "hello world"
wrapped_module = DistributedDataParallel(module, **ddp_args)
wrapped_module = ModuleProxyWrapper(wrapped_module)
assert wrapped_module.xyz == "hello world"
assert wrapped_module.state_dict().keys() == module.state_dict().keys()
Args:
module (nn.Module): module to wrap
"""
def
__init__
(
self
,
module
:
nn
.
Module
):
super
().
__init__
()
assert
hasattr
(
module
,
"module"
),
\
"ModuleProxyWrapper expects input to wrap another module"
self
.
module
=
module
def
__getattr__
(
self
,
name
):
"""Forward missing attributes to twice-wrapped module."""
try
:
# defer to nn.Module's logic
return
super
().
__getattr__
(
name
)
except
AttributeError
:
try
:
# forward to the once-wrapped module
return
getattr
(
self
.
module
,
name
)
except
AttributeError
:
# forward to the twice-wrapped module
return
getattr
(
self
.
module
.
module
,
name
)
def
state_dict
(
self
,
*
args
,
**
kwargs
):
"""Forward to the twice-wrapped module."""
return
self
.
module
.
module
.
state_dict
(
*
args
,
**
kwargs
)
def
load_state_dict
(
self
,
*
args
,
**
kwargs
):
"""Forward to the twice-wrapped module."""
return
self
.
module
.
module
.
load_state_dict
(
*
args
,
**
kwargs
)
def
forward
(
self
,
*
args
,
**
kwargs
):
return
self
.
module
(
*
args
,
**
kwargs
)
Uni-Core-main/unicore/distributed/utils.py
0 → 100644
View file @
a1c29028
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
datetime
import
io
import
logging
import
os
import
pickle
import
random
import
socket
import
struct
import
subprocess
import
warnings
from
collections
import
OrderedDict
from
typing
import
Any
,
Dict
,
List
,
Mapping
,
Optional
from
dataclasses
import
dataclass
import
torch
import
torch.distributed
as
dist
logger
=
logging
.
getLogger
(
__name__
)
def
is_master
(
args
):
return
args
.
distributed_rank
==
0
def
infer_init_method
(
args
,
force_distributed
=
False
):
if
args
.
distributed_init_method
is
not
None
:
return
if
all
(
key
in
os
.
environ
for
key
in
[
"MASTER_ADDR"
,
"MASTER_PORT"
,
"WORLD_SIZE"
,
"RANK"
]
):
# support torch.distributed.launch
_infer_torch_distributed_launch_init
(
args
)
elif
args
.
distributed_port
>
0
:
# we can determine the init method automatically for Slurm
_infer_slurm_init
(
args
)
elif
args
.
distributed_world_size
>
1
or
force_distributed
:
# fallback for single node with multiple GPUs
_infer_single_node_init
(
args
)
elif
not
args
.
distributed_no_spawn
:
args
.
distributed_num_procs
=
min
(
torch
.
cuda
.
device_count
(),
args
.
distributed_world_size
)
def
_infer_torch_distributed_launch_init
(
args
):
args
.
distributed_init_method
=
"env://"
args
.
distributed_world_size
=
int
(
os
.
environ
[
"WORLD_SIZE"
])
args
.
distributed_rank
=
int
(
os
.
environ
[
"RANK"
])
# processes are created by torch.distributed.launch
args
.
distributed_no_spawn
=
True
def
_infer_slurm_init
(
args
):
node_list
=
os
.
environ
.
get
(
"SLURM_STEP_NODELIST"
)
if
node_list
is
None
:
node_list
=
os
.
environ
.
get
(
"SLURM_JOB_NODELIST"
)
if
node_list
is
not
None
:
try
:
hostnames
=
subprocess
.
check_output
(
[
"scontrol"
,
"show"
,
"hostnames"
,
node_list
]
)
args
.
distributed_init_method
=
"tcp://{host}:{port}"
.
format
(
host
=
hostnames
.
split
()[
0
].
decode
(
"utf-8"
),
port
=
args
.
distributed_port
,
)
nnodes
=
int
(
os
.
environ
.
get
(
"SLURM_NNODES"
))
ntasks_per_node
=
os
.
environ
.
get
(
"SLURM_NTASKS_PER_NODE"
)
if
ntasks_per_node
is
not
None
:
ntasks_per_node
=
int
(
ntasks_per_node
)
else
:
ntasks
=
int
(
os
.
environ
.
get
(
"SLURM_NTASKS"
))
nnodes
=
int
(
os
.
environ
.
get
(
"SLURM_NNODES"
))
assert
ntasks
%
nnodes
==
0
ntasks_per_node
=
int
(
ntasks
/
nnodes
)
if
ntasks_per_node
==
1
:
gpus_per_node
=
torch
.
cuda
.
device_count
()
node_id
=
int
(
os
.
environ
.
get
(
"SLURM_NODEID"
))
args
.
distributed_rank
=
node_id
*
gpus_per_node
args
.
distributed_world_size
=
nnodes
*
gpus_per_node
else
:
assert
ntasks_per_node
==
args
.
distributed_world_size
//
nnodes
args
.
distributed_no_spawn
=
True
args
.
distributed_rank
=
int
(
os
.
environ
.
get
(
"SLURM_PROCID"
))
args
.
device_id
=
int
(
os
.
environ
.
get
(
"SLURM_LOCALID"
))
except
subprocess
.
CalledProcessError
as
e
:
# scontrol failed
raise
e
except
FileNotFoundError
:
# Slurm is not installed
pass
def
_infer_single_node_init
(
args
):
assert
(
args
.
distributed_world_size
<=
torch
.
cuda
.
device_count
()
),
f
"world size is
{
args
.
distributed_world_size
}
but have
{
torch
.
cuda
.
device_count
()
}
available devices"
port
=
random
.
randint
(
10000
,
20000
)
args
.
distributed_init_method
=
"tcp://localhost:{port}"
.
format
(
port
=
port
)
def
distributed_init
(
args
):
if
torch
.
distributed
.
is_available
()
and
torch
.
distributed
.
is_initialized
():
warnings
.
warn
(
"Distributed is already initialized, cannot initialize twice!"
)
else
:
logger
.
info
(
"distributed init (rank {}): {}"
.
format
(
args
.
distributed_rank
,
args
.
distributed_init_method
,
)
)
dist
.
init_process_group
(
backend
=
args
.
distributed_backend
,
init_method
=
args
.
distributed_init_method
,
world_size
=
args
.
distributed_world_size
,
rank
=
args
.
distributed_rank
,
timeout
=
datetime
.
timedelta
(
seconds
=
30
),
)
logger
.
info
(
"initialized host {} as rank {}"
.
format
(
socket
.
gethostname
(),
args
.
distributed_rank
,
)
)
# perform a dummy all-reduce to initialize the NCCL communicator
if
torch
.
cuda
.
is_available
():
dist
.
all_reduce
(
torch
.
zeros
(
1
).
cuda
())
args
.
distributed_rank
=
torch
.
distributed
.
get_rank
()
if
is_master
(
args
):
logging
.
getLogger
().
setLevel
(
logging
.
INFO
)
else
:
logging
.
getLogger
().
setLevel
(
logging
.
WARNING
)
return
args
.
distributed_rank
def
distributed_main
(
i
,
main
,
args
,
kwargs
):
args
.
device_id
=
i
if
torch
.
cuda
.
is_available
()
and
not
args
.
cpu
:
torch
.
cuda
.
set_device
(
args
.
device_id
)
if
args
.
distributed_rank
is
None
:
# torch.multiprocessing.spawn
args
.
distributed_rank
=
kwargs
.
pop
(
"start_rank"
,
0
)
+
i
args
.
distributed_rank
=
distributed_init
(
args
)
after_distributed_init_fn
=
kwargs
.
pop
(
"after_distributed_init_fn"
,
None
)
if
after_distributed_init_fn
:
args
=
after_distributed_init_fn
(
args
)
main
(
args
,
**
kwargs
)
if
torch
.
distributed
.
is_initialized
():
torch
.
distributed
.
barrier
(
get_global_group
())
def
call_main
(
args
,
main
,
**
kwargs
):
if
args
.
distributed_init_method
is
None
:
infer_init_method
(
args
)
if
args
.
distributed_init_method
is
not
None
:
# distributed training
if
not
args
.
distributed_no_spawn
:
start_rank
=
args
.
distributed_rank
args
.
distributed_rank
=
None
# assign automatically
kwargs
[
"start_rank"
]
=
start_rank
torch
.
multiprocessing
.
spawn
(
fn
=
distributed_main
,
args
=
(
main
,
args
,
kwargs
),
nprocs
=
min
(
torch
.
cuda
.
device_count
(),
args
.
distributed_world_size
,
),
join
=
True
,
)
else
:
distributed_main
(
args
.
device_id
,
main
,
args
,
kwargs
)
else
:
# single GPU main
main
(
args
,
**
kwargs
)
def
new_groups
(
grouped_ranks
:
List
[
List
[
int
]]):
groups
=
[
dist
.
new_group
(
g
)
for
g
in
grouped_ranks
]
my_group_idx
=
_find_my_group_index
(
grouped_ranks
)
return
groups
[
my_group_idx
]
def
_find_my_group_index
(
grouped_ranks
):
my_rank
=
get_global_rank
()
for
i
,
group
in
enumerate
(
grouped_ranks
):
if
my_rank
in
group
:
return
i
raise
RuntimeError
def
_find_my_group
(
grouped_ranks
):
index
=
_find_my_group_index
(
grouped_ranks
)
return
grouped_ranks
[
index
]
def
get_rank
(
group
):
return
dist
.
get_rank
(
group
=
group
)
def
get_world_size
(
group
):
if
torch
.
distributed
.
is_initialized
():
return
dist
.
get_world_size
(
group
=
group
)
else
:
return
1
def
get_global_group
():
if
torch
.
distributed
.
is_initialized
():
if
not
hasattr
(
get_global_group
,
"_global_group"
):
# ideally we could use torch.distributed.group.WORLD, but it seems
# to cause random NCCL hangs in some cases
get_global_group
.
_global_group
=
dist
.
new_group
()
return
get_global_group
.
_global_group
else
:
return
None
def
get_global_rank
():
if
torch
.
distributed
.
is_initialized
():
return
torch
.
distributed
.
get_rank
()
else
:
return
0
def
get_global_world_size
():
if
torch
.
distributed
.
is_initialized
():
return
torch
.
distributed
.
get_world_size
()
else
:
return
1
def
get_data_parallel_group
():
"""Get the data parallel group the caller rank belongs to."""
return
get_global_group
()
def
get_data_parallel_rank
():
"""Return my rank for the data parallel group."""
return
get_rank
(
get_data_parallel_group
())
def
get_data_parallel_world_size
():
"""Return world size for the data parallel group."""
return
get_world_size
(
get_data_parallel_group
())
def
all_reduce
(
tensor
,
group
,
op
=
"sum"
):
if
op
==
"sum"
:
op
=
dist
.
ReduceOp
.
SUM
elif
op
==
"max"
:
op
=
dist
.
ReduceOp
.
MAX
else
:
raise
NotImplementedError
dist
.
all_reduce
(
tensor
,
op
=
op
,
group
=
group
)
return
tensor
def
broadcast
(
tensor
,
src
,
group
):
dist
.
broadcast
(
tensor
,
src
=
src
,
group
=
group
)
def
all_to_all
(
tensor
,
group
):
"""Perform an all-to-all operation on a 1D Tensor."""
assert
tensor
.
dim
()
==
1
split_count
=
get_world_size
(
group
=
group
)
assert
tensor
.
numel
()
%
split_count
==
0
output
=
torch
.
zeros_like
(
tensor
)
dist
.
all_to_all_single
(
output
,
tensor
,
group
=
group
)
return
output
def
all_gather
(
tensor
,
group
,
return_tensor
=
False
):
"""Perform an all-gather operation."""
world_size
=
get_world_size
(
group
=
group
)
rank
=
get_rank
(
group
=
group
)
tensor_list
=
[
tensor
if
i
==
rank
else
torch
.
empty_like
(
tensor
)
for
i
in
range
(
world_size
)
]
dist
.
all_gather
(
tensor_list
,
tensor
,
group
=
group
)
if
return_tensor
:
return
torch
.
stack
(
tensor_list
,
dim
=
0
)
else
:
return
tensor_list
def
all_gather_list
(
data
,
group
=
None
,
max_size
=
16384
):
"""Gathers arbitrary data from all nodes into a list.
Similar to :func:`~torch.distributed.all_gather` but for arbitrary Python
data. Note that *data* must be picklable and any CUDA tensors will be moved
to CPU and returned on CPU as well.
Args:
data (Any): data from the local worker to be gathered on other workers
group: group of the collective
max_size (int, optional): maximum size of the data to be gathered
across workers
"""
from
unicore
import
utils
if
group
is
None
:
group
=
get_global_group
()
rank
=
get_rank
(
group
=
group
)
world_size
=
get_world_size
(
group
=
group
)
buffer_size
=
max_size
*
world_size
if
(
not
hasattr
(
all_gather_list
,
"_buffer"
)
or
all_gather_list
.
_buffer
.
numel
()
<
buffer_size
):
all_gather_list
.
_buffer
=
torch
.
cuda
.
ByteTensor
(
buffer_size
)
all_gather_list
.
_cpu_buffer
=
torch
.
ByteTensor
(
max_size
).
pin_memory
()
buffer
=
all_gather_list
.
_buffer
buffer
.
zero_
()
cpu_buffer
=
all_gather_list
.
_cpu_buffer
data
=
utils
.
move_to_cpu
(
data
)
enc
=
pickle
.
dumps
(
data
)
enc_size
=
len
(
enc
)
header_size
=
4
# size of header that contains the length of the encoded data
size
=
header_size
+
enc_size
if
size
>
max_size
:
raise
ValueError
(
"encoded data size ({}) exceeds max_size ({})"
.
format
(
size
,
max_size
)
)
header
=
struct
.
pack
(
">I"
,
enc_size
)
cpu_buffer
[:
size
]
=
torch
.
ByteTensor
(
list
(
header
+
enc
))
start
=
rank
*
max_size
buffer
[
start
:
start
+
size
].
copy_
(
cpu_buffer
[:
size
])
all_reduce
(
buffer
,
group
=
group
)
buffer
=
buffer
.
cpu
()
try
:
result
=
[]
for
i
in
range
(
world_size
):
out_buffer
=
buffer
[
i
*
max_size
:
(
i
+
1
)
*
max_size
]
(
enc_size
,)
=
struct
.
unpack
(
">I"
,
bytes
(
out_buffer
[:
header_size
].
tolist
()))
if
enc_size
>
0
:
result
.
append
(
pickle
.
loads
(
bytes
(
out_buffer
[
header_size
:
header_size
+
enc_size
].
tolist
())
)
)
return
result
except
pickle
.
UnpicklingError
:
raise
Exception
(
"Unable to unpickle data from other workers. all_gather_list requires all "
"workers to enter the function together, so this error usually indicates "
"that the workers have fallen out of sync somehow. Workers can fall out of "
"sync if one of them runs out of memory, or if there are other conditions "
"in your training script that can cause one worker to finish an epoch "
"while other workers are still iterating over their portions of the data. "
"Try rerunning with --ddp-backend=legacy_ddp and see if that helps."
)
def
all_reduce_dict
(
data
:
Mapping
[
str
,
Any
],
device
,
group
)
->
Dict
[
str
,
Any
]:
"""
AllReduce a dictionary of values across workers. We separately
reduce items that are already on the device and items on CPU for
better performance.
Args:
data (Mapping[str, Any]): dictionary of data to all-reduce, but
cannot be a nested dictionary
device (torch.device): device for the reduction
group: group of the collective
"""
data_keys
=
list
(
data
.
keys
())
# We want to separately reduce items that are already on the
# device and items on CPU for performance reasons.
cpu_data
=
OrderedDict
()
device_data
=
OrderedDict
()
for
k
in
data_keys
:
t
=
data
[
k
]
if
not
torch
.
is_tensor
(
t
):
cpu_data
[
k
]
=
torch
.
tensor
(
t
,
dtype
=
torch
.
double
)
elif
t
.
device
.
type
!=
device
.
type
:
cpu_data
[
k
]
=
t
.
to
(
dtype
=
torch
.
double
)
else
:
device_data
[
k
]
=
t
.
to
(
dtype
=
torch
.
double
)
def
_all_reduce_dict
(
data
:
OrderedDict
):
if
len
(
data
)
==
0
:
return
data
buf
=
torch
.
cat
([
t
.
view
(
-
1
)
for
t
in
data
.
values
()]).
to
(
device
=
device
)
all_reduce
(
buf
,
group
=
group
)
split_buf
=
torch
.
split
(
buf
,
[
t
.
numel
()
for
t
in
data
.
values
()])
reduced_data
=
[
t
.
view_as
(
orig
)
for
t
,
orig
in
zip
(
split_buf
,
data
.
values
())]
return
OrderedDict
(
zip
(
data
.
keys
(),
reduced_data
))
cpu_data
=
_all_reduce_dict
(
cpu_data
)
device_data
=
_all_reduce_dict
(
device_data
)
def
get_from_stack
(
key
):
if
key
in
cpu_data
:
return
cpu_data
[
key
]
elif
key
in
device_data
:
return
device_data
[
key
]
raise
KeyError
return
OrderedDict
([(
key
,
get_from_stack
(
key
))
for
key
in
data_keys
])
@
dataclass
class
_TensorPlaceholder
:
index
:
int
def
broadcast_tensors
(
tensors
:
Optional
[
List
[
torch
.
Tensor
]],
src_rank
:
int
,
group
:
object
,
dist_device
:
Optional
[
torch
.
device
]
=
None
,
)
->
List
[
torch
.
Tensor
]:
"""
Broadcasts a list of tensors without other (non-src) ranks needing to know
the dtypes/shapes of the tensors.
"""
if
dist_device
is
None
:
if
torch
.
distributed
.
get_backend
(
group
)
==
"nccl"
:
dist_device
=
torch
.
device
(
"cuda"
)
else
:
dist_device
=
torch
.
device
(
"cpu"
)
# share metadata first to simplify transfer
is_src_rank
=
get_rank
(
group
)
==
src_rank
if
is_src_rank
:
metadata
=
[
{
"size"
:
t
.
size
(),
"dtype"
:
t
.
dtype
,
"device"
:
t
.
device
}
for
t
in
tensors
]
metadata
=
_broadcast_object_slow
(
metadata
,
src_rank
,
group
,
dist_device
)
else
:
metadata
=
_broadcast_object_slow
(
None
,
src_rank
,
group
,
dist_device
)
out_tensors
=
[]
for
i
,
meta
in
enumerate
(
metadata
):
if
is_src_rank
:
tensor
=
tensors
[
i
]
broadcast
(
tensors
[
i
].
to
(
dist_device
),
src
=
src_rank
,
group
=
group
)
else
:
tensor
=
torch
.
zeros
(
[
meta
[
"size"
].
numel
()],
dtype
=
meta
[
"dtype"
],
device
=
dist_device
)
broadcast
(
tensor
,
src
=
src_rank
,
group
=
group
)
tensor
=
tensor
.
view
(
meta
[
"size"
]).
to
(
meta
[
"device"
])
out_tensors
.
append
(
tensor
)
return
out_tensors
def
broadcast_object
(
obj
:
Any
,
src_rank
:
int
,
group
:
object
,
dist_device
:
Optional
[
torch
.
device
]
=
None
,
)
->
Any
:
"""Broadcast an arbitrary Python object to other workers."""
if
dist_device
is
None
:
if
torch
.
distributed
.
get_backend
(
group
)
==
"nccl"
:
dist_device
=
torch
.
device
(
"cuda"
)
else
:
dist_device
=
torch
.
device
(
"cpu"
)
if
get_rank
(
group
)
==
src_rank
:
# split the tensors from the non-tensors so we can broadcast them
# directly, avoiding unnecessary serialization/deserialization
tensors
=
[]
obj
=
_split_tensors_from_obj
(
obj
,
tensors
)
obj
=
_broadcast_object_slow
(
obj
,
src_rank
,
group
,
dist_device
)
tensors
=
broadcast_tensors
(
tensors
,
src_rank
,
group
,
dist_device
)
else
:
obj
=
_broadcast_object_slow
(
None
,
src_rank
,
group
,
dist_device
)
tensors
=
broadcast_tensors
(
None
,
src_rank
,
group
,
dist_device
)
return
_put_tensors_in_obj
(
obj
,
tensors
)
def
_broadcast_object_slow
(
obj
:
Any
,
src_rank
:
int
,
group
:
object
,
dist_device
:
torch
.
device
,
)
->
Any
:
if
get_rank
(
group
)
==
src_rank
:
# Emit data
buffer
=
io
.
BytesIO
()
torch
.
save
(
obj
,
buffer
)
buffer
=
torch
.
ByteTensor
(
buffer
.
getbuffer
()).
to
(
dist_device
)
length
=
torch
.
LongTensor
([
len
(
buffer
)]).
to
(
dist_device
)
broadcast
(
length
,
src
=
src_rank
,
group
=
group
)
broadcast
(
buffer
,
src
=
src_rank
,
group
=
group
)
else
:
# Fetch from the source
length
=
torch
.
LongTensor
([
0
]).
to
(
dist_device
)
broadcast
(
length
,
src
=
src_rank
,
group
=
group
)
buffer
=
torch
.
ByteTensor
(
int
(
length
.
item
())).
to
(
dist_device
)
broadcast
(
buffer
,
src
=
src_rank
,
group
=
group
)
buffer
=
io
.
BytesIO
(
buffer
.
cpu
().
numpy
())
obj
=
torch
.
load
(
buffer
,
map_location
=
"cpu"
)
return
obj
def
_split_tensors_from_obj
(
obj
:
Any
,
tensors
:
List
[
torch
.
Tensor
])
->
Any
:
if
torch
.
is_tensor
(
obj
):
placeholder
=
_TensorPlaceholder
(
index
=
len
(
tensors
))
tensors
.
append
(
obj
)
return
placeholder
elif
isinstance
(
obj
,
dict
):
return
{
k
:
_split_tensors_from_obj
(
v
,
tensors
)
for
k
,
v
in
obj
.
items
()}
elif
isinstance
(
obj
,
list
):
return
[
_split_tensors_from_obj
(
v
,
tensors
)
for
v
in
obj
]
elif
isinstance
(
obj
,
tuple
):
return
tuple
(
_split_tensors_from_obj
(
v
,
tensors
)
for
v
in
obj
)
elif
isinstance
(
obj
,
set
):
return
{
_split_tensors_from_obj
(
v
,
tensors
)
for
v
in
obj
}
else
:
return
obj
def
_put_tensors_in_obj
(
obj
:
Any
,
tensors
:
List
[
torch
.
Tensor
])
->
Any
:
if
isinstance
(
obj
,
_TensorPlaceholder
):
return
tensors
[
obj
.
index
]
elif
isinstance
(
obj
,
dict
):
return
{
k
:
_put_tensors_in_obj
(
v
,
tensors
)
for
k
,
v
in
obj
.
items
()}
elif
isinstance
(
obj
,
list
):
return
[
_put_tensors_in_obj
(
v
,
tensors
)
for
v
in
obj
]
elif
isinstance
(
obj
,
tuple
):
return
tuple
(
_put_tensors_in_obj
(
v
,
tensors
)
for
v
in
obj
)
elif
isinstance
(
obj
,
set
):
return
{
_put_tensors_in_obj
(
v
,
tensors
)
for
v
in
obj
}
else
:
return
obj
Uni-Core-main/unicore/logging/__init__.py
0 → 100644
View file @
a1c29028
Uni-Core-main/unicore/logging/meters.py
0 → 100644
View file @
a1c29028
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
bisect
import
time
from
collections
import
OrderedDict
from
typing
import
Dict
,
Optional
try
:
import
torch
def
type_as
(
a
,
b
):
if
torch
.
is_tensor
(
a
)
and
torch
.
is_tensor
(
b
):
return
a
.
to
(
b
)
else
:
return
a
except
ImportError
:
torch
=
None
def
type_as
(
a
,
b
):
return
a
try
:
import
numpy
as
np
except
ImportError
:
np
=
None
class
Meter
(
object
):
"""Base class for Meters."""
def
__init__
(
self
):
pass
def
state_dict
(
self
):
return
{}
def
load_state_dict
(
self
,
state_dict
):
pass
def
reset
(
self
):
raise
NotImplementedError
@
property
def
smoothed_value
(
self
)
->
float
:
"""Smoothed value used for logging."""
raise
NotImplementedError
def
safe_round
(
number
,
ndigits
):
if
hasattr
(
number
,
"__round__"
):
return
round
(
number
,
ndigits
)
elif
torch
is
not
None
and
torch
.
is_tensor
(
number
)
and
number
.
numel
()
==
1
:
return
safe_round
(
number
.
item
(),
ndigits
)
elif
np
is
not
None
and
np
.
ndim
(
number
)
==
0
and
hasattr
(
number
,
"item"
):
return
safe_round
(
number
.
item
(),
ndigits
)
else
:
return
number
class
AverageMeter
(
Meter
):
"""Computes and stores the average and current value"""
def
__init__
(
self
,
round
:
Optional
[
int
]
=
None
):
self
.
round
=
round
self
.
reset
()
def
reset
(
self
):
self
.
val
=
None
# most recent update
self
.
sum
=
0
# sum from all updates
self
.
count
=
0
# total n from all updates
def
update
(
self
,
val
,
n
=
1
):
if
val
is
not
None
:
self
.
val
=
val
if
n
>
0
:
self
.
sum
=
type_as
(
self
.
sum
,
val
)
+
(
val
*
n
)
self
.
count
=
type_as
(
self
.
count
,
n
)
+
n
def
state_dict
(
self
):
return
{
"val"
:
self
.
val
,
"sum"
:
self
.
sum
,
"count"
:
self
.
count
,
"round"
:
self
.
round
,
}
def
load_state_dict
(
self
,
state_dict
):
self
.
val
=
state_dict
[
"val"
]
self
.
sum
=
state_dict
[
"sum"
]
self
.
count
=
state_dict
[
"count"
]
self
.
round
=
state_dict
.
get
(
"round"
,
None
)
@
property
def
avg
(
self
):
return
self
.
sum
/
self
.
count
if
self
.
count
>
0
else
self
.
val
@
property
def
smoothed_value
(
self
)
->
float
:
val
=
self
.
avg
if
self
.
round
is
not
None
and
val
is
not
None
:
val
=
safe_round
(
val
,
self
.
round
)
return
val
class
TimeMeter
(
Meter
):
"""Computes the average occurrence of some event per second"""
def
__init__
(
self
,
init
:
int
=
0
,
n
:
int
=
0
,
round
:
Optional
[
int
]
=
None
,
):
self
.
round
=
round
self
.
reset
(
init
,
n
)
def
reset
(
self
,
init
=
0
,
n
=
0
):
self
.
init
=
init
self
.
start
=
time
.
perf_counter
()
self
.
n
=
n
self
.
i
=
0
def
update
(
self
,
val
=
1
):
self
.
n
=
type_as
(
self
.
n
,
val
)
+
val
self
.
i
+=
1
def
state_dict
(
self
):
return
{
"init"
:
self
.
elapsed_time
,
"n"
:
self
.
n
,
"round"
:
self
.
round
,
}
def
load_state_dict
(
self
,
state_dict
):
if
"start"
in
state_dict
:
# backwards compatibility for old state_dicts
self
.
reset
(
init
=
state_dict
[
"init"
])
else
:
self
.
reset
(
init
=
state_dict
[
"init"
],
n
=
state_dict
[
"n"
])
self
.
round
=
state_dict
.
get
(
"round"
,
None
)
@
property
def
avg
(
self
):
return
self
.
n
/
self
.
elapsed_time
@
property
def
elapsed_time
(
self
):
return
self
.
init
+
(
time
.
perf_counter
()
-
self
.
start
)
@
property
def
smoothed_value
(
self
)
->
float
:
val
=
self
.
avg
if
self
.
round
is
not
None
and
val
is
not
None
:
val
=
safe_round
(
val
,
self
.
round
)
return
val
class
StopwatchMeter
(
Meter
):
"""Computes the sum/avg duration of some event in seconds"""
def
__init__
(
self
,
round
:
Optional
[
int
]
=
None
):
self
.
round
=
round
self
.
sum
=
0
self
.
n
=
0
self
.
start_time
=
None
def
start
(
self
):
self
.
start_time
=
time
.
perf_counter
()
def
stop
(
self
,
n
=
1
,
prehook
=
None
):
if
self
.
start_time
is
not
None
:
if
prehook
is
not
None
:
prehook
()
delta
=
time
.
perf_counter
()
-
self
.
start_time
self
.
sum
=
self
.
sum
+
delta
self
.
n
=
type_as
(
self
.
n
,
n
)
+
n
def
reset
(
self
):
self
.
sum
=
0
# cumulative time during which stopwatch was active
self
.
n
=
0
# total n across all start/stop
self
.
start
()
def
state_dict
(
self
):
return
{
"sum"
:
self
.
sum
,
"n"
:
self
.
n
,
"round"
:
self
.
round
,
}
def
load_state_dict
(
self
,
state_dict
):
self
.
sum
=
state_dict
[
"sum"
]
self
.
n
=
state_dict
[
"n"
]
self
.
start_time
=
None
self
.
round
=
state_dict
.
get
(
"round"
,
None
)
@
property
def
avg
(
self
):
return
self
.
sum
/
self
.
n
if
self
.
n
>
0
else
self
.
sum
@
property
def
elapsed_time
(
self
):
if
self
.
start_time
is
None
:
return
0.0
return
time
.
perf_counter
()
-
self
.
start_time
@
property
def
smoothed_value
(
self
)
->
float
:
val
=
self
.
avg
if
self
.
sum
>
0
else
self
.
elapsed_time
if
self
.
round
is
not
None
and
val
is
not
None
:
val
=
safe_round
(
val
,
self
.
round
)
return
val
class
MetersDict
(
OrderedDict
):
"""A sorted dictionary of :class:`Meters`.
Meters are sorted according to a priority that is given when the
meter is first added to the dictionary.
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
priorities
=
[]
def
__setitem__
(
self
,
key
,
value
):
assert
key
not
in
self
,
"MetersDict doesn't support reassignment"
priority
,
value
=
value
bisect
.
insort
(
self
.
priorities
,
(
priority
,
len
(
self
.
priorities
),
key
))
super
().
__setitem__
(
key
,
value
)
for
_
,
_
,
key
in
self
.
priorities
:
# reorder dict to match priorities
self
.
move_to_end
(
key
)
def
add_meter
(
self
,
key
,
meter
,
priority
):
self
.
__setitem__
(
key
,
(
priority
,
meter
))
def
state_dict
(
self
):
return
[
(
pri
,
key
,
self
[
key
].
__class__
.
__name__
,
self
[
key
].
state_dict
())
for
pri
,
_
,
key
in
self
.
priorities
# can't serialize DerivedMeter instances
if
not
isinstance
(
self
[
key
],
MetersDict
.
_DerivedMeter
)
]
def
load_state_dict
(
self
,
state_dict
):
self
.
clear
()
self
.
priorities
.
clear
()
for
pri
,
key
,
meter_cls
,
meter_state
in
state_dict
:
meter
=
globals
()[
meter_cls
]()
meter
.
load_state_dict
(
meter_state
)
self
.
add_meter
(
key
,
meter
,
pri
)
def
get_smoothed_value
(
self
,
key
:
str
)
->
float
:
"""Get a single smoothed value."""
meter
=
self
[
key
]
if
isinstance
(
meter
,
MetersDict
.
_DerivedMeter
):
return
meter
.
fn
(
self
)
else
:
return
meter
.
smoothed_value
def
get_smoothed_values
(
self
)
->
Dict
[
str
,
float
]:
"""Get all smoothed values."""
return
OrderedDict
(
[
(
key
,
self
.
get_smoothed_value
(
key
))
for
key
in
self
.
keys
()
if
not
key
.
startswith
(
"_"
)
]
)
def
reset
(
self
):
"""Reset Meter instances."""
for
meter
in
self
.
values
():
if
isinstance
(
meter
,
MetersDict
.
_DerivedMeter
):
continue
meter
.
reset
()
class
_DerivedMeter
(
Meter
):
"""A Meter whose values are derived from other Meters."""
def
__init__
(
self
,
fn
):
self
.
fn
=
fn
def
reset
(
self
):
pass
Uni-Core-main/unicore/logging/metrics.py
0 → 100644
View file @
a1c29028
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
A standalone module for aggregating metrics.
Metrics can be logged from anywhere using the `log_*` functions defined
in this module. The logged values will be aggregated dynamically based
on the aggregation context in which the logging occurs. See the
:func:`aggregate` context manager for more details.
"""
import
contextlib
import
uuid
from
collections
import
OrderedDict
,
defaultdict
from
typing
import
Callable
,
Dict
,
List
,
Optional
from
.meters
import
*
# Aggregation contexts are considered "active" when inside the scope
# created by the :func:`aggregate` context manager.
_aggregators
=
OrderedDict
()
_active_aggregators
=
OrderedDict
()
_active_aggregators_cnt
=
defaultdict
(
lambda
:
0
)
def
reset
()
->
None
:
"""Reset all metrics aggregators."""
_aggregators
.
clear
()
_active_aggregators
.
clear
()
_active_aggregators_cnt
.
clear
()
# The "default" aggregator observes all logged values.
_aggregators
[
"default"
]
=
MetersDict
()
_active_aggregators
[
"default"
]
=
_aggregators
[
"default"
]
_active_aggregators_cnt
[
"default"
]
=
1
reset
()
@
contextlib
.
contextmanager
def
aggregate
(
name
:
Optional
[
str
]
=
None
,
new_root
:
bool
=
False
):
"""Context manager to aggregate metrics under a given name.
Aggregations can be nested. If *new_root* is ``False``, then logged
metrics will be recorded along the entire stack of nested
aggregators, including a global "default" aggregator. If *new_root*
is ``True``, then this aggregator will be the root of a new
aggregation stack, thus bypassing any parent aggregators.
Note that aggregation contexts are uniquely identified by their
*name* (e.g., train, valid). Creating a context with an existing
name will reuse the corresponding :class:`MetersDict` instance.
If no name is given, then a temporary aggregator will be created.
Usage::
with metrics.aggregate("train"):
for step, batch in enumerate(epoch):
with metrics.aggregate("train_inner") as agg:
metrics.log_scalar("loss", get_loss(batch))
if step % log_interval == 0:
print(agg.get_smoothed_value("loss"))
agg.reset()
print(metrics.get_smoothed_values("train")["loss"])
Args:
name (str): name of the aggregation. Defaults to a
random/temporary name if not given explicitly.
new_root (bool): make this aggregation the root of a new
aggregation stack.
"""
if
name
is
None
:
# generate a temporary name
name
=
str
(
uuid
.
uuid4
())
assert
name
not
in
_aggregators
agg
=
MetersDict
()
else
:
assert
name
!=
"default"
agg
=
_aggregators
.
setdefault
(
name
,
MetersDict
())
if
new_root
:
backup_aggregators
=
_active_aggregators
.
copy
()
_active_aggregators
.
clear
()
backup_aggregators_cnt
=
_active_aggregators_cnt
.
copy
()
_active_aggregators_cnt
.
clear
()
_active_aggregators
[
name
]
=
agg
_active_aggregators_cnt
[
name
]
+=
1
yield
agg
_active_aggregators_cnt
[
name
]
-=
1
if
_active_aggregators_cnt
[
name
]
==
0
and
name
in
_active_aggregators
:
del
_active_aggregators
[
name
]
if
new_root
:
_active_aggregators
.
clear
()
_active_aggregators
.
update
(
backup_aggregators
)
_active_aggregators_cnt
.
clear
()
_active_aggregators_cnt
.
update
(
backup_aggregators_cnt
)
def
get_active_aggregators
()
->
List
[
MetersDict
]:
return
list
(
_active_aggregators
.
values
())
def
log_scalar
(
key
:
str
,
value
:
float
,
weight
:
float
=
1
,
priority
:
int
=
10
,
round
:
Optional
[
int
]
=
None
,
):
"""Log a scalar value.
Args:
key (str): name of the field to log
value (float): value to log
weight (float): weight that this value contributes to the average.
A weight of 0 will always log the latest value.
priority (int): smaller values are logged earlier in the output
round (Optional[int]): number of digits to round to when displaying
"""
for
agg
in
get_active_aggregators
():
if
key
not
in
agg
:
agg
.
add_meter
(
key
,
AverageMeter
(
round
=
round
),
priority
)
agg
[
key
].
update
(
value
,
weight
)
def
log_derived
(
key
:
str
,
fn
:
Callable
[[
MetersDict
],
float
],
priority
:
int
=
20
):
"""Log a scalar value derived from other meters.
Args:
key (str): name of the field to log
fn (Callable[[MetersDict], float]): function that takes a single
argument *meters* and returns the derived value
priority (int): smaller values are logged earlier in the output
"""
for
agg
in
get_active_aggregators
():
if
key
not
in
agg
:
agg
.
add_meter
(
key
,
MetersDict
.
_DerivedMeter
(
fn
),
priority
)
def
log_speed
(
key
:
str
,
value
:
float
,
priority
:
int
=
30
,
round
:
Optional
[
int
]
=
None
,
):
"""Log the rate of some quantity per second.
Args:
key (str): name of the field to log
value (float): value to log
priority (int): smaller values are logged earlier in the output
round (Optional[int]): number of digits to round to when displaying
"""
for
agg
in
get_active_aggregators
():
if
key
not
in
agg
:
agg
.
add_meter
(
key
,
TimeMeter
(
round
=
round
),
priority
)
agg
[
key
].
reset
()
# reset meter on the first call
else
:
agg
[
key
].
update
(
value
)
def
log_start_time
(
key
:
str
,
priority
:
int
=
40
,
round
:
Optional
[
int
]
=
None
):
"""Log the duration of some event in seconds.
The duration will be computed once :func:`log_stop_time` is called.
Args:
key (str): name of the field to log
priority (int): smaller values are logged earlier in the output
round (Optional[int]): number of digits to round to when displaying
"""
for
agg
in
get_active_aggregators
():
if
key
not
in
agg
:
agg
.
add_meter
(
key
,
StopwatchMeter
(
round
=
round
),
priority
)
agg
[
key
].
start
()
def
log_stop_time
(
key
:
str
,
weight
:
float
=
0.0
,
prehook
=
None
):
"""Log the duration of some event in seconds.
The duration will be computed since :func:`log_start_time` was called.
Set weight > 0 to report the average time instead of the sum.
Args:
key (str): name of the field to log
weight (float): weight that this time contributes to the average
prehook (function, no arguments): will be called before the timer
is stopped. For example, use prehook=torch.cuda.synchronize to
make sure all gpu operations are done before timer is stopped.
"""
for
agg
in
get_active_aggregators
():
if
key
in
agg
:
agg
[
key
].
stop
(
weight
,
prehook
)
def
log_custom
(
new_meter_fn
:
Callable
[[],
Meter
],
key
:
str
,
*
args
,
priority
:
int
=
50
,
**
kwargs
,
):
"""Log using a custom Meter.
Any extra *args* or *kwargs* will be passed through to the Meter's
*update* method.
Args:
new_meter_fn (Callable[[], Meter]): function that returns a new
Meter instance
key (str): name of the field to log
priority (int): smaller values are logged earlier in the output
"""
for
agg
in
get_active_aggregators
():
if
key
not
in
agg
:
agg
.
add_meter
(
key
,
new_meter_fn
(),
priority
)
agg
[
key
].
update
(
*
args
,
**
kwargs
)
def
reset_meter
(
name
:
str
,
key
:
str
)
->
None
:
"""Reset Meter instance aggregated under a given *name* and *key*."""
meter
=
get_meter
(
name
,
key
)
if
meter
is
not
None
:
meter
.
reset
()
def
reset_meters
(
name
:
str
)
->
None
:
"""Reset Meter instances aggregated under a given *name*."""
meters
=
get_meters
(
name
)
if
meters
is
not
None
:
meters
.
reset
()
def
get_meter
(
name
:
str
,
key
:
str
)
->
Meter
:
"""Get a single Meter instance aggregated under *name* and *key*.
Returns:
Meter or None if no metrics have been logged under *name* and *key*.
"""
if
name
not
in
_aggregators
:
return
None
return
_aggregators
[
name
].
get
(
key
,
None
)
def
get_meters
(
name
:
str
)
->
MetersDict
:
"""Get Meter instances aggregated under a given *name*.
Returns:
MetersDict or None if no metrics have been logged under *name*.
"""
return
_aggregators
.
get
(
name
,
None
)
def
get_smoothed_value
(
name
:
str
,
key
:
str
)
->
float
:
"""Get a single smoothed value.
Raises:
KeyError: if no metrics have been logged under *name* and *key*.
"""
return
_aggregators
[
name
].
get_smoothed_value
(
key
)
def
get_smoothed_values
(
name
:
str
)
->
Dict
[
str
,
float
]:
"""Get smoothed values aggregated under a given *name*.
Raises:
KeyError: if no metrics have been logged under *name*.
"""
return
_aggregators
[
name
].
get_smoothed_values
()
def
state_dict
():
return
OrderedDict
([(
name
,
agg
.
state_dict
())
for
name
,
agg
in
_aggregators
.
items
()])
def
load_state_dict
(
state_dict
):
for
name
,
agg_state
in
state_dict
.
items
():
_aggregators
[
name
]
=
MetersDict
()
_aggregators
[
name
].
load_state_dict
(
agg_state
)
Uni-Core-main/unicore/logging/progress_bar.py
0 → 100644
View file @
a1c29028
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Wrapper around various loggers and progress bars (e.g., tqdm).
"""
import
atexit
import
json
import
logging
import
os
import
sys
from
collections
import
OrderedDict
from
contextlib
import
contextmanager
from
numbers
import
Number
from
typing
import
Optional
import
torch
from
.meters
import
AverageMeter
,
StopwatchMeter
,
TimeMeter
logger
=
logging
.
getLogger
(
__name__
)
def
progress_bar
(
iterator
,
log_format
:
Optional
[
str
]
=
None
,
log_interval
:
int
=
100
,
epoch
:
Optional
[
int
]
=
None
,
prefix
:
Optional
[
str
]
=
None
,
tensorboard_logdir
:
Optional
[
str
]
=
None
,
default_log_format
:
str
=
"tqdm"
,
):
if
log_format
is
None
:
log_format
=
default_log_format
if
log_format
==
"tqdm"
and
not
sys
.
stderr
.
isatty
():
log_format
=
"simple"
if
log_format
==
"json"
:
bar
=
JsonProgressBar
(
iterator
,
epoch
,
prefix
,
log_interval
)
elif
log_format
==
"none"
:
bar
=
NoopProgressBar
(
iterator
,
epoch
,
prefix
)
elif
log_format
==
"simple"
:
bar
=
SimpleProgressBar
(
iterator
,
epoch
,
prefix
,
log_interval
)
elif
log_format
==
"tqdm"
:
bar
=
TqdmProgressBar
(
iterator
,
epoch
,
prefix
)
else
:
raise
ValueError
(
"Unknown log format: {}"
.
format
(
log_format
))
if
tensorboard_logdir
:
try
:
# [FB only] custom wrapper for TensorBoard
import
palaas
# noqa
from
.fb_tbmf_wrapper
import
FbTbmfWrapper
bar
=
FbTbmfWrapper
(
bar
,
log_interval
)
except
ImportError
:
bar
=
TensorboardProgressBarWrapper
(
bar
,
tensorboard_logdir
)
return
bar
def
build_progress_bar
(
args
,
iterator
,
epoch
:
Optional
[
int
]
=
None
,
prefix
:
Optional
[
str
]
=
None
,
default
:
str
=
"tqdm"
,
no_progress_bar
:
str
=
"none"
,
):
"""Legacy wrapper that takes an argparse.Namespace."""
if
getattr
(
args
,
"no_progress_bar"
,
False
):
default
=
no_progress_bar
if
getattr
(
args
,
"distributed_rank"
,
0
)
==
0
:
tensorboard_logdir
=
getattr
(
args
,
"tensorboard_logdir"
,
None
)
else
:
tensorboard_logdir
=
None
return
progress_bar
(
iterator
,
log_format
=
args
.
log_format
,
log_interval
=
args
.
log_interval
,
epoch
=
epoch
,
prefix
=
prefix
,
tensorboard_logdir
=
tensorboard_logdir
,
default_log_format
=
default
,
)
def
format_stat
(
stat
):
if
isinstance
(
stat
,
Number
):
stat
=
"{:g}"
.
format
(
stat
)
elif
isinstance
(
stat
,
AverageMeter
):
stat
=
"{:.3f}"
.
format
(
stat
.
avg
)
elif
isinstance
(
stat
,
TimeMeter
):
stat
=
"{:g}"
.
format
(
round
(
stat
.
avg
))
elif
isinstance
(
stat
,
StopwatchMeter
):
stat
=
"{:g}"
.
format
(
round
(
stat
.
sum
))
elif
torch
.
is_tensor
(
stat
):
stat
=
stat
.
tolist
()
return
stat
class
BaseProgressBar
(
object
):
"""Abstract class for progress bars."""
def
__init__
(
self
,
iterable
,
epoch
=
None
,
prefix
=
None
):
self
.
iterable
=
iterable
self
.
n
=
getattr
(
iterable
,
"n"
,
0
)
self
.
epoch
=
epoch
self
.
prefix
=
""
if
epoch
is
not
None
:
self
.
prefix
+=
"epoch {:03d}"
.
format
(
epoch
)
if
prefix
is
not
None
:
self
.
prefix
+=
(
" | "
if
self
.
prefix
!=
""
else
""
)
+
prefix
def
__len__
(
self
):
return
len
(
self
.
iterable
)
def
__enter__
(
self
):
return
self
def
__exit__
(
self
,
*
exc
):
return
False
def
__iter__
(
self
):
raise
NotImplementedError
def
log
(
self
,
stats
,
tag
=
None
,
step
=
None
):
"""Log intermediate stats according to log_interval."""
raise
NotImplementedError
def
print
(
self
,
stats
,
tag
=
None
,
step
=
None
):
"""Print end-of-epoch stats."""
raise
NotImplementedError
def
update_config
(
self
,
config
):
"""Log latest configuration."""
pass
def
_str_commas
(
self
,
stats
):
return
", "
.
join
(
key
+
"="
+
stats
[
key
].
strip
()
for
key
in
stats
.
keys
())
def
_str_pipes
(
self
,
stats
):
return
" | "
.
join
(
key
+
" "
+
stats
[
key
].
strip
()
for
key
in
stats
.
keys
())
def
_format_stats
(
self
,
stats
):
postfix
=
OrderedDict
(
stats
)
# Preprocess stats according to datatype
for
key
in
postfix
.
keys
():
postfix
[
key
]
=
str
(
format_stat
(
postfix
[
key
]))
return
postfix
@
contextmanager
def
rename_logger
(
logger
,
new_name
):
old_name
=
logger
.
name
if
new_name
is
not
None
:
logger
.
name
=
new_name
yield
logger
logger
.
name
=
old_name
class
JsonProgressBar
(
BaseProgressBar
):
"""Log output in JSON format."""
def
__init__
(
self
,
iterable
,
epoch
=
None
,
prefix
=
None
,
log_interval
=
1000
):
super
().
__init__
(
iterable
,
epoch
,
prefix
)
self
.
log_interval
=
log_interval
self
.
i
=
None
self
.
size
=
None
def
__iter__
(
self
):
self
.
size
=
len
(
self
.
iterable
)
for
i
,
obj
in
enumerate
(
self
.
iterable
,
start
=
self
.
n
):
self
.
i
=
i
yield
obj
def
log
(
self
,
stats
,
tag
=
None
,
step
=
None
):
"""Log intermediate stats according to log_interval."""
step
=
step
or
self
.
i
or
0
if
step
>
0
and
self
.
log_interval
is
not
None
and
step
%
self
.
log_interval
==
0
:
update
=
(
self
.
epoch
-
1
+
(
self
.
i
+
1
)
/
float
(
self
.
size
)
if
self
.
epoch
is
not
None
else
None
)
stats
=
self
.
_format_stats
(
stats
,
epoch
=
self
.
epoch
,
update
=
update
)
with
rename_logger
(
logger
,
tag
):
logger
.
info
(
json
.
dumps
(
stats
))
def
print
(
self
,
stats
,
tag
=
None
,
step
=
None
):
"""Print end-of-epoch stats."""
self
.
stats
=
stats
if
tag
is
not
None
:
self
.
stats
=
OrderedDict
(
[(
tag
+
"_"
+
k
,
v
)
for
k
,
v
in
self
.
stats
.
items
()]
)
stats
=
self
.
_format_stats
(
self
.
stats
,
epoch
=
self
.
epoch
)
with
rename_logger
(
logger
,
tag
):
logger
.
info
(
json
.
dumps
(
stats
))
def
_format_stats
(
self
,
stats
,
epoch
=
None
,
update
=
None
):
postfix
=
OrderedDict
()
if
epoch
is
not
None
:
postfix
[
"epoch"
]
=
epoch
if
update
is
not
None
:
postfix
[
"update"
]
=
round
(
update
,
3
)
# Preprocess stats according to datatype
for
key
in
stats
.
keys
():
postfix
[
key
]
=
format_stat
(
stats
[
key
])
return
postfix
class
NoopProgressBar
(
BaseProgressBar
):
"""No logging."""
def
__init__
(
self
,
iterable
,
epoch
=
None
,
prefix
=
None
):
super
().
__init__
(
iterable
,
epoch
,
prefix
)
def
__iter__
(
self
):
for
obj
in
self
.
iterable
:
yield
obj
def
log
(
self
,
stats
,
tag
=
None
,
step
=
None
):
"""Log intermediate stats according to log_interval."""
pass
def
print
(
self
,
stats
,
tag
=
None
,
step
=
None
):
"""Print end-of-epoch stats."""
pass
class
SimpleProgressBar
(
BaseProgressBar
):
"""A minimal logger for non-TTY environments."""
def
__init__
(
self
,
iterable
,
epoch
=
None
,
prefix
=
None
,
log_interval
=
1000
):
super
().
__init__
(
iterable
,
epoch
,
prefix
)
self
.
log_interval
=
log_interval
self
.
i
=
None
self
.
size
=
None
def
__iter__
(
self
):
self
.
size
=
len
(
self
.
iterable
)
for
i
,
obj
in
enumerate
(
self
.
iterable
,
start
=
self
.
n
):
self
.
i
=
i
yield
obj
def
log
(
self
,
stats
,
tag
=
None
,
step
=
None
):
"""Log intermediate stats according to log_interval."""
step
=
step
or
self
.
i
or
0
if
step
>
0
and
self
.
log_interval
is
not
None
and
step
%
self
.
log_interval
==
0
:
stats
=
self
.
_format_stats
(
stats
)
postfix
=
self
.
_str_commas
(
stats
)
with
rename_logger
(
logger
,
tag
):
logger
.
info
(
"{}: {:5d} / {:d} {}"
.
format
(
self
.
prefix
,
self
.
i
+
1
,
self
.
size
,
postfix
)
)
def
print
(
self
,
stats
,
tag
=
None
,
step
=
None
):
"""Print end-of-epoch stats."""
postfix
=
self
.
_str_pipes
(
self
.
_format_stats
(
stats
))
with
rename_logger
(
logger
,
tag
):
logger
.
info
(
"{} | {}"
.
format
(
self
.
prefix
,
postfix
))
class
TqdmProgressBar
(
BaseProgressBar
):
"""Log to tqdm."""
def
__init__
(
self
,
iterable
,
epoch
=
None
,
prefix
=
None
):
super
().
__init__
(
iterable
,
epoch
,
prefix
)
from
tqdm
import
tqdm
self
.
tqdm
=
tqdm
(
iterable
,
self
.
prefix
,
leave
=
False
,
disable
=
(
logger
.
getEffectiveLevel
()
>
logging
.
INFO
),
)
def
__iter__
(
self
):
return
iter
(
self
.
tqdm
)
def
log
(
self
,
stats
,
tag
=
None
,
step
=
None
):
"""Log intermediate stats according to log_interval."""
self
.
tqdm
.
set_postfix
(
self
.
_format_stats
(
stats
),
refresh
=
False
)
def
print
(
self
,
stats
,
tag
=
None
,
step
=
None
):
"""Print end-of-epoch stats."""
postfix
=
self
.
_str_pipes
(
self
.
_format_stats
(
stats
))
with
rename_logger
(
logger
,
tag
):
logger
.
info
(
"{} | {}"
.
format
(
self
.
prefix
,
postfix
))
try
:
_tensorboard_writers
=
{}
from
torch.utils.tensorboard
import
SummaryWriter
except
ImportError
:
try
:
from
tensorboardX
import
SummaryWriter
except
ImportError
:
SummaryWriter
=
None
def
_close_writers
():
for
w
in
_tensorboard_writers
.
values
():
w
.
close
()
atexit
.
register
(
_close_writers
)
class
TensorboardProgressBarWrapper
(
BaseProgressBar
):
"""Log to tensorboard."""
def
__init__
(
self
,
wrapped_bar
,
tensorboard_logdir
):
self
.
wrapped_bar
=
wrapped_bar
self
.
tensorboard_logdir
=
tensorboard_logdir
if
SummaryWriter
is
None
:
logger
.
warning
(
"tensorboard not found, please install with: pip install tensorboard"
)
def
_writer
(
self
,
key
):
if
SummaryWriter
is
None
:
return
None
_writers
=
_tensorboard_writers
if
key
not
in
_writers
:
_writers
[
key
]
=
SummaryWriter
(
os
.
path
.
join
(
self
.
tensorboard_logdir
,
key
))
_writers
[
key
].
add_text
(
"sys.argv"
,
" "
.
join
(
sys
.
argv
))
return
_writers
[
key
]
def
__iter__
(
self
):
return
iter
(
self
.
wrapped_bar
)
def
log
(
self
,
stats
,
tag
=
None
,
step
=
None
):
"""Log intermediate stats to tensorboard."""
self
.
_log_to_tensorboard
(
stats
,
tag
,
step
)
self
.
wrapped_bar
.
log
(
stats
,
tag
=
tag
,
step
=
step
)
def
print
(
self
,
stats
,
tag
=
None
,
step
=
None
):
"""Print end-of-epoch stats."""
self
.
_log_to_tensorboard
(
stats
,
tag
,
step
)
self
.
wrapped_bar
.
print
(
stats
,
tag
=
tag
,
step
=
step
)
def
update_config
(
self
,
config
):
"""Log latest configuration."""
# TODO add hparams to Tensorboard
self
.
wrapped_bar
.
update_config
(
config
)
def
_log_to_tensorboard
(
self
,
stats
,
tag
=
None
,
step
=
None
):
writer
=
self
.
_writer
(
tag
or
""
)
if
writer
is
None
:
return
if
step
is
None
:
step
=
stats
[
"num_updates"
]
for
key
in
stats
.
keys
()
-
{
"num_updates"
}:
if
isinstance
(
stats
[
key
],
AverageMeter
):
writer
.
add_scalar
(
key
,
stats
[
key
].
val
,
step
)
elif
isinstance
(
stats
[
key
],
Number
):
writer
.
add_scalar
(
key
,
stats
[
key
],
step
)
elif
torch
.
is_tensor
(
stats
[
key
])
and
stats
[
key
].
numel
()
==
1
:
writer
.
add_scalar
(
key
,
stats
[
key
].
item
(),
step
)
writer
.
flush
()
Uni-Core-main/unicore/losses/__init__.py
0 → 100644
View file @
a1c29028
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""isort:skip_file"""
import
importlib
import
os
from
unicore
import
registry
from
unicore.losses.unicore_loss
import
(
# noqa
UnicoreLoss
,
)
(
build_loss_
,
register_loss
,
CRITERION_REGISTRY
,
)
=
registry
.
setup_registry
(
"--loss"
,
base_class
=
UnicoreLoss
,
default
=
"cross_entropy"
)
def
build_loss
(
args
,
task
):
return
build_loss_
(
args
,
task
)
# automatically import any Python files in the losses/ directory
for
file
in
os
.
listdir
(
os
.
path
.
dirname
(
__file__
)):
if
file
.
endswith
(
".py"
)
and
not
file
.
startswith
(
"_"
):
file_name
=
file
[:
file
.
find
(
".py"
)]
importlib
.
import_module
(
"unicore.losses."
+
file_name
)
Prev
1
2
3
4
5
6
7
8
9
…
16
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment