"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "66582492d35edd2cd929dad8d668c982fa617211"
Unverified Commit 27a7fe7a authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

examples/seq2seq: never override $WANDB_PROJECT (#5407)

parent 32d20314
...@@ -71,7 +71,7 @@ Load it with `BartForConditionalGeneration.from_pretrained(f'{output_dir}/best_t ...@@ -71,7 +71,7 @@ Load it with `BartForConditionalGeneration.from_pretrained(f'{output_dir}/best_t
- If you want to run experiments on improving the summarization finetuning process, try the XSUM Shared Task (below). It's faster to train than CNNDM because the summaries are shorter. - If you want to run experiments on improving the summarization finetuning process, try the XSUM Shared Task (below). It's faster to train than CNNDM because the summaries are shorter.
- For CNN/DailyMail, the default `val_max_target_length` and `test_max_target_length` will truncate the ground truth labels, resulting in slightly higher rouge scores. To get accurate rouge scores, you should rerun calculate_rouge on the `{output_dir}/test_generations.txt` file saved by `trainer.test()` - For CNN/DailyMail, the default `val_max_target_length` and `test_max_target_length` will truncate the ground truth labels, resulting in slightly higher rouge scores. To get accurate rouge scores, you should rerun calculate_rouge on the `{output_dir}/test_generations.txt` file saved by `trainer.test()`
- `--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 ` is a reasonable setting for XSUM. - `--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 ` is a reasonable setting for XSUM.
- `wandb` can be used by specifying `--logger wandb_shared` or `--logger wandb`. It is useful for reproducibility. - `wandb` can be used by specifying `--logger wandb`. It is useful for reproducibility. Specify the environment variable `WANDB_PROJECT='hf_xsum'` to do the XSUM shared task.
- This warning can be safely ignored: - This warning can be safely ignored:
> "Some weights of BartForConditionalGeneration were not initialized from the model checkpoint at facebook/bart-large-xsum and are newly initialized: ['final_logits_bias']" > "Some weights of BartForConditionalGeneration were not initialized from the model checkpoint at facebook/bart-large-xsum and are newly initialized: ['final_logits_bias']"
- Both finetuning and eval are 30% faster with `--fp16`. For that you need to [install apex](https://github.com/NVIDIA/apex#quick-start). - Both finetuning and eval are 30% faster with `--fp16`. For that you need to [install apex](https://github.com/NVIDIA/apex#quick-start).
...@@ -111,14 +111,14 @@ Compare XSUM results with others by using `--logger wandb_shared`. This requires ...@@ -111,14 +111,14 @@ Compare XSUM results with others by using `--logger wandb_shared`. This requires
Here is an example command, but you can do whatever you want. Hopefully this will make debugging and collaboration easier! Here is an example command, but you can do whatever you want. Hopefully this will make debugging and collaboration easier!
```bash ```bash
./finetune.sh \ WANDB_PROJECT='hf_xsum' ./finetune.sh \
--data_dir $XSUM_DIR \ --data_dir $XSUM_DIR \
--output_dir xsum_frozen_embs \ --output_dir xsum_frozen_embs \
--model_name_or_path facebook/bart-large \ --model_name_or_path facebook/bart-large \
--logger wandb_shared \
--train_batch_size 16 --eval_batch_size 16 --freeze_embeds --freeze_encoder \ --train_batch_size 16 --eval_batch_size 16 --freeze_embeds --freeze_encoder \
--num_train_epochs 6 \ --num_train_epochs 6 \
--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 --max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 \
--logger wandb
``` ```
You can see your wandb logs [here](https://app.wandb.ai/sshleifer/hf_xsum?workspace=user-) You can see your wandb logs [here](https://app.wandb.ai/sshleifer/hf_xsum?workspace=user-)
......
...@@ -298,8 +298,6 @@ def main(args, model=None) -> SummarizationModule: ...@@ -298,8 +298,6 @@ def main(args, model=None) -> SummarizationModule:
model: SummarizationModule = SummarizationModule(args) model: SummarizationModule = SummarizationModule(args)
else: else:
model: SummarizationModule = TranslationModule(args) model: SummarizationModule = TranslationModule(args)
dataset = Path(args.data_dir).name
if ( if (
args.logger == "default" args.logger == "default"
or args.fast_dev_run or args.fast_dev_run
...@@ -310,12 +308,12 @@ def main(args, model=None) -> SummarizationModule: ...@@ -310,12 +308,12 @@ def main(args, model=None) -> SummarizationModule:
elif args.logger == "wandb": elif args.logger == "wandb":
from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.loggers import WandbLogger
logger = WandbLogger(name=model.output_dir.name, project=dataset) logger = WandbLogger(name=model.output_dir.name)
elif args.logger == "wandb_shared": elif args.logger == "wandb_shared":
from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.loggers import WandbLogger
logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}") logger = WandbLogger(name=model.output_dir.name)
trainer: pl.Trainer = generic_train( trainer: pl.Trainer = generic_train(
model, model,
args, args,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment