#! -*- coding:utf-8 -*- # DDP示例 # 启动命令:python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 task_distributed_data_parallel.py import os # 也可命令行传入 os.environ["CUDA_VISIBLE_DEVICES"]="0,1" from bert4torch.tokenizers import Tokenizer from bert4torch.models import build_transformer_model, BaseModelDDP from bert4torch.snippets import sequence_padding, text_segmentate, ListDataset, seed_everything import torch.nn as nn import torch import torch.optim as optim import random, os, numpy as np from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler import argparse parser = argparse.ArgumentParser() parser.add_argument("--local_rank", type=int, default=-1) args = parser.parse_args() torch.cuda.set_device(args.local_rank) device = torch.device('cuda', args.local_rank) torch.distributed.init_process_group(backend='nccl') # 模型设置 maxlen = 256 batch_size = 16 config_path = '/datasets/bert-base-chinese/config.json' checkpoint_path = '/datasets/bert-base-chinese/pytorch_model.bin' dict_path = '/datasets/bert-base-chinese/vocab.txt' # 固定seed seed_everything(42) # 建立分词器 tokenizer = Tokenizer(dict_path, do_lower_case=True) class MyDataset(ListDataset): @staticmethod def load_data(filename): D = [] with open(filename, encoding='utf-8') as f: f = f.read() for l in f.split('\n\n'): if not l: continue d = [''] for i, c in enumerate(l.split('\n')): char, flag = c.split(' ') d[0] += char if flag[0] == 'B': d.append([i, i, flag[2:]]) elif flag[0] == 'I': d[-1][1] = i D.append(d) return D def collate_fn(batch): batch_token_ids, batch_segment_ids, batch_labels = [], [], [] for text, label in batch: token_ids, segment_ids = tokenizer.encode(text, maxlen=maxlen) batch_token_ids.append(token_ids) batch_segment_ids.append(segment_ids) batch_labels.append([label]) batch_token_ids = torch.tensor(sequence_padding(batch_token_ids), dtype=torch.long, device=device) batch_segment_ids = torch.tensor(sequence_padding(batch_segment_ids), dtype=torch.long, device=device) batch_labels = torch.tensor(batch_labels, dtype=torch.long, device=device) return [batch_token_ids, batch_segment_ids, batch_labels.flatten()], None # 加载数据集 train_dataset = MyDataset('/datasets/bert-base-chinese/china-people-daily-ner-corpus/example.train') train_sampler = DistributedSampler(train_dataset) train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=batch_size, collate_fn=collate_fn) # 定义bert上的模型结构,这里loss并不是放在模型里计算的 class Model1(nn.Module): def __init__(self) -> None: super().__init__() self.bert = build_transformer_model(config_path=config_path, checkpoint_path=checkpoint_path, with_pool=True) self.dropout = nn.Dropout(0.1) self.dense = nn.Linear(self.bert.configs['hidden_size'], 2) self.loss_fn = nn.CrossEntropyLoss() def forward(self, token_ids, segment_ids, labels): _, pooled_output = self.bert([token_ids, segment_ids]) output = self.dropout(pooled_output) output = self.dense(output) loss = self.loss_fn(output, labels) return loss class Model(BaseModel): def __init__(self): super().__init__() self.bert = build_transformer_model(config_path=config_path, checkpoint_path=checkpoint_path, segment_vocab_size=0) self.fc = nn.Linear(768, len(categories)) # 包含首尾 self.crf = CRF(len(categories)) def forward(self, token_ids): sequence_output = self.bert([token_ids]) # [btz, seq_len, hdsz] emission_score = self.fc(sequence_output) # [btz, seq_len, tag_size] attention_mask = token_ids.gt(0).long() return emission_score, attention_mask def predict(self, token_ids): self.eval() with torch.no_grad(): emission_score, attention_mask = self.forward(token_ids) best_path = self.crf.decode(emission_score, attention_mask) # [btz, seq_len] return best_path model = Model().to(device) # 指定DDP模型使用多gpu, master_rank为指定用于打印训练过程的local_rank model = BaseModelDDP(model, master_rank=0, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=False) # 定义使用的loss和optimizer,这里支持自定义 model.compile( loss=lambda x, _: x, # 直接把forward计算的loss传出来 optimizer=optim.Adam(model.parameters(), lr=2e-5), ) if __name__ == '__main__': model.fit(train_dataloader, epochs=20, steps_per_epoch=None)