Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ShuffleNetV2_pytorch
Commits
b8d3ff26
Commit
b8d3ff26
authored
Mar 05, 2025
by
Sugon_ldc
Browse files
add multi.py
parent
b7a31755
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
150 additions
and
0 deletions
+150
-0
ShuffleNetV2Driver_multi.py
ShuffleNetV2Driver_multi.py
+150
-0
No files found.
ShuffleNetV2Driver_multi.py
0 → 100644
View file @
b8d3ff26
import
os
from
fitlog
import
FitLog
from
ShuffleNet.model
import
ShuffleNetV2
from
iframe_feeder
import
iframe_feeder
from
torch.utils.data
import
DataLoader
import
torch
from
tqdm
import
tqdm
,
trange
from
torch.optim.lr_scheduler
import
MultiStepLR
import
torch.nn
as
nn
class
ShuffleNetV2Driver
:
def
__init__
(
self
,
featd
,
feati
,
normd
,
normi
,
lab
):
self
.
batchsize
=
256
self
.
lr
=
0.01
self
.
momentum
=
0.9
self
.
decay
=
4e-5
self
.
gamma
=
0.1
self
.
schedule
=
[
200
,
300
]
self
.
local_rank
=
int
(
os
.
getenv
(
"LOCAL_RANK"
,
-
1
))
self
.
RANK
=
int
(
os
.
getenv
(
"RANK"
,
-
1
))
self
.
feeder
=
iframe_feeder
(
featd
,
feati
,
lab
,
normd
,
normi
)
self
.
feeder
.
set_mode
(
'train'
)
if
self
.
local_rank
>=
0
:
self
.
device
=
torch
.
device
(
'cuda'
,
self
.
local_rank
)
else
:
self
.
device
=
torch
.
device
(
'cuda'
)
if
self
.
local_rank
>=
0
:
self
.
sampler
=
torch
.
utils
.
data
.
distributed
.
DistributedSampler
(
self
.
feeder
)
self
.
loader
=
DataLoader
(
self
.
feeder
,
batch_size
=
self
.
batchsize
,
sampler
=
self
.
sampler
,
shuffle
=
False
)
else
:
self
.
loader
=
DataLoader
(
self
.
feeder
,
batch_size
=
self
.
batchsize
,
shuffle
=
True
,
num_workers
=
0
)
#self.feeder = iframe_feeder(featd, feati, lab, normd, normi)
#self.feeder.set_mode('train')
self
.
model
=
ShuffleNetV2
(
num_classes
=
2
,
scale
=
0.5
,
SE
=
True
,
residual
=
True
)
#torchvision.models.ShuffleNetV2(num_classes=2)#
self
.
fitlog
=
FitLog
()
self
.
detail_log
=
FitLog
(
prefix
=
'dt_'
)
#self.device = self.get_device() #torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print
(
"device info: "
+
str
(
self
.
device
))
self
.
optimizer
=
torch
.
optim
.
SGD
(
self
.
model
.
parameters
(),
\
self
.
lr
,
momentum
=
self
.
momentum
,
weight_decay
=
self
.
decay
,
\
nesterov
=
True
)
self
.
scheduler
=
MultiStepLR
(
self
.
optimizer
,
self
.
schedule
,
self
.
gamma
)
self
.
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
#self.loader = DataLoader(self.feeder, batch_size=self.batchsize, shuffle=True, num_workers=0)
self
.
print_interval
=
25
self
.
n_epoch
=
200
print
(
'self_device:'
,
self
.
device
)
print
(
'self_local_rank:'
,
self
.
local_rank
)
self
.
model
.
to
(
self
.
device
)
self
.
model
=
nn
.
parallel
.
DistributedDataParallel
(
self
.
model
,
device_ids
=
[
self
.
local_rank
],
output_device
=
self
.
local_rank
,
find_unused_parameters
=
True
)
def
get_device
(
self
):
device_mlu
=
None
device_gpu
=
None
try
:
# device_mlu = torch.device('mlu')
device_gpu
=
torch
.
device
(
'cuda'
)
except
Exception
as
err
:
print
(
err
)
if
device_mlu
:
self
.
fitlog
.
append
(
'mlu'
,
True
,
True
)
return
device_mlu
elif
device_gpu
:
self
.
fitlog
.
append
(
'cuda'
,
True
,
True
)
return
device_gpu
else
:
self
.
fitlog
.
append
(
'cpu'
,
True
,
True
)
return
torch
.
device
(
'cpu'
)
def
train
(
self
,
epoch
):
self
.
feeder
.
set_mode
(
'train'
)
self
.
model
.
train
()
nbatch
=
len
(
self
.
loader
)
print
(
'nbatch:'
,
nbatch
)
for
batch_idx
,
(
feats
,
labs
)
in
enumerate
(
tqdm
(
self
.
loader
)):
feats
=
feats
.
to
(
self
.
device
)
labs
=
labs
.
to
(
self
.
device
)
self
.
optimizer
.
zero_grad
()
res
=
self
.
model
(
feats
)
loss
=
self
.
criterion
(
res
,
labs
)
loss
.
backward
()
self
.
optimizer
.
step
()
if
batch_idx
%
self
.
print_interval
==
0
:
xstr
=
"Train: epoch: {} batch: {}/{}, loss: {:.6f}"
.
format
(
epoch
,
batch_idx
,
nbatch
,
loss
)
tqdm
.
write
(
xstr
)
self
.
fitlog
.
append
(
xstr
,
True
,
True
)
def
validate
(
self
):
self
.
feeder
.
set_mode
(
'test'
)
self
.
model
.
eval
()
loss_val
=
0
n_correct
=
0
n_total
=
0
for
batch_idx
,
(
feats
,
labs
)
in
enumerate
(
tqdm
(
self
.
loader
)):
feats
=
feats
.
to
(
self
.
device
)
labs
=
labs
.
to
(
self
.
device
)
with
torch
.
no_grad
():
res
=
self
.
model
(
feats
)
loss_val
+=
self
.
criterion
(
res
,
labs
).
item
()
_
,
pred
=
res
.
max
(
1
)
n_correct
+=
pred
.
eq
(
labs
).
sum
().
item
()
n_total
+=
labs
.
shape
[
0
]
loss_val
=
loss_val
/
len
(
self
.
loader
)
acc
=
n_correct
/
n_total
*
100
self
.
detail_log
.
append
(
str
(
labs
.
tolist
()))
self
.
detail_log
.
append
(
str
(
pred
.
tolist
()))
xstr
=
"Validation: avg loss: {:.4f}, avg acc: {:.4f}%"
.
format
(
loss_val
,
acc
)
tqdm
.
write
(
xstr
)
self
.
fitlog
.
append
(
xstr
,
True
,
True
)
def
finish
(
self
):
self
.
feeder
.
finish
()
self
.
fitlog
.
close
()
self
.
detail_log
.
close
()
def
run
(
self
):
for
i
in
range
(
1
,
self
.
n_epoch
+
1
):
self
.
train
(
i
)
self
.
validate
()
if
__name__
==
"__main__"
:
datafolder
=
"data/"
torch
.
distributed
.
init_process_group
(
backend
=
"nccl"
)
driver
=
ShuffleNetV2Driver
(
datafolder
+
"s2_ftimgd"
,
\
datafolder
+
"s2_ftimgi"
,
\
datafolder
+
"s2_normd"
,
\
datafolder
+
"s2_normi"
,
\
datafolder
+
"s2_label.json"
)
driver
.
run
()
driver
.
finish
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment