import os from fitlog import FitLog from ShuffleNet.model import ShuffleNetV2 from iframe_feeder import iframe_feeder from torch.utils.data import DataLoader import torch from tqdm import tqdm, trange from torch.optim.lr_scheduler import MultiStepLR import torch.nn as nn class ShuffleNetV2Driver: def __init__(self, featd, feati, normd, normi, lab): self.batchsize = 256 self.lr = 0.01 self.momentum = 0.9 self.decay = 4e-5 self.gamma = 0.1 self.schedule = [200, 300] self.local_rank = int(os.getenv("LOCAL_RANK", -1)) self.RANK = int(os.getenv("RANK", -1)) self.feeder = iframe_feeder(featd, feati, lab, normd, normi) self.feeder.set_mode('train') if self.local_rank >= 0: self.device = torch.device('cuda', self.local_rank) else: self.device = torch.device('cuda') if self.local_rank >= 0: self.sampler = torch.utils.data.distributed.DistributedSampler(self.feeder) self.loader = DataLoader(self.feeder, batch_size=self.batchsize,sampler=self.sampler, shuffle=False) else: self.loader = DataLoader(self.feeder, batch_size=self.batchsize, shuffle=True, num_workers=0) #self.feeder = iframe_feeder(featd, feati, lab, normd, normi) #self.feeder.set_mode('train') self.model = ShuffleNetV2(num_classes=2, scale=0.5, SE=True, residual=True)#torchvision.models.ShuffleNetV2(num_classes=2)# self.fitlog = FitLog() self.detail_log = FitLog(prefix='dt_') #self.device = self.get_device() #torch.device('cuda' if torch.cuda.is_available() else 'cpu') print("device info: " + str(self.device)) self.optimizer = torch.optim.SGD(self.model.parameters(),\ self.lr, momentum=self.momentum, weight_decay=self.decay,\ nesterov=True) self.scheduler = MultiStepLR(self.optimizer, self.schedule, self.gamma) self.criterion = torch.nn.CrossEntropyLoss() #self.loader = DataLoader(self.feeder, batch_size=self.batchsize, shuffle=True, num_workers=0) self.print_interval = 25 self.n_epoch = 200 print('self_device:',self.device) print('self_local_rank:',self.local_rank) self.model.to(self.device) self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[self.local_rank],output_device=self.local_rank, find_unused_parameters=True) def get_device(self): device_mlu = None device_gpu = None try: # device_mlu = torch.device('mlu') device_gpu = torch.device('cuda') except Exception as err: print(err) if device_mlu: self.fitlog.append('mlu', True, True) return device_mlu elif device_gpu: self.fitlog.append('cuda', True, True) return device_gpu else: self.fitlog.append('cpu', True, True) return torch.device('cpu') def train(self, epoch): self.feeder.set_mode('train') self.model.train() nbatch = len(self.loader) print('nbatch:',nbatch) for batch_idx, (feats, labs) in enumerate(tqdm(self.loader)): feats = feats.to(self.device) labs = labs.to(self.device) self.optimizer.zero_grad() res = self.model(feats) loss = self.criterion(res, labs) loss.backward() self.optimizer.step() if batch_idx % self.print_interval == 0: xstr = "Train: epoch: {} batch: {}/{}, loss: {:.6f}".format(epoch, batch_idx, nbatch, loss) tqdm.write(xstr) self.fitlog.append(xstr, True, True) def validate(self): self.feeder.set_mode('test') self.model.eval() loss_val = 0 n_correct = 0 n_total = 0 for batch_idx, (feats, labs) in enumerate(tqdm(self.loader)): feats = feats.to(self.device) labs = labs.to(self.device) with torch.no_grad(): res = self.model(feats) loss_val += self.criterion(res, labs).item() _, pred = res.max(1) n_correct += pred.eq(labs).sum().item() n_total += labs.shape[0] loss_val = loss_val / len(self.loader) acc = n_correct / n_total * 100 self.detail_log.append(str(labs.tolist())) self.detail_log.append(str(pred.tolist())) xstr = "Validation: avg loss: {:.4f}, avg acc: {:.4f}%".format(loss_val, acc) tqdm.write(xstr) self.fitlog.append(xstr, True, True) def finish(self): self.feeder.finish() self.fitlog.close() self.detail_log.close() def run(self): for i in range(1, self.n_epoch + 1): self.train(i) self.validate() if __name__ == "__main__": datafolder = "data/" torch.distributed.init_process_group(backend="nccl") driver = ShuffleNetV2Driver(datafolder + "s2_ftimgd",\ datafolder + "s2_ftimgi",\ datafolder + "s2_normd",\ datafolder + "s2_normi",\ datafolder + "s2_label.json") driver.run() driver.finish()