Commit 2a64107e authored by Rémi Louf's avatar Rémi Louf Committed by Julien Chaumond
Browse files

improve device usage

parent c0707a85
...@@ -29,7 +29,7 @@ And move all the stories to the same folder. We will refer as `$DATA_PATH` the p ...@@ -29,7 +29,7 @@ And move all the stories to the same folder. We will refer as `$DATA_PATH` the p
python run_summarization.py \ python run_summarization.py \
--documents_dir $DATA_PATH \ --documents_dir $DATA_PATH \
--summaries_output_dir $SUMMARIES_PATH \ # optional --summaries_output_dir $SUMMARIES_PATH \ # optional
--visible_gpus 0,1,2 \ --to_cpu false \
--batch_size 4 \ --batch_size 4 \
--min_length 50 \ --min_length 50 \
--max_length 200 \ --max_length 200 \
...@@ -39,7 +39,7 @@ python run_summarization.py \ ...@@ -39,7 +39,7 @@ python run_summarization.py \
--compute_rouge true --compute_rouge true
``` ```
The ROUGE scores will be displayed in the console at the end of evaluation and written in a `rouge_scores.txt` file. The scripts executes on GPU if one is available and if `to_cpu` is not set to `true`. Inference on multiple GPUs is not suported yet. The ROUGE scores will be displayed in the console at the end of evaluation and written in a `rouge_scores.txt` file. The script takes 30 hours to compute with a single Tesla V100 GPU and a batch size of 10 (300,000 texts to summarize).
## Summarize any text ## Summarize any text
...@@ -49,7 +49,7 @@ Put the documents that you would like to summarize in a folder (the path to whic ...@@ -49,7 +49,7 @@ Put the documents that you would like to summarize in a folder (the path to whic
python run_summarization.py \ python run_summarization.py \
--documents_dir $DATA_PATH \ --documents_dir $DATA_PATH \
--summaries_output_dir $SUMMARIES_PATH \ # optional --summaries_output_dir $SUMMARIES_PATH \ # optional
--visible_gpus 0,1,2 \ --to_cpu false \
--batch_size 4 \ --batch_size 4 \
--min_length 50 \ --min_length 50 \
--max_length 200 \ --max_length 200 \
...@@ -58,4 +58,4 @@ python run_summarization.py \ ...@@ -58,4 +58,4 @@ python run_summarization.py \
--block_trigram true \ --block_trigram true \
``` ```
If you want to compute ROUGE on another dataset you will need to tweak the stories/summaries import in `utils_summarization.py` You may want to play around with `min_length`, `max_length` and `alpha` to suit your use case. If you want to compute ROUGE on another dataset you will need to tweak the stories/summaries import in `utils_summarization.py` and tell it where to fetch the reference summaries.
...@@ -12,10 +12,11 @@ ...@@ -12,10 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" Convert BertExtAbs's checkpoints """ Convert BertExtAbs's checkpoints.
The file currently does not do much as we ended up copying the exact model The script looks like it is doing something trivial but it is not. The "weights"
structure, but I leave it here in case we ever want to refactor the model. proposed by the authors are actually the entire model pickled. We need to load
the model within the original codebase to be able to only save its `state_dict`.
""" """
import argparse import argparse
......
...@@ -847,14 +847,12 @@ class Translator(object): ...@@ -847,14 +847,12 @@ class Translator(object):
global_scores (:obj:`GlobalScorer`): global_scores (:obj:`GlobalScorer`):
object to rescore final translations object to rescore final translations
copy_attn (bool): use copy attention during translation copy_attn (bool): use copy attention during translation
cuda (bool): use cuda
beam_trace (bool): trace beam search for debugging beam_trace (bool): trace beam search for debugging
logger(logging.Logger): logger. logger(logging.Logger): logger.
""" """
def __init__(self, args, model, vocab, symbols, global_scorer=None, logger=None): def __init__(self, args, model, vocab, symbols, global_scorer=None, logger=None):
self.logger = logger self.logger = logger
self.cuda = args.visible_gpus != "-1"
self.args = args self.args = args
self.model = model self.model = model
......
...@@ -185,7 +185,7 @@ def save_summaries(summaries, path, original_document_name): ...@@ -185,7 +185,7 @@ def save_summaries(summaries, path, original_document_name):
def build_data_iterator(args, tokenizer): def build_data_iterator(args, tokenizer):
dataset = load_and_cache_examples(args, tokenizer) dataset = load_and_cache_examples(args, tokenizer)
sampler = SequentialSampler(dataset) sampler = SequentialSampler(dataset)
collate_fn = lambda data: collate(data, tokenizer, block_size=512) collate_fn = lambda data: collate(data, tokenizer, block_size=512, device=args.device)
iterator = DataLoader( iterator = DataLoader(
dataset, sampler=sampler, batch_size=args.batch_size, collate_fn=collate_fn, dataset, sampler=sampler, batch_size=args.batch_size, collate_fn=collate_fn,
) )
...@@ -198,7 +198,7 @@ def load_and_cache_examples(args, tokenizer): ...@@ -198,7 +198,7 @@ def load_and_cache_examples(args, tokenizer):
return dataset return dataset
def collate(data, tokenizer, block_size): def collate(data, tokenizer, block_size, device):
""" Collate formats the data passed to the data loader. """ Collate formats the data passed to the data loader.
In particular we tokenize the data batch after batch to avoid keeping them In particular we tokenize the data batch after batch to avoid keeping them
...@@ -224,9 +224,9 @@ def collate(data, tokenizer, block_size): ...@@ -224,9 +224,9 @@ def collate(data, tokenizer, block_size):
batch = Batch( batch = Batch(
document_names=names, document_names=names,
batch_size=len(encoded_stories), batch_size=len(encoded_stories),
src=encoded_stories, src=encoded_stories.to(device),
segs=encoder_token_type_ids, segs=encoder_token_type_ids.to(device),
mask_src=encoder_mask, mask_src=encoder_mask.to(device),
tgt_str=summaries, tgt_str=summaries,
) )
...@@ -271,10 +271,10 @@ def main(): ...@@ -271,10 +271,10 @@ def main():
) )
# EVALUATION options # EVALUATION options
parser.add_argument( parser.add_argument(
"--visible_gpus", "--to_cpu",
default=-1, default=False,
type=int, type=bool,
help="Number of GPUs with which to do the training.", help="Whether to force the execution on CPU.",
) )
parser.add_argument( parser.add_argument(
"--batch_size", default=4, type=int, help="Batch size per GPU/CPU for training.", "--batch_size", default=4, type=int, help="Batch size per GPU/CPU for training.",
...@@ -311,8 +311,11 @@ def main(): ...@@ -311,8 +311,11 @@ def main():
help="Whether to block the existence of repeating trigrams in the text generated by beam search.", help="Whether to block the existence of repeating trigrams in the text generated by beam search.",
) )
args = parser.parse_args() args = parser.parse_args()
args.device = torch.device("cpu") if args.visible_gpus == -1 else torch.device("cuda")
# Select device (distibuted not available)
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.to_cpu else "cpu")
# Check the existence of directories
if not args.summaries_output_dir: if not args.summaries_output_dir:
args.summaries_output_dir = args.documents_dir args.summaries_output_dir = args.documents_dir
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment