Commit 2a64107e authored by Rémi Louf's avatar Rémi Louf Committed by Julien Chaumond
Browse files

improve device usage

parent c0707a85
......@@ -29,7 +29,7 @@ And move all the stories to the same folder. We will refer as `$DATA_PATH` the p
python run_summarization.py \
--documents_dir $DATA_PATH \
--summaries_output_dir $SUMMARIES_PATH \ # optional
--visible_gpus 0,1,2 \
--to_cpu false \
--batch_size 4 \
--min_length 50 \
--max_length 200 \
......@@ -39,7 +39,7 @@ python run_summarization.py \
--compute_rouge true
```
The ROUGE scores will be displayed in the console at the end of evaluation and written in a `rouge_scores.txt` file.
The scripts executes on GPU if one is available and if `to_cpu` is not set to `true`. Inference on multiple GPUs is not suported yet. The ROUGE scores will be displayed in the console at the end of evaluation and written in a `rouge_scores.txt` file. The script takes 30 hours to compute with a single Tesla V100 GPU and a batch size of 10 (300,000 texts to summarize).
## Summarize any text
......@@ -49,7 +49,7 @@ Put the documents that you would like to summarize in a folder (the path to whic
python run_summarization.py \
--documents_dir $DATA_PATH \
--summaries_output_dir $SUMMARIES_PATH \ # optional
--visible_gpus 0,1,2 \
--to_cpu false \
--batch_size 4 \
--min_length 50 \
--max_length 200 \
......@@ -58,4 +58,4 @@ python run_summarization.py \
--block_trigram true \
```
If you want to compute ROUGE on another dataset you will need to tweak the stories/summaries import in `utils_summarization.py`
You may want to play around with `min_length`, `max_length` and `alpha` to suit your use case. If you want to compute ROUGE on another dataset you will need to tweak the stories/summaries import in `utils_summarization.py` and tell it where to fetch the reference summaries.
......@@ -12,10 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Convert BertExtAbs's checkpoints
""" Convert BertExtAbs's checkpoints.
The file currently does not do much as we ended up copying the exact model
structure, but I leave it here in case we ever want to refactor the model.
The script looks like it is doing something trivial but it is not. The "weights"
proposed by the authors are actually the entire model pickled. We need to load
the model within the original codebase to be able to only save its `state_dict`.
"""
import argparse
......
......@@ -847,14 +847,12 @@ class Translator(object):
global_scores (:obj:`GlobalScorer`):
object to rescore final translations
copy_attn (bool): use copy attention during translation
cuda (bool): use cuda
beam_trace (bool): trace beam search for debugging
logger(logging.Logger): logger.
"""
def __init__(self, args, model, vocab, symbols, global_scorer=None, logger=None):
self.logger = logger
self.cuda = args.visible_gpus != "-1"
self.args = args
self.model = model
......
......@@ -185,7 +185,7 @@ def save_summaries(summaries, path, original_document_name):
def build_data_iterator(args, tokenizer):
dataset = load_and_cache_examples(args, tokenizer)
sampler = SequentialSampler(dataset)
collate_fn = lambda data: collate(data, tokenizer, block_size=512)
collate_fn = lambda data: collate(data, tokenizer, block_size=512, device=args.device)
iterator = DataLoader(
dataset, sampler=sampler, batch_size=args.batch_size, collate_fn=collate_fn,
)
......@@ -198,7 +198,7 @@ def load_and_cache_examples(args, tokenizer):
return dataset
def collate(data, tokenizer, block_size):
def collate(data, tokenizer, block_size, device):
""" Collate formats the data passed to the data loader.
In particular we tokenize the data batch after batch to avoid keeping them
......@@ -224,9 +224,9 @@ def collate(data, tokenizer, block_size):
batch = Batch(
document_names=names,
batch_size=len(encoded_stories),
src=encoded_stories,
segs=encoder_token_type_ids,
mask_src=encoder_mask,
src=encoded_stories.to(device),
segs=encoder_token_type_ids.to(device),
mask_src=encoder_mask.to(device),
tgt_str=summaries,
)
......@@ -271,10 +271,10 @@ def main():
)
# EVALUATION options
parser.add_argument(
"--visible_gpus",
default=-1,
type=int,
help="Number of GPUs with which to do the training.",
"--to_cpu",
default=False,
type=bool,
help="Whether to force the execution on CPU.",
)
parser.add_argument(
"--batch_size", default=4, type=int, help="Batch size per GPU/CPU for training.",
......@@ -311,8 +311,11 @@ def main():
help="Whether to block the existence of repeating trigrams in the text generated by beam search.",
)
args = parser.parse_args()
args.device = torch.device("cpu") if args.visible_gpus == -1 else torch.device("cuda")
# Select device (distibuted not available)
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.to_cpu else "cpu")
# Check the existence of directories
if not args.summaries_output_dir:
args.summaries_output_dir = args.documents_dir
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment