Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
d6186284
Unverified
Commit
d6186284
authored
Jul 31, 2020
by
chicm-ms
Committed by
GitHub
Jul 31, 2020
Browse files
update gradient_selector dataloader iterator import (#2690)
parent
717877d0
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
28 deletions
+21
-28
src/sdk/pynni/nni/feature_engineering/gradient_selector/fginitialize.py
...nni/feature_engineering/gradient_selector/fginitialize.py
+14
-26
src/sdk/pynni/nni/feature_engineering/gradient_selector/learnability.py
...nni/feature_engineering/gradient_selector/learnability.py
+7
-2
No files found.
src/sdk/pynni/nni/feature_engineering/gradient_selector/fginitialize.py
View file @
d6186284
...
@@ -31,7 +31,7 @@ from sklearn.datasets import load_svmlight_file
...
@@ -31,7 +31,7 @@ from sklearn.datasets import load_svmlight_file
import
torch
import
torch
from
torch.utils.data
import
DataLoader
,
Dataset
from
torch.utils.data
import
DataLoader
,
Dataset
# pylint: disable=E0611
# pylint: disable=E0611
from
torch.utils.data.dataloader
import
_DataLoaderIter
,
_utils
from
torch.utils.data.dataloader
import
_
SingleProcessDataLoaderIter
,
_MultiProcessing
DataLoaderIter
,
_utils
from
.
import
constants
from
.
import
constants
from
.
import
syssettings
from
.
import
syssettings
...
@@ -585,39 +585,27 @@ class ChunkDataLoader(DataLoader):
...
@@ -585,39 +585,27 @@ class ChunkDataLoader(DataLoader):
return
_ChunkDataLoaderIter
(
self
)
return
_ChunkDataLoaderIter
(
self
)
class
_ChunkDataLoaderIter
(
_DataLoaderIter
)
:
class
_ChunkDataLoaderIter
:
"""
"""
DataLoaderIter class used to more quickly load a batch of indices at once.
DataLoaderIter class used to more quickly load a batch of indices at once.
"""
"""
def
__init__
(
self
,
dataloader
):
if
dataloader
.
num_workers
==
0
:
self
.
iter
=
_SingleProcessDataLoaderIter
(
dataloader
)
else
:
self
.
iter
=
_MultiProcessingDataLoaderIter
(
dataloader
)
def
__next__
(
self
):
def
__next__
(
self
):
# only chunk that is edited from base
# only chunk that is edited from base
if
self
.
num_workers
==
0
:
# same-process loading
if
self
.
iter
.
_
num_workers
==
0
:
# same-process loading
indices
=
next
(
self
.
sample_iter
)
# may raise StopIteration
indices
=
next
(
self
.
iter
.
_
sample
r
_iter
)
# may raise StopIteration
if
len
(
indices
)
>
1
:
if
len
(
indices
)
>
1
:
batch
=
self
.
dataset
[
np
.
array
(
indices
)]
batch
=
self
.
iter
.
_
dataset
[
np
.
array
(
indices
)]
else
:
else
:
batch
=
self
.
collate_fn
([
self
.
dataset
[
i
]
for
i
in
indices
])
batch
=
self
.
iter
.
_
collate_fn
([
self
.
iter
.
_
dataset
[
i
]
for
i
in
indices
])
if
self
.
pin_memory
:
if
self
.
iter
.
_
pin_memory
:
batch
=
_utils
.
pin_memory
.
pin_memory_batch
(
batch
)
batch
=
_utils
.
pin_memory
.
pin_memory_batch
(
batch
)
return
batch
return
batch
else
:
# check if the next sample has already been generated
return
next
(
self
.
iter
)
if
self
.
rcvd_idx
in
self
.
reorder_dict
:
batch
=
self
.
reorder_dict
.
pop
(
self
.
rcvd_idx
)
return
self
.
_process_next_batch
(
batch
)
if
self
.
batches_outstanding
==
0
:
self
.
_shutdown_workers
()
raise
StopIteration
while
True
:
assert
(
not
self
.
shutdown
and
self
.
batches_outstanding
>
0
)
idx
,
batch
=
self
.
_get_batch
()
self
.
batches_outstanding
-=
1
if
idx
!=
self
.
rcvd_idx
:
# store out-of-order samples
self
.
reorder_dict
[
idx
]
=
batch
continue
return
self
.
_process_next_batch
(
batch
)
src/sdk/pynni/nni/feature_engineering/gradient_selector/learnability.py
View file @
d6186284
...
@@ -287,6 +287,11 @@ class Solver(nn.Module):
...
@@ -287,6 +287,11 @@ class Solver(nn.Module):
else
:
else
:
pin_memory
=
False
pin_memory
=
False
if
num_workers
==
0
:
timeout
=
0
else
:
timeout
=
60
self
.
ds_train
=
ChunkDataLoader
(
self
.
ds_train
=
ChunkDataLoader
(
PreparedData
,
PreparedData
,
batch_size
=
self
.
Nminibatch
,
batch_size
=
self
.
Nminibatch
,
...
@@ -294,7 +299,7 @@ class Solver(nn.Module):
...
@@ -294,7 +299,7 @@ class Solver(nn.Module):
drop_last
=
True
,
drop_last
=
True
,
num_workers
=
num_workers
,
num_workers
=
num_workers
,
pin_memory
=
pin_memory
,
pin_memory
=
pin_memory
,
timeout
=
60
)
timeout
=
timeout
)
self
.
f_train
=
LearnabilityMB
(
self
.
Nminibatch
,
self
.
D
,
self
.
f_train
=
LearnabilityMB
(
self
.
Nminibatch
,
self
.
D
,
constants
.
Coefficients
.
SLE
[
order
],
constants
.
Coefficients
.
SLE
[
order
],
self
.
groups
,
self
.
groups
,
...
@@ -338,7 +343,7 @@ class Solver(nn.Module):
...
@@ -338,7 +343,7 @@ class Solver(nn.Module):
Completes the forward operation and computes gradients for learnability and penalty.
Completes the forward operation and computes gradients for learnability and penalty.
"""
"""
f_train
=
self
.
f_train
(
s
,
xsub
,
ysub
)
f_train
=
self
.
f_train
(
s
,
xsub
,
ysub
)
pen
=
self
.
penalty
(
s
)
pen
=
self
.
penalty
(
s
)
.
unsqueeze
(
0
).
unsqueeze
(
0
)
# pylint: disable=E1102
# pylint: disable=E1102
grad_outputs
=
torch
.
tensor
([[
1
]],
dtype
=
torch
.
get_default_dtype
(),
grad_outputs
=
torch
.
tensor
([[
1
]],
dtype
=
torch
.
get_default_dtype
(),
device
=
self
.
device
)
device
=
self
.
device
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment