Commit 15550ce0 authored by Julien Chaumond's avatar Julien Chaumond
Browse files

[skip ci] remove local rank

parent 62427d08
...@@ -115,7 +115,7 @@ class DataTrainingArguments: ...@@ -115,7 +115,7 @@ class DataTrainingArguments:
) )
def get_dataset(args: DataTrainingArguments, tokenizer: PreTrainedTokenizer, evaluate=False, local_rank=-1): def get_dataset(args: DataTrainingArguments, tokenizer: PreTrainedTokenizer, evaluate=False):
file_path = args.eval_data_file if evaluate else args.train_data_file file_path = args.eval_data_file if evaluate else args.train_data_file
if args.line_by_line: if args.line_by_line:
return LineByLineTextDataset(tokenizer=tokenizer, file_path=file_path, block_size=args.block_size) return LineByLineTextDataset(tokenizer=tokenizer, file_path=file_path, block_size=args.block_size)
...@@ -216,16 +216,8 @@ def main(): ...@@ -216,16 +216,8 @@ def main():
data_args.block_size = min(data_args.block_size, tokenizer.max_len) data_args.block_size = min(data_args.block_size, tokenizer.max_len)
# Get datasets # Get datasets
train_dataset = ( train_dataset = get_dataset(data_args, tokenizer=tokenizer) if training_args.do_train else None
get_dataset(data_args, tokenizer=tokenizer, local_rank=training_args.local_rank) eval_dataset = get_dataset(data_args, tokenizer=tokenizer, evaluate=True) if training_args.do_eval else None
if training_args.do_train
else None
)
eval_dataset = (
get_dataset(data_args, tokenizer=tokenizer, local_rank=training_args.local_rank, evaluate=True)
if training_args.do_eval
else None
)
data_collator = DataCollatorForLanguageModeling( data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer, mlm=data_args.mlm, mlm_probability=data_args.mlm_probability tokenizer=tokenizer, mlm=data_args.mlm, mlm_probability=data_args.mlm_probability
) )
......
...@@ -68,6 +68,6 @@ class RobertaConfig(BertConfig): ...@@ -68,6 +68,6 @@ class RobertaConfig(BertConfig):
model_type = "roberta" model_type = "roberta"
def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2, **kwargs): def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2, **kwargs):
"""Constructs FlaubertConfig. """Constructs RobertaConfig.
""" """
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment