import torch import torch.nn as nn # from peselibs_config import get_lib_path import sys # sys.path.append(get_lib_path()) import DDP import torchvision.models.mobilenet as mobilenet from datawork import * from sklearn.metrics import accuracy_score from fitlog import FitLog from torch.utils.data import DataLoader import time import torch.distributed as dist import os import datetime import argparse parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') parser.add_argument('--world-size', default=-1, type=int, help='number of nodes for distributed training') parser.add_argument('--rank', default=-1, type=int, help='node rank for distributed training') parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, help='url used to set up distributed training') parser.add_argument('--dist-backend', default='nccl', type=str, help='distributed backend') g_dubug = False class MobileNetV2Driver(): def __init__(self,args):#DDP: system initialization self.nclass = 9 self.batch_size = 128 #64 #self.local_rank = local_rank self.nepoch = 50 self.nround = 10 self.lr = 0.00001 self.loader = None self.test_loader = None self.dataset = None self.device = None self.args = args self.local_rank = args.rank #model & device self.model = mobilenet.MobileNetV2(num_classes=self.nclass) print("local_rank:{}".format(local_rank)) self._init_device() self.model.to(self.device) print('$$$$$$$$$$$$$$$$$$$$$$$$',self.device) if self.local_rank != None: self.model = nn.parallel.DistributedDataParallel( self.model, device_ids=[self.local_rank%4],output_device=local_rank%4,find_unused_parameters=True) self.criterion = nn.CrossEntropyLoss() # self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.lr) self.optimizer = torch.optim.Adam(self.model.parameters(),lr=0.00001,betas=(0.9,0.999)) #self.scheduler = torch.optim.lr_scheduler.StepLR(optimizer_ft, step_size=1, gamma=0.98) print('##################################self.dataset#############################') self.dataset = Fer2013Dataset(local_rank%4) print('self_data_load finish$$$$$$$$$$$$$$$$$$$$$') try: self.sampler = torch.utils.data.distributed.DistributedSampler(self.dataset,num_replicas=args.world_size, rank=args.rank) except: self.sampler=None if self.local_rank != None: self.loader = DataLoader( self.dataset, batch_size=self.batch_size, sampler=self.sampler, shuffle=False) else: self.loader = DataLoader( self.dataset, batch_size=self.batch_size, shuffle=True) self.test_loader = DataLoader( self.dataset, batch_size=self.batch_size, shuffle=True) print('&&&&&&&&&&&&&&&&&&&&&dataset end&&&&&&&&&&&&&&&&&&&&&&&&&') #self.model = torch.nn.parallel.DistributedDataParallel(self.model,device_ids=[self.local_rank]) def _init_device(self): if self.local_rank != None: self.device = torch.device('cuda', self.local_rank % 4) else: if torch.cuda.is_available(): self.device = torch.device('cuda') else: self.device = torch.device('cpu') def init_dataset(self, seed): self.dataset.randomization(seed) def train(self): best_acc = 0 best_acc_at = 0 if self.local_rank == 0 or self.local_rank == None: self.fitlog = FitLog("logs/") self.jishilog = FitLog("logs/",prefix='jishi') self.dlog = FitLog("logs/", prefix='pred') st_time=time.time() for epoch in range(self.nepoch): self.dataset.set_mode("train") self.model.train() all_loss = [] for batch_idx, (data, target) in enumerate(self.loader): data, target = data.to(self.device), target.to(self.device) self.dataset.set_mode("train") if self.local_rank == 0 or self.local_rank == None: jishi1=time.time() self.model.train() self.optimizer.zero_grad() output = self.model(data) loss = self.criterion(output, target) loss.backward() self.optimizer.step() if self.local_rank == 0 or self.local_rank == None: jishi2=time.time() jishi2_log='****epc:{},process:{}/{},start:{},end:{},duration:{}****'.format(str(epoch),str(batch_idx* len(data)),str(len(self.loader.dataset)), str(jishi1),str(jishi2),str(jishi2-jishi1)) #print(jishi2_log) jishi2_log=str(jishi2_log) self.jishilog.append(jishi2_log) print(jishi2_log) all_loss.append(loss.item()) t1=time.time() duration=t1-st_time if (batch_idx % 10 == 0) and (self.local_rank == 0 or self.local_rank == None): btstr = 'epc: {} [{}/{} ({:.0f}%)] loss: {:.6f} b-acc: {:.3f} @:{},curtime:{},duration:{}'.format( epoch, batch_idx * len(data), len(self.loader.dataset), 100. * batch_idx / len(self.loader), loss.item(), best_acc, best_acc_at,str(t1),str(duration)) self.fitlog.append(btstr) # print(btstr) if g_dubug: break torch.save(self.model,'./mobilenet.pth') if self.local_rank == 0 or self.local_rank == None: t1=time.time() duration=t1-st_time acc, vloss, vloss_std,all_pred, all_tar = self._validate() epcstr = '****epc:{},loss:{:.6f},loss_std:{:.6f},vloss:{:.6f},vloss_std:{:.6f},acc:{:.3f},duration:{}****'.format( epoch, np.mean(all_loss), np.std(all_loss), vloss, vloss_std,acc,str(duration)) self.dlog.append(epcstr+",preds:{},plabs:{}".format(str(all_pred), str(all_tar))) if acc > best_acc: best_acc = acc best_acc_at = epoch print(epcstr) if g_dubug: break if self.local_rank == 0 or self.local_rank == None: self.fitlog.close() self.dlog.close() self.jishilog.close() def _validate(self): self.model.eval() self.dataset.set_mode('test') all_pred = [] all_tar = [] accs = [] all_loss = [] with torch.no_grad(): for i, (ft, labs) in enumerate(self.test_loader): ft, labs = ft.to(self.device), labs.to(self.device) output = self.model(ft) loss = self.criterion(output, labs) preds = torch.argmax(output, dim=1).cpu().numpy().tolist() all_pred.extend(preds) all_tar.extend(labs.cpu().numpy().tolist()) accs.append(accuracy_score(all_tar, all_pred)) all_loss.append(loss.item()) if i % 100 == 0: print('validating @ batch {}'.format(i)) if g_dubug: break return np.mean(accs), np.mean(all_loss),np.std(all_loss), all_pred, all_tar def run(self, iround): self.init_dataset(iround) self.train() def init_ddp(args,visiable_devices='0,1,2,3'): if torch.cuda.device_count() > 1: #os.environ['HIP_VISIBLE_DEVICES'] = visiable_devices local_rank = args.rank #int(os.environ["LOCAL_RANK"]) print("local_rank:" + str(local_rank)) #torch.distributed.init_process_group(backend='nccl', init_method='tcp://localhost:23456', rank=0, world_size=1) #torch.distributed.init_process_group(backend="nccl") print(args.dist_backend) print(args.dist_url) print(args.world_size) print(args.rank) print('%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%') torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank) # local_rank = torch.distributed.get_rank() torch.cuda.set_device(local_rank % 4) # device = torch.device("cuda", args.local_rank) return local_rank else: return None if __name__ == '__main__': print("torch.cuda.is_available",torch.cuda.is_available()) args = parser.parse_args() t0 =time.time() print(torch.cuda.device_count()) local_rank = init_ddp(args) print(local_rank) driver = MobileNetV2Driver(args) # print("round {}".format(sys.argv[1])) iround=1 driver.run(iround) t1 =time.time() print("result_time=",(t1-t0)/1000) #dist.destroy_process_group()