Commit efb44a83 authored by thomwolf's avatar thomwolf
Browse files

distributed in extract features

parent d9d7d1a4
...@@ -25,12 +25,11 @@ import logging ...@@ -25,12 +25,11 @@ import logging
import json import json
import re import re
import tokenization
import torch import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
import tokenization
from modeling import BertConfig, BertModel from modeling import BertConfig, BertModel
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
...@@ -226,8 +225,9 @@ def main(): ...@@ -226,8 +225,9 @@ def main():
else: else:
device = torch.device("cuda", args.local_rank) device = torch.device("cuda", args.local_rank)
n_gpu = 1 n_gpu = 1
# print("Initializing the distributed backend: NCCL") # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
print("device", device, "n_gpu", n_gpu) torch.distributed.init_process_group(backend='nccl')
logger.info("device", device, "n_gpu", n_gpu, "distributed training", bool(args.local_rank != -1))
layer_indexes = [int(x) for x in args.layers.split(",")] layer_indexes = [int(x) for x in args.layers.split(",")]
...@@ -249,9 +249,12 @@ def main(): ...@@ -249,9 +249,12 @@ def main():
if args.init_checkpoint is not None: if args.init_checkpoint is not None:
model.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu')) model.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
model.to(device) model.to(device)
if n_gpu > 1: if args.local_rank != -1:
model = nn.DataParallel(model) model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
output_device=args.local_rank)
elif n_gpu > 1:
model = torch.nn.DataParallel(model)
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long) all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
...@@ -268,7 +271,7 @@ def main(): ...@@ -268,7 +271,7 @@ def main():
with open(args.output_file, "w", encoding='utf-8') as writer: with open(args.output_file, "w", encoding='utf-8') as writer:
for input_ids, input_mask, example_indices in eval_dataloader: for input_ids, input_mask, example_indices in eval_dataloader:
input_ids = input_ids.to(device) input_ids = input_ids.to(device)
input_mask = input_mask.float().to(device) input_mask = input_mask.to(device)
all_encoder_layers, _ = model(input_ids, token_type_ids=None, attention_mask=input_mask) all_encoder_layers, _ = model(input_ids, token_type_ids=None, attention_mask=input_mask)
all_encoder_layers = all_encoder_layers all_encoder_layers = all_encoder_layers
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment