Commit 0e1045f0 authored by lvzhen's avatar lvzhen
Browse files

Revert "Merge branch 'master' into 'master'"

This reverts merge request !2
parent 467ec853
{
"cells": [
{
"cell_type": "markdown",
"source": [
"# 使用微调的方法让模型对新闻分类更加准确\n",
"\n",
"在本实操手册中,开发者将使用ChatGLM3-6B base模型,对新闻分类数据集进行微调,并使用微调后的模型进行推理。\n",
"本操作手册使用公开数据集,数据集中包含了新闻标题和新闻关键词,开发者需要根据这些信息,将新闻分类到15个类别中的一个。\n",
"为了体现模型高效的学习能力,以及让用户更快的学习本手册,我们只使用了数据集中的一小部分数据,实际上,数据集中包含了超过40万条新闻数据。\n",
"\n",
"## 硬件要求\n",
"本实践手册需要使用 FP16 精度的模型进行推理,因此,我们推荐使用至少 16GB 显存的 英伟达 GPU 来完成本实践手册。\n"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"首先,我们将原始的数据集格式转换为用于微调的`jsonl`格式,以方便进行微调。"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 1,
"outputs": [],
"source": [
"import json\n",
"\n",
"# 路径可以根据实际情况修改\n",
"input_file_path = 'data/toutiao_cat_data_example.txt'\n",
"output_file_path = 'data/toutiao_cat_data_example.jsonl'\n",
"\n",
"# 提示词\n",
"prompt_prefix = \"\"\"\n",
"你是一个专业的新闻专家,请根据我提供的新闻信息,包括新闻标题,新闻关键词等信息,你要对每一行新闻类别进行分类并告诉我结果,不要返回其他信息和多于的文字,这些类别是:\n",
"news_story\n",
"news_culture\n",
"news_sports\n",
"news_finance\n",
"news_house\n",
"news_car\n",
"news_edu\n",
"news_tech\n",
"news_military\n",
"news_travel\n",
"news_world\n",
"stock\n",
"news_agriculture\n",
"news_game\n",
"请选择其中一个类别并返回,你只要返回类别的名称,不要返回其他信息。让我们开始吧:\n",
"\"\"\"\n",
"\n",
"# 分类代码和名称的映射\n",
"category_map = {\n",
" \"100\": \"news_story\",\n",
" \"101\": \"news_culture\",\n",
" \"102\": \"news_entertainment\",\n",
" \"103\": \"news_sports\",\n",
" \"104\": \"news_finance\",\n",
" \"106\": \"news_house\",\n",
" \"107\": \"news_car\",\n",
" \"108\": \"news_edu\",\n",
" \"109\": \"news_tech\",\n",
" \"110\": \"news_military\",\n",
" \"112\": \"news_travel\",\n",
" \"113\": \"news_world\",\n",
" \"114\": \"stock\",\n",
" \"115\": \"news_agriculture\",\n",
" \"116\": \"news_game\"\n",
"}\n",
"\n",
"def process_line(line):\n",
" # 分割每行数据\n",
" parts = line.strip().split('_!_')\n",
" if len(parts) != 5:\n",
" return None\n",
"\n",
" # 提取所需字段\n",
" _, category_code, _, news_title, news_keywords = parts\n",
"\n",
" # 构造 JSON 对象\n",
" news_title = news_title if news_title else \"无\"\n",
" news_keywords = news_keywords if news_keywords else \"无\"\n",
" json_obj = {\n",
" \"context\": prompt_prefix + f\"新闻标题: {news_title}\\n 新闻关键词: {news_keywords}\\n\",\n",
" \"target\": category_map.get(category_code, \"无\")\n",
" }\n",
" return json_obj\n",
"\n",
"def convert_to_jsonl(input_path, output_path):\n",
" with open(input_path, 'r', encoding='utf-8') as infile, \\\n",
" open(output_path, 'w', encoding='utf-8') as outfile:\n",
" for line in infile:\n",
" json_obj = process_line(line)\n",
" if json_obj:\n",
" json_line = json.dumps(json_obj, ensure_ascii=False)\n",
" outfile.write(json_line + '\\n')\n",
"\n",
"# 运行转换函数\n",
"convert_to_jsonl(input_file_path, output_file_path)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2023-11-24T13:43:59.281779Z",
"end_time": "2023-11-24T13:43:59.330679Z"
}
}
},
{
"cell_type": "markdown",
"source": [
"## 使用没有微调的模型进行推理\n",
"首先,我们先试用原本的模基座模型进行推理,并查看效果。"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 2,
"outputs": [],
"source": [
"PROMPT = \"\"\"\n",
"你是一个专业的新闻专家,请根据我提供的新闻信息,包括新闻标题,新闻关键词等信息,你要对每一行新闻类别进行分类并告诉我结果,不要返回其他信息和多于的文字,这些类别是:\n",
"news_story\n",
"news_culture\n",
"news_sports\n",
"news_finance\n",
"news_house\n",
"news_car\n",
"news_edu\n",
"news_tech\n",
"news_military\n",
"news_travel\n",
"news_world\n",
"stock\n",
"news_agriculture\n",
"news_game\n",
"请选择其中一个类别并返回,你只要返回类别的名称,不要返回其他信息。让我们开始吧:\n",
"新闻标题:华为手机扛下敌人子弹,是什么技术让其在战争中大放异彩?\n",
"新闻关键词: 华为手机\n",
"\"\"\""
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2023-11-24T13:43:59.377486Z",
"end_time": "2023-11-24T13:43:59.392276Z"
}
}
},
{
"cell_type": "code",
"execution_count": 3,
"outputs": [
{
"data": {
"text/plain": "Loading checkpoint shards: 0%| | 0/7 [00:00<?, ?it/s]",
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "ed12c09cc7f443bab2f99b3ca7e99716"
}
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from transformers import AutoModel, AutoTokenizer\n",
"import torch\n",
"# 参数设置\n",
"model_path = \"/Models/chatglm3-6b-base\"\n",
"tokenizer_path = model_path\n",
"device = \"cuda\"\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True)\n",
"model = AutoModel.from_pretrained(model_path, load_in_8bit=False, trust_remote_code=True).to(device)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2023-11-24T13:43:59.377486Z",
"end_time": "2023-11-24T13:44:11.180343Z"
}
}
},
{
"cell_type": "code",
"execution_count": 4,
"outputs": [
{
"data": {
"text/plain": "'news_house'"
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"max_new_tokens = 1024\n",
"temperature = 0.4\n",
"top_p = 0.9\n",
"inputs = tokenizer(PROMPT, return_tensors=\"pt\").to(device)\n",
"response = model.generate(input_ids=inputs[\"input_ids\"],max_new_tokens=max_new_tokens,temperature=temperature,top_p=top_p,do_sample=True)\n",
"response = response[0, inputs[\"input_ids\"].shape[-1]:]\n",
"origin_answer = tokenizer.decode(response, skip_special_tokens=True)\n",
"origin_answer"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2023-11-24T13:44:11.183956Z",
"end_time": "2023-11-24T13:44:12.313278Z"
}
}
},
{
"cell_type": "code",
"execution_count": 5,
"outputs": [],
"source": [
"del model\n",
"torch.cuda.empty_cache()"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2023-11-24T13:44:12.313278Z",
"end_time": "2023-11-24T13:44:12.463365Z"
}
}
},
{
"cell_type": "markdown",
"source": [
"我们可以发现,模型并没有正确的对新闻进行分类。这可能是由模型训练阶段的数据导致的问题。那么,我们通过微调这个模型,能不能实现更好的效果呢?"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"## 使用官方的微调脚本进行微调\n",
"\n",
"在完成数据的切割之后,我们需要按照官方提供好的方式进行微调。我们使用的模型为`chatglm3-6b-base`基座模型,该模型相对于Chat模型,更容易上手微调,且更符合本章节的应用场景。\n",
"\n",
"我们将对应的参数设置好后,就可以直接执行下面的代码进行微调。该代码使用`Lora`方案进行微调,成本相较于全参微调大幅度降低。\n"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 6,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[2023-11-24 13:44:13,160] torch.distributed.run: [WARNING] master_addr is only used for static rdzv_backend and when rdzv_endpoint is not specified.\r\n",
"11/24/2023 13:44:15 - WARNING - __main__ - Process rank: 0, device: cuda:0, n_gpu: 1distributed training: True, 16-bits training: False\r\n",
"11/24/2023 13:44:15 - INFO - __main__ - Training/evaluation parameters Seq2SeqTrainingArguments(\r\n",
"_n_gpu=1,\r\n",
"adafactor=False,\r\n",
"adam_beta1=0.9,\r\n",
"adam_beta2=0.999,\r\n",
"adam_epsilon=1e-08,\r\n",
"auto_find_batch_size=False,\r\n",
"bf16=False,\r\n",
"bf16_full_eval=False,\r\n",
"data_seed=None,\r\n",
"dataloader_drop_last=False,\r\n",
"dataloader_num_workers=0,\r\n",
"dataloader_pin_memory=True,\r\n",
"ddp_backend=None,\r\n",
"ddp_broadcast_buffers=None,\r\n",
"ddp_bucket_cap_mb=None,\r\n",
"ddp_find_unused_parameters=None,\r\n",
"ddp_timeout=1800,\r\n",
"debug=[],\r\n",
"deepspeed=None,\r\n",
"disable_tqdm=False,\r\n",
"dispatch_batches=None,\r\n",
"do_eval=False,\r\n",
"do_predict=False,\r\n",
"do_train=False,\r\n",
"eval_accumulation_steps=None,\r\n",
"eval_delay=0,\r\n",
"eval_steps=None,\r\n",
"evaluation_strategy=no,\r\n",
"fp16=False,\r\n",
"fp16_backend=auto,\r\n",
"fp16_full_eval=False,\r\n",
"fp16_opt_level=O1,\r\n",
"fsdp=[],\r\n",
"fsdp_config={'min_num_params': 0, 'xla': False, 'xla_fsdp_grad_ckpt': False},\r\n",
"fsdp_min_num_params=0,\r\n",
"fsdp_transformer_layer_cls_to_wrap=None,\r\n",
"full_determinism=False,\r\n",
"generation_config=None,\r\n",
"generation_max_length=None,\r\n",
"generation_num_beams=None,\r\n",
"gradient_accumulation_steps=2,\r\n",
"gradient_checkpointing=False,\r\n",
"gradient_checkpointing_kwargs=None,\r\n",
"greater_is_better=None,\r\n",
"group_by_length=False,\r\n",
"half_precision_backend=auto,\r\n",
"hub_always_push=False,\r\n",
"hub_model_id=None,\r\n",
"hub_private_repo=False,\r\n",
"hub_strategy=every_save,\r\n",
"hub_token=<HUB_TOKEN>,\r\n",
"ignore_data_skip=False,\r\n",
"include_inputs_for_metrics=False,\r\n",
"include_tokens_per_second=False,\r\n",
"jit_mode_eval=False,\r\n",
"label_names=None,\r\n",
"label_smoothing_factor=0.0,\r\n",
"learning_rate=2e-05,\r\n",
"length_column_name=length,\r\n",
"load_best_model_at_end=False,\r\n",
"local_rank=0,\r\n",
"log_level=passive,\r\n",
"log_level_replica=warning,\r\n",
"log_on_each_node=True,\r\n",
"logging_dir=output/news-20231124-134412-2e-05/runs/Nov24_13-44-15_iZwz90pbe3r8jeoaaobbf0Z,\r\n",
"logging_first_step=False,\r\n",
"logging_nan_inf_filter=True,\r\n",
"logging_steps=1.0,\r\n",
"logging_strategy=steps,\r\n",
"lr_scheduler_type=linear,\r\n",
"max_grad_norm=1.0,\r\n",
"max_steps=300,\r\n",
"metric_for_best_model=None,\r\n",
"mp_parameters=,\r\n",
"neftune_noise_alpha=None,\r\n",
"no_cuda=False,\r\n",
"num_train_epochs=3.0,\r\n",
"optim=adamw_torch,\r\n",
"optim_args=None,\r\n",
"output_dir=output/news-20231124-134412-2e-05,\r\n",
"overwrite_output_dir=False,\r\n",
"past_index=-1,\r\n",
"per_device_eval_batch_size=8,\r\n",
"per_device_train_batch_size=1,\r\n",
"predict_with_generate=False,\r\n",
"prediction_loss_only=False,\r\n",
"push_to_hub=False,\r\n",
"push_to_hub_model_id=None,\r\n",
"push_to_hub_organization=None,\r\n",
"push_to_hub_token=<PUSH_TO_HUB_TOKEN>,\r\n",
"ray_scope=last,\r\n",
"remove_unused_columns=True,\r\n",
"report_to=[],\r\n",
"resume_from_checkpoint=None,\r\n",
"run_name=output/news-20231124-134412-2e-05,\r\n",
"save_on_each_node=False,\r\n",
"save_safetensors=True,\r\n",
"save_steps=100,\r\n",
"save_strategy=steps,\r\n",
"save_total_limit=None,\r\n",
"seed=42,\r\n",
"skip_memory_metrics=True,\r\n",
"sortish_sampler=False,\r\n",
"split_batches=False,\r\n",
"tf32=None,\r\n",
"torch_compile=False,\r\n",
"torch_compile_backend=None,\r\n",
"torch_compile_mode=None,\r\n",
"torchdynamo=None,\r\n",
"tpu_metrics_debug=False,\r\n",
"tpu_num_cores=None,\r\n",
"use_cpu=False,\r\n",
"use_ipex=False,\r\n",
"use_legacy_prediction_loop=False,\r\n",
"use_mps_device=False,\r\n",
"warmup_ratio=0.0,\r\n",
"warmup_steps=0,\r\n",
"weight_decay=0.0,\r\n",
")\r\n",
"[INFO|tokenization_utils_base.py:2020] 2023-11-24 13:44:15,353 >> loading file tokenizer.model\r\n",
"[INFO|tokenization_utils_base.py:2020] 2023-11-24 13:44:15,353 >> loading file added_tokens.json\r\n",
"[INFO|tokenization_utils_base.py:2020] 2023-11-24 13:44:15,353 >> loading file special_tokens_map.json\r\n",
"[INFO|tokenization_utils_base.py:2020] 2023-11-24 13:44:15,353 >> loading file tokenizer_config.json\r\n",
"[INFO|tokenization_utils_base.py:2020] 2023-11-24 13:44:15,353 >> loading file tokenizer.json\r\n",
"[INFO|configuration_utils.py:715] 2023-11-24 13:44:15,453 >> loading configuration file /Models/chatglm3-6b-base/config.json\r\n",
"[INFO|configuration_utils.py:715] 2023-11-24 13:44:15,454 >> loading configuration file /Models/chatglm3-6b-base/config.json\r\n",
"[INFO|configuration_utils.py:777] 2023-11-24 13:44:15,454 >> Model config ChatGLMConfig {\r\n",
" \"_name_or_path\": \"/Models/chatglm3-6b-base\",\r\n",
" \"add_bias_linear\": false,\r\n",
" \"add_qkv_bias\": true,\r\n",
" \"apply_query_key_layer_scaling\": true,\r\n",
" \"apply_residual_connection_post_layernorm\": false,\r\n",
" \"architectures\": [\r\n",
" \"ChatGLMModel\"\r\n",
" ],\r\n",
" \"attention_dropout\": 0.0,\r\n",
" \"attention_softmax_in_fp32\": true,\r\n",
" \"auto_map\": {\r\n",
" \"AutoConfig\": \"configuration_chatglm.ChatGLMConfig\",\r\n",
" \"AutoModel\": \"modeling_chatglm.ChatGLMForConditionalGeneration\",\r\n",
" \"AutoModelForCausalLM\": \"modeling_chatglm.ChatGLMForConditionalGeneration\",\r\n",
" \"AutoModelForSeq2SeqLM\": \"modeling_chatglm.ChatGLMForConditionalGeneration\",\r\n",
" \"AutoModelForSequenceClassification\": \"modeling_chatglm.ChatGLMForSequenceClassification\"\r\n",
" },\r\n",
" \"bias_dropout_fusion\": true,\r\n",
" \"classifier_dropout\": null,\r\n",
" \"eos_token_id\": 2,\r\n",
" \"ffn_hidden_size\": 13696,\r\n",
" \"fp32_residual_connection\": false,\r\n",
" \"hidden_dropout\": 0.0,\r\n",
" \"hidden_size\": 4096,\r\n",
" \"kv_channels\": 128,\r\n",
" \"layernorm_epsilon\": 1e-05,\r\n",
" \"model_type\": \"chatglm\",\r\n",
" \"multi_query_attention\": true,\r\n",
" \"multi_query_group_num\": 2,\r\n",
" \"num_attention_heads\": 32,\r\n",
" \"num_layers\": 28,\r\n",
" \"original_rope\": true,\r\n",
" \"pad_token_id\": 0,\r\n",
" \"padded_vocab_size\": 65024,\r\n",
" \"post_layer_norm\": true,\r\n",
" \"pre_seq_len\": null,\r\n",
" \"prefix_projection\": false,\r\n",
" \"quantization_bit\": 0,\r\n",
" \"rmsnorm\": true,\r\n",
" \"seq_length\": 32768,\r\n",
" \"tie_word_embeddings\": false,\r\n",
" \"torch_dtype\": \"float16\",\r\n",
" \"transformers_version\": \"4.35.2\",\r\n",
" \"use_cache\": true,\r\n",
" \"vocab_size\": 65024\r\n",
"}\r\n",
"\r\n",
"[INFO|modeling_utils.py:3118] 2023-11-24 13:44:15,517 >> loading weights file /Models/chatglm3-6b-base/pytorch_model.bin.index.json\r\n",
"[INFO|configuration_utils.py:791] 2023-11-24 13:44:15,517 >> Generate config GenerationConfig {\r\n",
" \"eos_token_id\": 2,\r\n",
" \"pad_token_id\": 0\r\n",
"}\r\n",
"\r\n",
"Loading checkpoint shards: 100%|██████████████████| 7/7 [00:07<00:00, 1.02s/it]\r\n",
"[INFO|modeling_utils.py:3950] 2023-11-24 13:44:22,704 >> All model checkpoint weights were used when initializing ChatGLMForConditionalGeneration.\r\n",
"\r\n",
"[INFO|modeling_utils.py:3958] 2023-11-24 13:44:22,704 >> All the weights of ChatGLMForConditionalGeneration were initialized from the model checkpoint at /Models/chatglm3-6b-base.\r\n",
"If your task is similar to the task the model of the checkpoint was trained on, you can already use ChatGLMForConditionalGeneration for predictions without further training.\r\n",
"[INFO|modeling_utils.py:3525] 2023-11-24 13:44:22,706 >> Generation config file not found, using a generation config created from the model config.\r\n",
"Train dataset size: 4999\r\n",
"Sanity Check >>>>>>>>>>>>>\r\n",
" '[gMASK]': 64790 -> -100\r\n",
" 'sop': 64792 -> -100\r\n",
" '': 30910 -> -100\r\n",
" '\\n': 13 -> -100\r\n",
" '你': 54622 -> -100\r\n",
" '是一个': 32103 -> -100\r\n",
" '专业的': 34917 -> -100\r\n",
" '新闻': 31935 -> -100\r\n",
" '专家': 32114 -> -100\r\n",
" ',': 31123 -> -100\r\n",
" '请': 55073 -> -100\r\n",
" '根据': 31793 -> -100\r\n",
" '我': 54546 -> -100\r\n",
" '提供的': 35100 -> -100\r\n",
" '新闻': 31935 -> -100\r\n",
" '信息': 31707 -> -100\r\n",
" ',': 31123 -> -100\r\n",
" '包括': 31779 -> -100\r\n",
" '新闻': 31935 -> -100\r\n",
" '标题': 34490 -> -100\r\n",
" ',': 31123 -> -100\r\n",
" '新闻': 31935 -> -100\r\n",
" '关键词': 35075 -> -100\r\n",
" '等信息': 46172 -> -100\r\n",
" ',': 31123 -> -100\r\n",
" '你要': 34526 -> -100\r\n",
" '对': 54570 -> -100\r\n",
" '每一': 32467 -> -100\r\n",
" '行': 54560 -> -100\r\n",
" '新闻': 31935 -> -100\r\n",
" '类别': 38724 -> -100\r\n",
" '进行': 31636 -> -100\r\n",
" '分类': 33328 -> -100\r\n",
" '并': 54724 -> -100\r\n",
" '告诉我': 38953 -> -100\r\n",
" '结果': 31951 -> -100\r\n",
" ',': 31123 -> -100\r\n",
" '不要': 31844 -> -100\r\n",
" '返回': 34891 -> -100\r\n",
" '其他': 31722 -> -100\r\n",
" '信息和': 52701 -> -100\r\n",
" '多': 54573 -> -100\r\n",
" '于': 54579 -> -100\r\n",
" '的文字': 48746 -> -100\r\n",
" ',': 31123 -> -100\r\n",
" '这些': 31704 -> -100\r\n",
" '类别': 38724 -> -100\r\n",
" '是': 54532 -> -100\r\n",
" ':': 30954 -> -100\r\n",
" '\\n': 13 -> -100\r\n",
" 'news': 8480 -> -100\r\n",
" '_': 30962 -> -100\r\n",
" 'story': 12553 -> -100\r\n",
" '\\n': 13 -> -100\r\n",
" 'news': 8480 -> -100\r\n",
" '_': 30962 -> -100\r\n",
" 'culture': 27458 -> -100\r\n",
" '\\n': 13 -> -100\r\n",
" 'news': 8480 -> -100\r\n",
" '_': 30962 -> -100\r\n",
" 's': 30917 -> -100\r\n",
" 'ports': 3915 -> -100\r\n",
" '\\n': 13 -> -100\r\n",
" 'news': 8480 -> -100\r\n",
" '_': 30962 -> -100\r\n",
" 'fin': 6242 -> -100\r\n",
" 'ance': 562 -> -100\r\n",
" '\\n': 13 -> -100\r\n",
" 'news': 8480 -> -100\r\n",
" '_': 30962 -> -100\r\n",
" 'house': 4199 -> -100\r\n",
" '\\n': 13 -> -100\r\n",
" 'news': 8480 -> -100\r\n",
" '_': 30962 -> -100\r\n",
" 'car': 6747 -> -100\r\n",
" '\\n': 13 -> -100\r\n",
" 'news': 8480 -> -100\r\n",
" '_': 30962 -> -100\r\n",
" 'edu': 7473 -> -100\r\n",
" '\\n': 13 -> -100\r\n",
" 'news': 8480 -> -100\r\n",
" '_': 30962 -> -100\r\n",
" 'tech': 12232 -> -100\r\n",
" '\\n': 13 -> -100\r\n",
" 'news': 8480 -> -100\r\n",
" '_': 30962 -> -100\r\n",
" 'mil': 20477 -> -100\r\n",
" 'itary': 2733 -> -100\r\n",
" '\\n': 13 -> -100\r\n",
" 'news': 8480 -> -100\r\n",
" '_': 30962 -> -100\r\n",
" 'tra': 7441 -> -100\r\n",
" 'vel': 609 -> -100\r\n",
" '\\n': 13 -> -100\r\n",
" 'news': 8480 -> -100\r\n",
" '_': 30962 -> -100\r\n",
" 'world': 8515 -> -100\r\n",
" '\\n': 13 -> -100\r\n",
" 'stock': 14148 -> -100\r\n",
" '\\n': 13 -> -100\r\n",
" 'news': 8480 -> -100\r\n",
" '_': 30962 -> -100\r\n",
" 'ag': 369 -> -100\r\n",
" 'ric': 995 -> -100\r\n",
" 'ulture': 4768 -> -100\r\n",
" '\\n': 13 -> -100\r\n",
" 'news': 8480 -> -100\r\n",
" '_': 30962 -> -100\r\n",
" 'game': 8947 -> -100\r\n",
" '\\n': 13 -> -100\r\n",
" '请': 55073 -> -100\r\n",
" '选择': 31768 -> -100\r\n",
" '其中一个': 46753 -> -100\r\n",
" '类别': 38724 -> -100\r\n",
" '并': 54724 -> -100\r\n",
" '返回': 34891 -> -100\r\n",
" ',': 31123 -> -100\r\n",
" '你': 54622 -> -100\r\n",
" '只要': 32100 -> -100\r\n",
" '返回': 34891 -> -100\r\n",
" '类': 54931 -> -100\r\n",
" '别的': 34752 -> -100\r\n",
" '名称': 33624 -> -100\r\n",
" ',': 31123 -> -100\r\n",
" '不要': 31844 -> -100\r\n",
" '返回': 34891 -> -100\r\n",
" '其他': 31722 -> -100\r\n",
" '信息': 31707 -> -100\r\n",
" '。': 31155 -> -100\r\n",
" '让我们': 32817 -> -100\r\n",
" '开始': 31699 -> -100\r\n",
" '吧': 55370 -> -100\r\n",
" ':': 30954 -> -100\r\n",
" '\\n': 13 -> -100\r\n",
" '新闻': 31935 -> -100\r\n",
" '标题': 34490 -> -100\r\n",
" ':': 30954 -> -100\r\n",
" '': 30910 -> -100\r\n",
" '京城': 46921 -> -100\r\n",
" '最': 54628 -> -100\r\n",
" '值得': 32421 -> -100\r\n",
" '你来': 52586 -> -100\r\n",
" '场': 54686 -> -100\r\n",
" '文化': 31653 -> -100\r\n",
" '之旅': 35383 -> -100\r\n",
" '的': 54530 -> -100\r\n",
" '博物馆': 32964 -> -100\r\n",
" '\\n': 13 -> -100\r\n",
" '新闻': 45302 -> -100\r\n",
" '关键词': 35075 -> -100\r\n",
" ':': 30954 -> -100\r\n",
" '': 30910 -> -100\r\n",
" '保利': 46340 -> -100\r\n",
" '集团': 31839 -> -100\r\n",
" ',': 30932 -> -100\r\n",
" '马': 54988 -> -100\r\n",
" '未': 54933 -> -100\r\n",
" '都': 54606 -> -100\r\n",
" ',': 30932 -> -100\r\n",
" '中国': 31626 -> -100\r\n",
" '科学技术': 35587 -> -100\r\n",
" '馆': 55294 -> -100\r\n",
" ',': 30932 -> -100\r\n",
" '博物馆': 32964 -> -100\r\n",
" ',': 30932 -> -100\r\n",
" '新中国': 35873 -> -100\r\n",
" '\\n': 13 -> -100\r\n",
" 'news': 2374 -> 2374\r\n",
" '_': 30962 -> 30962\r\n",
" 'culture': 27458 -> 27458\r\n",
" '': 2 -> 2\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
" '': 0 -> -100\r\n",
"<<<<<<<<<<<<< Sanity Check\r\n",
"[INFO|trainer.py:544] 2023-11-24 13:44:24,838 >> max_steps is given, it will override any value given in num_train_epochs\r\n",
"[INFO|trainer.py:1723] 2023-11-24 13:44:25,787 >> ***** Running training *****\r\n",
"[INFO|trainer.py:1724] 2023-11-24 13:44:25,787 >> Num examples = 4,999\r\n",
"[INFO|trainer.py:1725] 2023-11-24 13:44:25,787 >> Num Epochs = 1\r\n",
"[INFO|trainer.py:1726] 2023-11-24 13:44:25,787 >> Instantaneous batch size per device = 1\r\n",
"[INFO|trainer.py:1729] 2023-11-24 13:44:25,787 >> Total train batch size (w. parallel, distributed & accumulation) = 2\r\n",
"[INFO|trainer.py:1730] 2023-11-24 13:44:25,787 >> Gradient Accumulation steps = 2\r\n",
"[INFO|trainer.py:1731] 2023-11-24 13:44:25,788 >> Total optimization steps = 300\r\n",
"[INFO|trainer.py:1732] 2023-11-24 13:44:25,788 >> Number of trainable parameters = 1,949,696\r\n",
" 0%| | 0/300 [00:00<?, ?it/s][W reducer.cpp:1346] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator())\r\n",
"{'loss': 0.6584, 'learning_rate': 1.9933333333333334e-05, 'epoch': 0.0} \r\n",
"{'loss': 1.1299, 'learning_rate': 1.9866666666666667e-05, 'epoch': 0.0} \r\n",
"{'loss': 1.2722, 'learning_rate': 1.98e-05, 'epoch': 0.0} \r\n",
"{'loss': 0.5076, 'learning_rate': 1.9733333333333336e-05, 'epoch': 0.0} \r\n",
"{'loss': 1.5129, 'learning_rate': 1.9666666666666666e-05, 'epoch': 0.0} \r\n",
"{'loss': 1.7524, 'learning_rate': 1.9600000000000002e-05, 'epoch': 0.0} \r\n",
"{'loss': 0.7554, 'learning_rate': 1.9533333333333335e-05, 'epoch': 0.0} \r\n",
"{'loss': 0.834, 'learning_rate': 1.9466666666666668e-05, 'epoch': 0.0} \r\n",
"{'loss': 2.0967, 'learning_rate': 1.94e-05, 'epoch': 0.0} \r\n",
"{'loss': 0.2806, 'learning_rate': 1.9333333333333333e-05, 'epoch': 0.0} \r\n",
"{'loss': 1.0696, 'learning_rate': 1.926666666666667e-05, 'epoch': 0.0} \r\n",
"{'loss': 0.5934, 'learning_rate': 1.9200000000000003e-05, 'epoch': 0.0} \r\n",
"{'loss': 0.8784, 'learning_rate': 1.9133333333333335e-05, 'epoch': 0.01} \r\n",
"{'loss': 1.9795, 'learning_rate': 1.9066666666666668e-05, 'epoch': 0.01} \r\n",
"{'loss': 0.953, 'learning_rate': 1.9e-05, 'epoch': 0.01} \r\n",
"{'loss': 0.1516, 'learning_rate': 1.8933333333333334e-05, 'epoch': 0.01} \r\n",
"{'loss': 0.7275, 'learning_rate': 1.886666666666667e-05, 'epoch': 0.01} \r\n",
"{'loss': 0.2672, 'learning_rate': 1.88e-05, 'epoch': 0.01} \r\n",
"{'loss': 0.152, 'learning_rate': 1.8733333333333336e-05, 'epoch': 0.01} \r\n",
"{'loss': 0.3321, 'learning_rate': 1.866666666666667e-05, 'epoch': 0.01} \r\n",
"{'loss': 0.5358, 'learning_rate': 1.86e-05, 'epoch': 0.01} \r\n",
"{'loss': 0.3262, 'learning_rate': 1.8533333333333334e-05, 'epoch': 0.01} \r\n",
"{'loss': 0.3477, 'learning_rate': 1.8466666666666667e-05, 'epoch': 0.01} \r\n",
"{'loss': 0.2667, 'learning_rate': 1.8400000000000003e-05, 'epoch': 0.01} \r\n",
"{'loss': 0.3765, 'learning_rate': 1.8333333333333333e-05, 'epoch': 0.01} \r\n",
"{'loss': 0.1945, 'learning_rate': 1.826666666666667e-05, 'epoch': 0.01} \r\n",
"{'loss': 0.2321, 'learning_rate': 1.8200000000000002e-05, 'epoch': 0.01} \r\n",
"{'loss': 0.094, 'learning_rate': 1.8133333333333335e-05, 'epoch': 0.01} \r\n",
"{'loss': 0.2454, 'learning_rate': 1.8066666666666668e-05, 'epoch': 0.01} \r\n",
"{'loss': 0.0523, 'learning_rate': 1.8e-05, 'epoch': 0.01} \r\n",
"{'loss': 0.9747, 'learning_rate': 1.7933333333333333e-05, 'epoch': 0.01} \r\n",
"{'loss': 0.7521, 'learning_rate': 1.7866666666666666e-05, 'epoch': 0.01} \r\n",
"{'loss': 0.7373, 'learning_rate': 1.7800000000000002e-05, 'epoch': 0.01} \r\n",
"{'loss': 0.2056, 'learning_rate': 1.7733333333333335e-05, 'epoch': 0.01} \r\n",
"{'loss': 0.0107, 'learning_rate': 1.7666666666666668e-05, 'epoch': 0.01} \r\n",
"{'loss': 0.0525, 'learning_rate': 1.76e-05, 'epoch': 0.01} \r\n",
"{'loss': 0.0377, 'learning_rate': 1.7533333333333337e-05, 'epoch': 0.01} \r\n",
"{'loss': 0.8701, 'learning_rate': 1.7466666666666667e-05, 'epoch': 0.02} \r\n",
"{'loss': 0.1365, 'learning_rate': 1.7400000000000003e-05, 'epoch': 0.02} \r\n",
"{'loss': 0.763, 'learning_rate': 1.7333333333333336e-05, 'epoch': 0.02} \r\n",
"{'loss': 0.6814, 'learning_rate': 1.726666666666667e-05, 'epoch': 0.02} \r\n",
"{'loss': 0.066, 'learning_rate': 1.72e-05, 'epoch': 0.02} \r\n",
"{'loss': 0.543, 'learning_rate': 1.7133333333333334e-05, 'epoch': 0.02} \r\n",
"{'loss': 0.1186, 'learning_rate': 1.706666666666667e-05, 'epoch': 0.02} \r\n",
"{'loss': 0.0753, 'learning_rate': 1.7e-05, 'epoch': 0.02} \r\n",
"{'loss': 0.1376, 'learning_rate': 1.6933333333333336e-05, 'epoch': 0.02} \r\n",
"{'loss': 0.0322, 'learning_rate': 1.686666666666667e-05, 'epoch': 0.02} \r\n",
"{'loss': 0.2514, 'learning_rate': 1.6800000000000002e-05, 'epoch': 0.02} \r\n",
"{'loss': 0.9785, 'learning_rate': 1.6733333333333335e-05, 'epoch': 0.02} \r\n",
"{'loss': 0.4355, 'learning_rate': 1.6666666666666667e-05, 'epoch': 0.02} \r\n",
"{'loss': 0.064, 'learning_rate': 1.66e-05, 'epoch': 0.02} \r\n",
"{'loss': 0.0915, 'learning_rate': 1.6533333333333333e-05, 'epoch': 0.02} \r\n",
"{'loss': 0.0999, 'learning_rate': 1.646666666666667e-05, 'epoch': 0.02} \r\n",
"{'loss': 0.0058, 'learning_rate': 1.64e-05, 'epoch': 0.02} \r\n",
"{'loss': 0.3339, 'learning_rate': 1.6333333333333335e-05, 'epoch': 0.02} \r\n",
"{'loss': 0.604, 'learning_rate': 1.6266666666666668e-05, 'epoch': 0.02} \r\n",
"{'loss': 0.3101, 'learning_rate': 1.62e-05, 'epoch': 0.02} \r\n",
"{'loss': 0.149, 'learning_rate': 1.6133333333333334e-05, 'epoch': 0.02} \r\n",
"{'loss': 0.1271, 'learning_rate': 1.606666666666667e-05, 'epoch': 0.02} \r\n",
"{'loss': 0.2854, 'learning_rate': 1.6000000000000003e-05, 'epoch': 0.02} \r\n",
"{'loss': 0.3322, 'learning_rate': 1.5933333333333336e-05, 'epoch': 0.02} \r\n",
"{'loss': 0.0556, 'learning_rate': 1.586666666666667e-05, 'epoch': 0.02} \r\n",
"{'loss': 0.1277, 'learning_rate': 1.58e-05, 'epoch': 0.03} \r\n",
"{'loss': 1.4146, 'learning_rate': 1.5733333333333334e-05, 'epoch': 0.03} \r\n",
"{'loss': 0.0079, 'learning_rate': 1.5666666666666667e-05, 'epoch': 0.03} \r\n",
"{'loss': 0.0576, 'learning_rate': 1.5600000000000003e-05, 'epoch': 0.03} \r\n",
"{'loss': 0.049, 'learning_rate': 1.5533333333333333e-05, 'epoch': 0.03} \r\n",
"{'loss': 0.4167, 'learning_rate': 1.546666666666667e-05, 'epoch': 0.03} \r\n",
"{'loss': 0.1903, 'learning_rate': 1.54e-05, 'epoch': 0.03} \r\n",
"{'loss': 0.0972, 'learning_rate': 1.5333333333333334e-05, 'epoch': 0.03} \r\n",
"{'loss': 0.056, 'learning_rate': 1.5266666666666667e-05, 'epoch': 0.03} \r\n",
"{'loss': 0.0937, 'learning_rate': 1.5200000000000002e-05, 'epoch': 0.03} \r\n",
"{'loss': 0.0168, 'learning_rate': 1.5133333333333335e-05, 'epoch': 0.03} \r\n",
"{'loss': 0.4977, 'learning_rate': 1.5066666666666668e-05, 'epoch': 0.03} \r\n",
"{'loss': 0.7477, 'learning_rate': 1.5000000000000002e-05, 'epoch': 0.03} \r\n",
"{'loss': 0.0054, 'learning_rate': 1.4933333333333335e-05, 'epoch': 0.03} \r\n",
"{'loss': 0.0664, 'learning_rate': 1.4866666666666668e-05, 'epoch': 0.03} \r\n",
"{'loss': 0.1106, 'learning_rate': 1.48e-05, 'epoch': 0.03} \r\n",
"{'loss': 0.3886, 'learning_rate': 1.4733333333333335e-05, 'epoch': 0.03} \r\n",
"{'loss': 0.0432, 'learning_rate': 1.4666666666666666e-05, 'epoch': 0.03} \r\n",
"{'loss': 0.9208, 'learning_rate': 1.46e-05, 'epoch': 0.03} \r\n",
"{'loss': 0.047, 'learning_rate': 1.4533333333333335e-05, 'epoch': 0.03} \r\n",
"{'loss': 0.184, 'learning_rate': 1.4466666666666668e-05, 'epoch': 0.03} \r\n",
"{'loss': 0.0166, 'learning_rate': 1.4400000000000001e-05, 'epoch': 0.03} \r\n",
"{'loss': 0.0051, 'learning_rate': 1.4333333333333334e-05, 'epoch': 0.03} \r\n",
"{'loss': 0.0221, 'learning_rate': 1.4266666666666668e-05, 'epoch': 0.03} \r\n",
"{'loss': 0.0343, 'learning_rate': 1.4200000000000001e-05, 'epoch': 0.03} \r\n",
"{'loss': 0.0725, 'learning_rate': 1.4133333333333334e-05, 'epoch': 0.04} \r\n",
"{'loss': 0.612, 'learning_rate': 1.4066666666666669e-05, 'epoch': 0.04} \r\n",
"{'loss': 0.0592, 'learning_rate': 1.4e-05, 'epoch': 0.04} \r\n",
"{'loss': 0.3324, 'learning_rate': 1.3933333333333334e-05, 'epoch': 0.04} \r\n",
"{'loss': 0.4796, 'learning_rate': 1.3866666666666669e-05, 'epoch': 0.04} \r\n",
"{'loss': 0.0342, 'learning_rate': 1.38e-05, 'epoch': 0.04} \r\n",
"{'loss': 0.1784, 'learning_rate': 1.3733333333333335e-05, 'epoch': 0.04} \r\n",
"{'loss': 0.007, 'learning_rate': 1.3666666666666667e-05, 'epoch': 0.04} \r\n",
"{'loss': 0.1859, 'learning_rate': 1.3600000000000002e-05, 'epoch': 0.04} \r\n",
"{'loss': 0.1524, 'learning_rate': 1.3533333333333333e-05, 'epoch': 0.04} \r\n",
"{'loss': 0.0107, 'learning_rate': 1.3466666666666668e-05, 'epoch': 0.04} \r\n",
"{'loss': 0.2768, 'learning_rate': 1.3400000000000002e-05, 'epoch': 0.04} \r\n",
"{'loss': 0.0235, 'learning_rate': 1.3333333333333333e-05, 'epoch': 0.04} \r\n",
" 33%|█████████████▋ | 100/300 [02:29<04:59, 1.50s/it][INFO|tokenization_utils_base.py:2428] 2023-11-24 13:46:55,413 >> tokenizer config file saved in output/news-20231124-134412-2e-05/checkpoint-100/tokenizer_config.json\r\n",
"[INFO|tokenization_utils_base.py:2437] 2023-11-24 13:46:55,413 >> Special tokens file saved in output/news-20231124-134412-2e-05/checkpoint-100/special_tokens_map.json\r\n",
"{'loss': 0.0081, 'learning_rate': 1.3266666666666668e-05, 'epoch': 0.04} \r\n",
"{'loss': 0.0147, 'learning_rate': 1.3200000000000002e-05, 'epoch': 0.04} \r\n",
"{'loss': 0.0156, 'learning_rate': 1.3133333333333334e-05, 'epoch': 0.04} \r\n",
"{'loss': 0.0188, 'learning_rate': 1.3066666666666668e-05, 'epoch': 0.04} \r\n",
"{'loss': 0.1504, 'learning_rate': 1.3000000000000001e-05, 'epoch': 0.04} \r\n",
"{'loss': 0.1548, 'learning_rate': 1.2933333333333334e-05, 'epoch': 0.04} \r\n",
"{'loss': 0.1723, 'learning_rate': 1.2866666666666667e-05, 'epoch': 0.04} \r\n",
"{'loss': 0.0258, 'learning_rate': 1.2800000000000001e-05, 'epoch': 0.04} \r\n",
"{'loss': 0.0329, 'learning_rate': 1.2733333333333336e-05, 'epoch': 0.04} \r\n",
"{'loss': 0.0042, 'learning_rate': 1.2666666666666667e-05, 'epoch': 0.04} \r\n",
"{'loss': 0.2594, 'learning_rate': 1.2600000000000001e-05, 'epoch': 0.04} \r\n",
"{'loss': 0.3348, 'learning_rate': 1.2533333333333336e-05, 'epoch': 0.04} \r\n",
"{'loss': 0.1135, 'learning_rate': 1.2466666666666667e-05, 'epoch': 0.05} \r\n",
"{'loss': 0.0225, 'learning_rate': 1.2400000000000002e-05, 'epoch': 0.05} \r\n",
"{'loss': 0.0036, 'learning_rate': 1.2333333333333334e-05, 'epoch': 0.05} \r\n",
"{'loss': 0.0105, 'learning_rate': 1.2266666666666667e-05, 'epoch': 0.05} \r\n",
"{'loss': 0.0565, 'learning_rate': 1.22e-05, 'epoch': 0.05} \r\n",
"{'loss': 0.0286, 'learning_rate': 1.2133333333333335e-05, 'epoch': 0.05} \r\n",
"{'loss': 0.5644, 'learning_rate': 1.206666666666667e-05, 'epoch': 0.05} \r\n",
"{'loss': 0.3605, 'learning_rate': 1.2e-05, 'epoch': 0.05} \r\n",
"{'loss': 0.6441, 'learning_rate': 1.1933333333333335e-05, 'epoch': 0.05} \r\n",
"{'loss': 0.0111, 'learning_rate': 1.186666666666667e-05, 'epoch': 0.05} \r\n",
"{'loss': 0.0098, 'learning_rate': 1.18e-05, 'epoch': 0.05} \r\n",
"{'loss': 0.9559, 'learning_rate': 1.1733333333333335e-05, 'epoch': 0.05} \r\n",
"{'loss': 0.0042, 'learning_rate': 1.1666666666666668e-05, 'epoch': 0.05} \r\n",
"{'loss': 0.4136, 'learning_rate': 1.16e-05, 'epoch': 0.05} \r\n",
"{'loss': 0.5268, 'learning_rate': 1.1533333333333334e-05, 'epoch': 0.05} \r\n",
"{'loss': 0.1325, 'learning_rate': 1.1466666666666668e-05, 'epoch': 0.05} \r\n",
"{'loss': 0.0051, 'learning_rate': 1.14e-05, 'epoch': 0.05} \r\n",
"{'loss': 0.0006, 'learning_rate': 1.1333333333333334e-05, 'epoch': 0.05} \r\n",
"{'loss': 0.007, 'learning_rate': 1.1266666666666668e-05, 'epoch': 0.05} \r\n",
"{'loss': 0.441, 'learning_rate': 1.1200000000000001e-05, 'epoch': 0.05} \r\n",
"{'loss': 0.0102, 'learning_rate': 1.1133333333333334e-05, 'epoch': 0.05} \r\n",
"{'loss': 0.0415, 'learning_rate': 1.1066666666666669e-05, 'epoch': 0.05} \r\n",
"{'loss': 0.0206, 'learning_rate': 1.1000000000000001e-05, 'epoch': 0.05} \r\n",
"{'loss': 0.1697, 'learning_rate': 1.0933333333333334e-05, 'epoch': 0.05} \r\n",
"{'loss': 0.2054, 'learning_rate': 1.0866666666666667e-05, 'epoch': 0.05} \r\n",
"{'loss': 0.005, 'learning_rate': 1.0800000000000002e-05, 'epoch': 0.06} \r\n",
"{'loss': 0.9604, 'learning_rate': 1.0733333333333333e-05, 'epoch': 0.06} \r\n",
"{'loss': 0.0043, 'learning_rate': 1.0666666666666667e-05, 'epoch': 0.06} \r\n",
"{'loss': 0.4815, 'learning_rate': 1.0600000000000002e-05, 'epoch': 0.06} \r\n",
"{'loss': 0.4817, 'learning_rate': 1.0533333333333333e-05, 'epoch': 0.06} \r\n",
"{'loss': 1.0441, 'learning_rate': 1.0466666666666668e-05, 'epoch': 0.06} \r\n",
"{'loss': 1.2198, 'learning_rate': 1.04e-05, 'epoch': 0.06} \r\n",
"{'loss': 0.019, 'learning_rate': 1.0333333333333335e-05, 'epoch': 0.06} \r\n",
"{'loss': 0.0127, 'learning_rate': 1.0266666666666668e-05, 'epoch': 0.06} \r\n",
"{'loss': 0.0606, 'learning_rate': 1.02e-05, 'epoch': 0.06} \r\n",
"{'loss': 0.1299, 'learning_rate': 1.0133333333333335e-05, 'epoch': 0.06} \r\n",
"{'loss': 0.0678, 'learning_rate': 1.0066666666666666e-05, 'epoch': 0.06} \r\n",
"{'loss': 0.0037, 'learning_rate': 1e-05, 'epoch': 0.06} \r\n",
"{'loss': 0.4082, 'learning_rate': 9.933333333333334e-06, 'epoch': 0.06} \r\n",
"{'loss': 0.1784, 'learning_rate': 9.866666666666668e-06, 'epoch': 0.06} \r\n",
"{'loss': 0.0074, 'learning_rate': 9.800000000000001e-06, 'epoch': 0.06} \r\n",
"{'loss': 0.148, 'learning_rate': 9.733333333333334e-06, 'epoch': 0.06} \r\n",
"{'loss': 0.0005, 'learning_rate': 9.666666666666667e-06, 'epoch': 0.06} \r\n",
"{'loss': 0.3039, 'learning_rate': 9.600000000000001e-06, 'epoch': 0.06} \r\n",
"{'loss': 0.0024, 'learning_rate': 9.533333333333334e-06, 'epoch': 0.06} \r\n",
"{'loss': 0.2657, 'learning_rate': 9.466666666666667e-06, 'epoch': 0.06} \r\n",
"{'loss': 0.0117, 'learning_rate': 9.4e-06, 'epoch': 0.06} \r\n",
"{'loss': 0.001, 'learning_rate': 9.333333333333334e-06, 'epoch': 0.06} \r\n",
"{'loss': 0.0045, 'learning_rate': 9.266666666666667e-06, 'epoch': 0.06} \r\n",
"{'loss': 0.0013, 'learning_rate': 9.200000000000002e-06, 'epoch': 0.06} \r\n",
"{'loss': 0.4838, 'learning_rate': 9.133333333333335e-06, 'epoch': 0.07} \r\n",
"{'loss': 0.164, 'learning_rate': 9.066666666666667e-06, 'epoch': 0.07} \r\n",
"{'loss': 0.0084, 'learning_rate': 9e-06, 'epoch': 0.07} \r\n",
"{'loss': 0.0207, 'learning_rate': 8.933333333333333e-06, 'epoch': 0.07} \r\n",
"{'loss': 0.3102, 'learning_rate': 8.866666666666668e-06, 'epoch': 0.07} \r\n",
"{'loss': 0.0723, 'learning_rate': 8.8e-06, 'epoch': 0.07} \r\n",
"{'loss': 0.0009, 'learning_rate': 8.733333333333333e-06, 'epoch': 0.07} \r\n",
"{'loss': 0.1778, 'learning_rate': 8.666666666666668e-06, 'epoch': 0.07} \r\n",
"{'loss': 1.0444, 'learning_rate': 8.6e-06, 'epoch': 0.07} \r\n",
"{'loss': 0.1915, 'learning_rate': 8.533333333333335e-06, 'epoch': 0.07} \r\n",
"{'loss': 0.0191, 'learning_rate': 8.466666666666668e-06, 'epoch': 0.07} \r\n",
"{'loss': 0.0044, 'learning_rate': 8.400000000000001e-06, 'epoch': 0.07} \r\n",
"{'loss': 0.5965, 'learning_rate': 8.333333333333334e-06, 'epoch': 0.07} \r\n",
"{'loss': 0.0201, 'learning_rate': 8.266666666666667e-06, 'epoch': 0.07} \r\n",
"{'loss': 1.1224, 'learning_rate': 8.2e-06, 'epoch': 0.07} \r\n",
"{'loss': 0.0041, 'learning_rate': 8.133333333333334e-06, 'epoch': 0.07} \r\n",
"{'loss': 0.0071, 'learning_rate': 8.066666666666667e-06, 'epoch': 0.07} \r\n",
"{'loss': 0.0885, 'learning_rate': 8.000000000000001e-06, 'epoch': 0.07} \r\n",
"{'loss': 0.394, 'learning_rate': 7.933333333333334e-06, 'epoch': 0.07} \r\n",
"{'loss': 0.0002, 'learning_rate': 7.866666666666667e-06, 'epoch': 0.07} \r\n",
"{'loss': 0.2249, 'learning_rate': 7.800000000000002e-06, 'epoch': 0.07} \r\n",
"{'loss': 0.2116, 'learning_rate': 7.733333333333334e-06, 'epoch': 0.07} \r\n",
"{'loss': 0.0159, 'learning_rate': 7.666666666666667e-06, 'epoch': 0.07} \r\n",
"{'loss': 0.0027, 'learning_rate': 7.600000000000001e-06, 'epoch': 0.07} \r\n",
"{'loss': 0.1615, 'learning_rate': 7.533333333333334e-06, 'epoch': 0.07} \r\n",
"{'loss': 0.0401, 'learning_rate': 7.4666666666666675e-06, 'epoch': 0.08} \r\n",
"{'loss': 0.0089, 'learning_rate': 7.4e-06, 'epoch': 0.08} \r\n",
"{'loss': 0.0712, 'learning_rate': 7.333333333333333e-06, 'epoch': 0.08} \r\n",
"{'loss': 0.5022, 'learning_rate': 7.266666666666668e-06, 'epoch': 0.08} \r\n",
"{'loss': 0.1132, 'learning_rate': 7.2000000000000005e-06, 'epoch': 0.08} \r\n",
"{'loss': 0.3318, 'learning_rate': 7.133333333333334e-06, 'epoch': 0.08} \r\n",
"{'loss': 0.0071, 'learning_rate': 7.066666666666667e-06, 'epoch': 0.08} \r\n",
"{'loss': 0.6094, 'learning_rate': 7e-06, 'epoch': 0.08} \r\n",
"{'loss': 0.0022, 'learning_rate': 6.9333333333333344e-06, 'epoch': 0.08} \r\n",
"{'loss': 0.4046, 'learning_rate': 6.866666666666667e-06, 'epoch': 0.08} \r\n",
"{'loss': 0.3455, 'learning_rate': 6.800000000000001e-06, 'epoch': 0.08} \r\n",
"{'loss': 0.5033, 'learning_rate': 6.733333333333334e-06, 'epoch': 0.08} \r\n",
"{'loss': 0.1898, 'learning_rate': 6.666666666666667e-06, 'epoch': 0.08} \r\n",
" 67%|███████████████████████████▎ | 200/300 [05:00<02:30, 1.51s/it][INFO|tokenization_utils_base.py:2428] 2023-11-24 13:49:25,887 >> tokenizer config file saved in output/news-20231124-134412-2e-05/checkpoint-200/tokenizer_config.json\r\n",
"[INFO|tokenization_utils_base.py:2437] 2023-11-24 13:49:25,887 >> Special tokens file saved in output/news-20231124-134412-2e-05/checkpoint-200/special_tokens_map.json\r\n",
"{'loss': 0.2219, 'learning_rate': 6.600000000000001e-06, 'epoch': 0.08} \r\n",
"{'loss': 0.0261, 'learning_rate': 6.533333333333334e-06, 'epoch': 0.08} \r\n",
"{'loss': 0.0268, 'learning_rate': 6.466666666666667e-06, 'epoch': 0.08} \r\n",
"{'loss': 0.2255, 'learning_rate': 6.4000000000000006e-06, 'epoch': 0.08} \r\n",
"{'loss': 0.0023, 'learning_rate': 6.333333333333333e-06, 'epoch': 0.08} \r\n",
"{'loss': 0.003, 'learning_rate': 6.266666666666668e-06, 'epoch': 0.08} \r\n",
"{'loss': 0.0272, 'learning_rate': 6.200000000000001e-06, 'epoch': 0.08} \r\n",
"{'loss': 0.0585, 'learning_rate': 6.133333333333334e-06, 'epoch': 0.08} \r\n",
"{'loss': 0.0094, 'learning_rate': 6.066666666666667e-06, 'epoch': 0.08} \r\n",
"{'loss': 0.163, 'learning_rate': 6e-06, 'epoch': 0.08} \r\n",
"{'loss': 0.0036, 'learning_rate': 5.933333333333335e-06, 'epoch': 0.08} \r\n",
"{'loss': 0.0781, 'learning_rate': 5.8666666666666675e-06, 'epoch': 0.08} \r\n",
"{'loss': 0.0089, 'learning_rate': 5.8e-06, 'epoch': 0.09} \r\n",
"{'loss': 0.5827, 'learning_rate': 5.733333333333334e-06, 'epoch': 0.09} \r\n",
"{'loss': 0.0053, 'learning_rate': 5.666666666666667e-06, 'epoch': 0.09} \r\n",
"{'loss': 0.0179, 'learning_rate': 5.600000000000001e-06, 'epoch': 0.09} \r\n",
"{'loss': 1.8615, 'learning_rate': 5.533333333333334e-06, 'epoch': 0.09} \r\n",
"{'loss': 0.3423, 'learning_rate': 5.466666666666667e-06, 'epoch': 0.09} \r\n",
"{'loss': 0.5502, 'learning_rate': 5.400000000000001e-06, 'epoch': 0.09} \r\n",
"{'loss': 0.1772, 'learning_rate': 5.333333333333334e-06, 'epoch': 0.09} \r\n",
"{'loss': 0.075, 'learning_rate': 5.2666666666666665e-06, 'epoch': 0.09} \r\n",
"{'loss': 0.0059, 'learning_rate': 5.2e-06, 'epoch': 0.09} \r\n",
"{'loss': 0.0042, 'learning_rate': 5.133333333333334e-06, 'epoch': 0.09} \r\n",
"{'loss': 0.0213, 'learning_rate': 5.0666666666666676e-06, 'epoch': 0.09} \r\n",
"{'loss': 0.0174, 'learning_rate': 5e-06, 'epoch': 0.09} \r\n",
"{'loss': 1.0351, 'learning_rate': 4.933333333333334e-06, 'epoch': 0.09} \r\n",
"{'loss': 0.0026, 'learning_rate': 4.866666666666667e-06, 'epoch': 0.09} \r\n",
"{'loss': 0.0026, 'learning_rate': 4.800000000000001e-06, 'epoch': 0.09} \r\n",
"{'loss': 0.5251, 'learning_rate': 4.7333333333333335e-06, 'epoch': 0.09} \r\n",
"{'loss': 0.2601, 'learning_rate': 4.666666666666667e-06, 'epoch': 0.09} \r\n",
"{'loss': 0.006, 'learning_rate': 4.600000000000001e-06, 'epoch': 0.09} \r\n",
"{'loss': 0.8758, 'learning_rate': 4.533333333333334e-06, 'epoch': 0.09} \r\n",
"{'loss': 0.2939, 'learning_rate': 4.4666666666666665e-06, 'epoch': 0.09} \r\n",
"{'loss': 0.0141, 'learning_rate': 4.4e-06, 'epoch': 0.09} \r\n",
"{'loss': 0.001, 'learning_rate': 4.333333333333334e-06, 'epoch': 0.09} \r\n",
"{'loss': 0.0176, 'learning_rate': 4.266666666666668e-06, 'epoch': 0.09} \r\n",
"{'loss': 0.4934, 'learning_rate': 4.2000000000000004e-06, 'epoch': 0.09} \r\n",
"{'loss': 0.3413, 'learning_rate': 4.133333333333333e-06, 'epoch': 0.1} \r\n",
"{'loss': 0.0059, 'learning_rate': 4.066666666666667e-06, 'epoch': 0.1} \r\n",
"{'loss': 0.0379, 'learning_rate': 4.000000000000001e-06, 'epoch': 0.1} \r\n",
"{'loss': 0.7871, 'learning_rate': 3.9333333333333335e-06, 'epoch': 0.1} \r\n",
"{'loss': 0.0035, 'learning_rate': 3.866666666666667e-06, 'epoch': 0.1} \r\n",
"{'loss': 0.3481, 'learning_rate': 3.8000000000000005e-06, 'epoch': 0.1} \r\n",
"{'loss': 0.1827, 'learning_rate': 3.7333333333333337e-06, 'epoch': 0.1} \r\n",
"{'loss': 0.4669, 'learning_rate': 3.6666666666666666e-06, 'epoch': 0.1} \r\n",
"{'loss': 0.8078, 'learning_rate': 3.6000000000000003e-06, 'epoch': 0.1} \r\n",
"{'loss': 1.1314, 'learning_rate': 3.5333333333333335e-06, 'epoch': 0.1} \r\n",
"{'loss': 0.0124, 'learning_rate': 3.4666666666666672e-06, 'epoch': 0.1} \r\n",
"{'loss': 0.0116, 'learning_rate': 3.4000000000000005e-06, 'epoch': 0.1} \r\n",
"{'loss': 0.2296, 'learning_rate': 3.3333333333333333e-06, 'epoch': 0.1} \r\n",
"{'loss': 0.0037, 'learning_rate': 3.266666666666667e-06, 'epoch': 0.1} \r\n",
"{'loss': 0.0882, 'learning_rate': 3.2000000000000003e-06, 'epoch': 0.1} \r\n",
"{'loss': 0.0073, 'learning_rate': 3.133333333333334e-06, 'epoch': 0.1} \r\n",
"{'loss': 0.013, 'learning_rate': 3.066666666666667e-06, 'epoch': 0.1} \r\n",
"{'loss': 0.0027, 'learning_rate': 3e-06, 'epoch': 0.1} \r\n",
"{'loss': 0.7103, 'learning_rate': 2.9333333333333338e-06, 'epoch': 0.1} \r\n",
"{'loss': 0.0022, 'learning_rate': 2.866666666666667e-06, 'epoch': 0.1} \r\n",
"{'loss': 0.4343, 'learning_rate': 2.8000000000000003e-06, 'epoch': 0.1} \r\n",
"{'loss': 1.1652, 'learning_rate': 2.7333333333333336e-06, 'epoch': 0.1} \r\n",
"{'loss': 0.377, 'learning_rate': 2.666666666666667e-06, 'epoch': 0.1} \r\n",
"{'loss': 0.0025, 'learning_rate': 2.6e-06, 'epoch': 0.1} \r\n",
"{'loss': 0.072, 'learning_rate': 2.5333333333333338e-06, 'epoch': 0.1} \r\n",
"{'loss': 0.0035, 'learning_rate': 2.466666666666667e-06, 'epoch': 0.11} \r\n",
"{'loss': 0.0176, 'learning_rate': 2.4000000000000003e-06, 'epoch': 0.11} \r\n",
"{'loss': 0.0404, 'learning_rate': 2.3333333333333336e-06, 'epoch': 0.11} \r\n",
"{'loss': 0.0243, 'learning_rate': 2.266666666666667e-06, 'epoch': 0.11} \r\n",
"{'loss': 0.0197, 'learning_rate': 2.2e-06, 'epoch': 0.11} \r\n",
"{'loss': 0.0558, 'learning_rate': 2.133333333333334e-06, 'epoch': 0.11} \r\n",
"{'loss': 0.0396, 'learning_rate': 2.0666666666666666e-06, 'epoch': 0.11} \r\n",
"{'loss': 0.0042, 'learning_rate': 2.0000000000000003e-06, 'epoch': 0.11} \r\n",
"{'loss': 0.2281, 'learning_rate': 1.9333333333333336e-06, 'epoch': 0.11} \r\n",
"{'loss': 0.0594, 'learning_rate': 1.8666666666666669e-06, 'epoch': 0.11} \r\n",
"{'loss': 0.0016, 'learning_rate': 1.8000000000000001e-06, 'epoch': 0.11} \r\n",
"{'loss': 0.1489, 'learning_rate': 1.7333333333333336e-06, 'epoch': 0.11} \r\n",
"{'loss': 0.0053, 'learning_rate': 1.6666666666666667e-06, 'epoch': 0.11} \r\n",
"{'loss': 0.0131, 'learning_rate': 1.6000000000000001e-06, 'epoch': 0.11} \r\n",
"{'loss': 0.0517, 'learning_rate': 1.5333333333333334e-06, 'epoch': 0.11} \r\n",
"{'loss': 2.0054, 'learning_rate': 1.4666666666666669e-06, 'epoch': 0.11} \r\n",
"{'loss': 0.863, 'learning_rate': 1.4000000000000001e-06, 'epoch': 0.11} \r\n",
"{'loss': 0.0038, 'learning_rate': 1.3333333333333334e-06, 'epoch': 0.11} \r\n",
"{'loss': 0.0003, 'learning_rate': 1.2666666666666669e-06, 'epoch': 0.11} \r\n",
"{'loss': 0.0553, 'learning_rate': 1.2000000000000002e-06, 'epoch': 0.11} \r\n",
"{'loss': 0.0076, 'learning_rate': 1.1333333333333334e-06, 'epoch': 0.11} \r\n",
"{'loss': 0.0481, 'learning_rate': 1.066666666666667e-06, 'epoch': 0.11} \r\n",
"{'loss': 0.5921, 'learning_rate': 1.0000000000000002e-06, 'epoch': 0.11} \r\n",
"{'loss': 0.3559, 'learning_rate': 9.333333333333334e-07, 'epoch': 0.11} \r\n",
"{'loss': 0.0014, 'learning_rate': 8.666666666666668e-07, 'epoch': 0.11} \r\n",
"{'loss': 0.4355, 'learning_rate': 8.000000000000001e-07, 'epoch': 0.12} \r\n",
"{'loss': 0.0764, 'learning_rate': 7.333333333333334e-07, 'epoch': 0.12} \r\n",
"{'loss': 0.0035, 'learning_rate': 6.666666666666667e-07, 'epoch': 0.12} \r\n",
"{'loss': 0.0042, 'learning_rate': 6.000000000000001e-07, 'epoch': 0.12} \r\n",
"{'loss': 0.0086, 'learning_rate': 5.333333333333335e-07, 'epoch': 0.12} \r\n",
"{'loss': 0.2114, 'learning_rate': 4.666666666666667e-07, 'epoch': 0.12} \r\n",
"{'loss': 0.223, 'learning_rate': 4.0000000000000003e-07, 'epoch': 0.12} \r\n",
"{'loss': 0.0044, 'learning_rate': 3.3333333333333335e-07, 'epoch': 0.12} \r\n",
"{'loss': 0.6593, 'learning_rate': 2.666666666666667e-07, 'epoch': 0.12} \r\n",
"{'loss': 0.1374, 'learning_rate': 2.0000000000000002e-07, 'epoch': 0.12} \r\n",
"{'loss': 0.1037, 'learning_rate': 1.3333333333333336e-07, 'epoch': 0.12} \r\n",
"{'loss': 0.0181, 'learning_rate': 6.666666666666668e-08, 'epoch': 0.12} \r\n",
"{'loss': 0.0783, 'learning_rate': 0.0, 'epoch': 0.12} \r\n",
"100%|█████████████████████████████████████████| 300/300 [07:31<00:00, 1.51s/it][INFO|tokenization_utils_base.py:2428] 2023-11-24 13:51:56,849 >> tokenizer config file saved in output/news-20231124-134412-2e-05/checkpoint-300/tokenizer_config.json\r\n",
"[INFO|tokenization_utils_base.py:2437] 2023-11-24 13:51:56,849 >> Special tokens file saved in output/news-20231124-134412-2e-05/checkpoint-300/special_tokens_map.json\r\n",
"[INFO|trainer.py:1955] 2023-11-24 13:51:56,875 >> \r\n",
"\r\n",
"Training completed. Do not forget to share your model on huggingface.co/models =)\r\n",
"\r\n",
"\r\n",
"{'train_runtime': 451.0866, 'train_samples_per_second': 1.33, 'train_steps_per_second': 0.665, 'train_loss': 0.2702767427762349, 'epoch': 0.12}\r\n",
"100%|█████████████████████████████████████████| 300/300 [07:31<00:00, 1.50s/it]\r\n",
"[INFO|tokenization_utils_base.py:2428] 2023-11-24 13:51:56,887 >> tokenizer config file saved in output/news-20231124-134412-2e-05/tokenizer_config.json\r\n",
"[INFO|tokenization_utils_base.py:2437] 2023-11-24 13:51:56,887 >> Special tokens file saved in output/news-20231124-134412-2e-05/special_tokens_map.json\r\n"
]
}
],
"source": [
"!which python\n",
"import os\n",
"from datetime import datetime\n",
"import random\n",
"\n",
"# 定义变量\n",
"lr = 2e-5\n",
"num_gpus = 1\n",
"lora_rank = 8\n",
"lora_alpha = 32\n",
"lora_dropout = 0.1\n",
"max_source_len = 512\n",
"max_target_len = 128\n",
"dev_batch_size = 1\n",
"grad_accumularion_steps = 2\n",
"max_step = 300\n",
"save_interval = 100\n",
"max_seq_len = 512\n",
"logging_steps=1\n",
"\n",
"run_name = \"news\"\n",
"dataset_path = \"data/toutiao_cat_data_example.jsonl\"\n",
"datestr = datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n",
"output_dir = f\"output/{run_name}-{datestr}-{lr}\"\n",
"master_port = random.randint(10000, 65535)\n",
"\n",
"os.makedirs(output_dir, exist_ok=True)\n",
"# 构建命令\n",
"command = f\"\"\"\n",
"/home/zr/Code/ChatGLM3/venv/bin/torchrun --standalone --nnodes=1 --nproc_per_node={num_gpus} ../finetune_basemodel_demo/finetune.py \\\n",
" --train_format input-output \\\n",
" --train_file {dataset_path} \\\n",
" --lora_rank {lora_rank} \\\n",
" --lora_alpha {lora_alpha} \\\n",
" --lora_dropout {lora_dropout} \\\n",
" --max_seq_length {max_seq_len} \\\n",
" --preprocessing_num_workers 1 \\\n",
" --model_name_or_path {model_path} \\\n",
" --output_dir {output_dir} \\\n",
" --per_device_train_batch_size 1 \\\n",
" --gradient_accumulation_steps 2 \\\n",
" --max_steps {max_step} \\\n",
" --logging_steps {logging_steps} \\\n",
" --save_steps {save_interval} \\\n",
" --learning_rate {lr}\n",
"\"\"\"\n",
"\n",
"# 在 Notebook 中执行命令\n",
"!{command}"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2023-11-24T13:44:12.468751Z",
"end_time": "2023-11-24T13:51:59.534815Z"
}
}
},
{
"cell_type": "markdown",
"source": [
"## 使用微调的模型进行推理预测\n",
"现在,我们已经完成了模型的微调,接下来,我们将使用微调后的模型进行推理。我们使用与微调时相同的提示词,并使用一些没有出现的模型效果来复现推理结果。"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 7,
"outputs": [
{
"data": {
"text/plain": "Loading checkpoint shards: 0%| | 0/7 [00:00<?, ?it/s]",
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "21b29455e65c4540886188d1fd5d68aa"
}
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import torch\n",
"from transformers import AutoModel, AutoTokenizer\n",
"import os\n",
"from peft import get_peft_model, LoraConfig, TaskType\n",
"\n",
"# 参数设置\n",
"lora_path = output_dir + \"pytorch_model.bin\"\n",
"\n",
"# 加载分词器和模型\n",
"tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True)\n",
"model = AutoModel.from_pretrained(model_path, load_in_8bit=False, trust_remote_code=True).to(device)\n",
"\n",
"# LoRA 模型配置\n",
"peft_config = LoraConfig(\n",
" task_type=TaskType.CAUSAL_LM, inference_mode=True,\n",
" target_modules=['query_key_value'],\n",
" r=8, lora_alpha=32, lora_dropout=0.1\n",
")\n",
"model = get_peft_model(model, peft_config)\n",
"if os.path.exists(lora_path):\n",
" model.load_state_dict(torch.load(lora_path), strict=False)\n"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2023-11-24T13:51:59.522046Z",
"end_time": "2023-11-24T13:52:12.130299Z"
}
}
},
{
"cell_type": "markdown",
"source": [
"我们使用同样的提示词进行推理。"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 8,
"outputs": [],
"source": [
"inputs = tokenizer(PROMPT, return_tensors=\"pt\").to(device)\n",
"response = model.generate(input_ids=inputs[\"input_ids\"],max_new_tokens=max_new_tokens,temperature=temperature,top_p=top_p,do_sample=True)\n",
"response = response[0, inputs[\"input_ids\"].shape[-1]:]\n",
"response = tokenizer.decode(response, skip_special_tokens=True)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2023-11-24T13:52:12.132308Z",
"end_time": "2023-11-24T13:52:12.296837Z"
}
}
},
{
"cell_type": "code",
"execution_count": 9,
"outputs": [
{
"data": {
"text/plain": "'news_tech'"
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"response"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2023-11-24T13:52:12.296837Z",
"end_time": "2023-11-24T13:52:12.297910Z"
}
}
},
{
"cell_type": "markdown",
"source": [
"这一次,模型成功给出了理想的答案。我们结束实操训练,删除模型并释放显存。"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 10,
"outputs": [],
"source": [
"del model\n",
"torch.cuda.empty_cache()"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2023-11-24T13:54:07.355796Z",
"end_time": "2023-11-24T13:54:07.507254Z"
}
}
},
{
"cell_type": "markdown",
"source": [
"## 总结\n",
"在本实践手册中,我们让开发者体验了 `ChatGLM3-6B` 模型在经过微调前后在 `新闻标题分类` 任务中表现。\n",
"我们可以发现:\n",
"\n",
"在具有混淆的新闻分类中,原始的模型可能受到了误导,不能有效的进行分类,而经过简单微调后的模型,已经具备了正确分类的能力。\n",
"因此,对于有更高要求的专业分类任务,我们可以使用微调的方式对模型进行简单微调,实现更好的任务完成效果。\n"
],
"metadata": {
"collapsed": false
}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
# ChatGLM3-6B-base 微调示例
本目录提供 ChatGLM3-6B-base 模型的微调示例,目前,仅包含了Lora微调。
如果将模型下载到了本地,本文和代码中的 `THUDM/chatglm3-6b-base` 字段均应替换为相应地址以从本地加载模型。
运行示例需要 `python>=3.10`,除基础的 `torch` 依赖外,示例代码运行还需要依赖
```bash
pip install requirements.txt
```
## 多轮对话格式
`base`模型不具备对话能力,仅能够生成单轮回复。如果你希望使用多轮对话模型,使用`Chat`模型进行微调。
## 数据集要求
格式上,请使用`alpaca`数据集。
```bash
{"context": "hello", "target": "hi,I am ChatGLM3"}
```
其中,`context`是对话的上文,也就是模型的输入,`target`是对话的下文,也就是模型的输出。
## 微调模型
以下脚本提供了微调模型的参考方式。
```bash
./scripts/finetune_lora.sh # 使用Lora微调
```
### 提示
1. 微调代码在开始训练前,会先打印首条训练数据的预处理信息,显示为
```log
Sanity Check >>>>>>>>>>>>>
'[gMASK]': 64790 -> -100
'sop': 64792 -> -100
'<|system|>': 64794 -> -100
'': 30910 -> -100
'\n': 13 -> -100
'Answer': 20115 -> -100
'the': 267 -> -100
'following': 1762 -> -100
...
'know': 683 -> -100
'the': 267 -> -100
'response': 3010 -> -100
'details': 3296 -> -100
'.': 30930 -> -100
'<|assistant|>': 64796 -> -100
'': 30910 -> 30910
'\n': 13 -> 13
'I': 307 -> 307
'need': 720 -> 720
'to': 289 -> 289
'use': 792 -> 792
...
'': 0 -> -100
'': 0 -> -100 (有若干个)
<<<<<<<<<<<<< Sanity Check
```
字样,每行依次表示一个 detokenized string, token_id 和 target_id。可在日志中查看这部分的 `loss_mask` 是否符合预期。若不符合,可能需要调整代码或数据。
2. 参考显存用量
- 按照官方脚本的默认参数运行,每一张显卡占用显存为 `23GB`
3. 若尝试后发现显存不足,可以考虑
- 尝试降低 `DEV_BATCH_SIZE` 并提升 `GRAD_ACCUMULARION_STEPS`
- 尝试降低 `MAX_SEQ_LEN`,但是这可能会影响模型的性能
## 注意事项
+ 基座模型不具备对话能力,仅能够生成单轮回复。如果你希望使用多轮对话模型,使用Chat模型进行微调。
+ 请注意,运行本脚本,你还需要安装本目录下的 `requirements.txt` 中的所有内容。
from dataclasses import dataclass, field
from typing import Optional
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
lora_checkpoint: str = field(
default=None, metadata={"help": "Path to lora checkpoints"}
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
cache_dir: Optional[str] = field(
default=None,
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
)
use_fast_tokenizer: bool = field(
default=True,
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
)
model_revision: str = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
use_auth_token: bool = field(
default=False,
metadata={
"help": (
"Will use the token generated when running `huggingface-cli login` (necessary to use this script "
"with private models)."
)
},
)
resize_position_embeddings: Optional[bool] = field(
default=None,
metadata={
"help": (
"Whether to automatically resize the position embeddings if `max_source_length` exceeds "
"the model's position embeddings."
)
},
)
quantization_bit: Optional[int] = field(
default=None,
metadata={
"help": (
"An optional parameter specifying the number of bits used for quantization. "
"Quantization is a process that reduces the model size by limiting the number of "
"bits that represent each weight in the model. A lower number of bits can reduce "
"the model size and speed up inference, but might also decrease model accuracy. "
"If not set (None), quantization is not applied."
)
},
)
lora_rank: Optional[int] = field(
default=8,
metadata={
"help": (
"balancing between complexity and model flexibility. A higher rank allows more "
"complex adaptations but increases the number of parameters and computational cost."
)
},
)
lora_alpha: Optional[float] = field(
default=32,
metadata={
"help": (
"A higher value results in more significant adjustments, potentially improving adaptation to new tasks or data, "
"but might also risk overfitting. A lower value makes smaller adjustments, possibly maintaining better generalization."
)
}, )
lora_dropout: Optional[float] = field(
default=0.1,
metadata={
"help": (
"during training to prevent the model from overly relying on specific patterns in the training data. "
"Higher dropout rates can improve model generalization but may reduce learning efficiency."
)
},
)
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
train_file: Optional[str] = field(
default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."}
)
max_seq_length: Optional[int] = field(
default=2048,
metadata={
"help": (
"The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated."
)
},
)
max_source_length: Optional[int] = field(
default=1024,
metadata={
"help": (
"The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
)
},
)
max_target_length: Optional[int] = field(
default=128,
metadata={
"help": (
"The maximum total sequence length for target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
)
},
)
train_format: str = field(
default=None, metadata={"help": "The format of the training data file (mulit-turn or input-output)"},
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
max_seq_length: Optional[int] = field(
default=1024,
metadata={
"help": (
"The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
)
},
)
pad_to_max_length: bool = field(
default=False,
metadata={
"help": (
"Whether to pad all samples to model maximum sentence length. "
"If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
"efficient on GPU but very bad for TPU."
)
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": (
"For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
)
},
)
def __post_init__(self):
extension = self.train_file.split(".")[-1]
assert extension in {"jsonl", "json"}, "`train_file` should be a jsonl or a json file."
assert self.train_format in {"multi-turn", "input-output"}
#!/usr/bin/env python
# coding=utf-8
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Fine-tuning the library models for sequence to sequence.
"""
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
# Adapted from
import logging
import os
import sys
import torch
import json
import transformers
from transformers import (
AutoModel,
AutoTokenizer,
DataCollatorForSeq2Seq,
HfArgumentParser,
Seq2SeqTrainingArguments,
set_seed,
)
from trainer import LoRATrainer
from arguments import ModelArguments, DataTrainingArguments
from peft import get_peft_model, LoraConfig, TaskType, prepare_model_for_kbit_training
from preprocess_utils import sanity_check, InputOutputDataset
logger = logging.getLogger(__name__)
class CastOutputToFloat(torch.nn.Module):
def __init__(self, layer):
super().__init__()
self.layer = layer
def forward(self, *args, **kwargs):
return self.layer(*args, **kwargs).float()
def main():
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
if training_args.should_log:
# The default of training_args.log_level is passive, so we set log level at info here to have that default.
transformers.utils.logging.set_verbosity_info()
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
# datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
# Log on each process the small summary:
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
logger.info(f"Training/evaluation parameters {training_args}")
# Set seed before initializing model.
set_seed(training_args.seed)
# Load pretrained model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
model = AutoModel.from_pretrained(model_args.model_name_or_path, trust_remote_code=True).half().cuda()
if model_args.quantization_bit is not None:
print(f"Quantized to {model_args.quantization_bit} bit")
model = model.quantize(model_args.quantization_bit)
with open(data_args.train_file, "r", encoding="utf-8") as f:
if data_args.train_file.endswith(".json"):
train_data = json.load(f)
elif data_args.train_file.endswith(".jsonl"):
train_data = [json.loads(line) for line in f]
if data_args.train_format == "input-output":
train_dataset = InputOutputDataset(
train_data,
tokenizer,
data_args.max_source_length,
data_args.max_target_length,
)
else:
raise ValueError(f"Unknown train format: {data_args.train_format}")
print(f"Train dataset size: {len(train_dataset)}")
#if training_args.local_rank < 1:
sanity_check(train_dataset[0]['input_ids'], train_dataset[0]['labels'], tokenizer)
# Apply PEFT configuration
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=model_args.lora_rank,
target_modules=['query_key_value'],
lora_alpha=model_args.lora_alpha,
lora_dropout=model_args.lora_dropout,
)
model = get_peft_model(model, peft_config).to("cuda")
# 确保梯度检查点和模型并行化设置正确
#model.gradient_checkpointing_enable()
model.enable_input_require_grads()
model.is_parallelizable = True
model.model_parallel = True # 可以尝试暂时关闭模型并行化来看是否解决问题
model.lm_head = CastOutputToFloat(model.transformer.output_layer)
model.config.use_cache = False
# Data collator
data_collator = DataCollatorForSeq2Seq(
tokenizer,
model=model,
label_pad_token_id=-100,
pad_to_multiple_of=None,
padding=False
)
# Initialize our Trainer
trainer = LoRATrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
tokenizer=tokenizer,
data_collator=data_collator,
)
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
# model.gradient_checkpointing_enable()
model.enable_input_require_grads()
trainer.train()
trainer.save_model() # Saves the tokenizer too for easy upload
trainer.save_state()
if __name__ == "__main__":
main()
\ No newline at end of file
import argparse
import torch
from transformers import AutoConfig, AutoModel, AutoTokenizer
import os
from peft import get_peft_model, LoraConfig, TaskType
# Argument Parser Setup
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default=None,
help="The directory of the model")
parser.add_argument("--tokenizer", type=str, default=None, help="Tokenizer path")
parser.add_argument("--lora-path", type=str, default=None,
help="Path to the LoRA model checkpoint")
parser.add_argument("--device", type=str, default="cuda", help="Device to use for computation")
parser.add_argument("--max-new-tokens", type=int, default=128, help="Maximum new tokens for generation")
parser.add_argument("--lora-alpha", type=float, default=32, help="LoRA alpha")
parser.add_argument("--lora-rank", type=int, default=8, help="LoRA r")
parser.add_argument("--lora-dropout", type=float, default=0.1, help="LoRA dropout")
args = parser.parse_args()
if args.tokenizer is None:
args.tokenizer = args.model
# Model and Tokenizer Configuration
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, trust_remote_code=True)
model = AutoModel.from_pretrained(args.model, load_in_8bit=False, trust_remote_code=True, device_map="auto").to(
args.device)
# LoRA Model Configuration
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM, inference_mode=True,
target_modules=['query_key_value'],
r=args.lora_rank, lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout
)
model = get_peft_model(model, peft_config)
if os.path.exists(args.lora_path):
model.load_state_dict(torch.load(args.lora_path), strict=False)
# Interactive Prompt
while True:
prompt = input("Prompt: ")
inputs = tokenizer(prompt, return_tensors="pt").to(args.device)
response = model.generate(input_ids=inputs["input_ids"],
max_length=inputs["input_ids"].shape[-1] + args.max_new_tokens)
response = response[0, inputs["input_ids"].shape[-1]:]
print("Response:", tokenizer.decode(response, skip_special_tokens=True))
from transformers import PreTrainedTokenizer
from torch.utils.data import Dataset
from typing import Dict, List
def sanity_check(tokens: List[int], target: List[int], tokenizer: PreTrainedTokenizer):
print("Sanity Check >>>>>>>>>>>>>")
for t, m in zip(tokens, target):
decoded = tokenizer.tokenizer.index_special_tokens[t] \
if t in tokenizer.tokenizer.index_special_tokens \
else tokenizer.decode([t])
print("%20s: %6d -> %6d" % (repr(decoded), t, m))
print("<<<<<<<<<<<<< Sanity Check")
assert len(tokens) == len(target), f"length mismatch: {len(tokens)} vs {len(target)}"
class InputOutputDataset(Dataset):
def __init__(self, data: List[dict], tokenizer: PreTrainedTokenizer, max_source_length: int, max_target_length: int):
super(InputOutputDataset, self).__init__()
self.tokenizer = tokenizer
self.max_source_length = max_source_length
self.max_target_length = max_target_length
self.max_seq_length = max_source_length + max_target_length + 1
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, i) -> dict:
data_item = self.data[i]
a_ids = self.tokenizer.encode(text=data_item['context'], add_special_tokens=True, truncation=True,
max_length=self.max_source_length)
b_ids = self.tokenizer.encode(text=data_item['target'], add_special_tokens=False, truncation=True,
max_length=self.max_target_length)
context_length = len(a_ids)
input_ids = a_ids + b_ids + [self.tokenizer.eos_token_id]
labels = [self.tokenizer.pad_token_id] * context_length + b_ids + [self.tokenizer.eos_token_id]
pad_len = self.max_seq_length - len(input_ids)
input_ids = input_ids + [self.tokenizer.pad_token_id] * pad_len
labels = labels + [self.tokenizer.pad_token_id] * pad_len
labels = [(l if l != self.tokenizer.pad_token_id else -100) for l in labels]
assert len(input_ids) == len(labels), f"length mismatch: {len(input_ids)} vs {len(labels)}"
return {
"input_ids": input_ids,
"labels": labels
}
tqdm
datasets
fsspec
astunparse
peft
accelerate
sentencepiece
\ No newline at end of file
#! /usr/bin/env bash
set -ex
LR=1e-4
NUM_GPUS=4
LORA_RANK=8
LORA_ALPHA=32
LORA_DROUPOUT=0.1
MAX_SOURCE_LEN=512
MAX_TARGET_LEN=128
DEV_BATCH_SIZE=1
GRAD_ACCUMULARION_STEPS=1
MAX_STEP=500
SAVE_INTERVAL=50
MAX_SEQ_LEN=512
RUN_NAME=text
BASE_MODEL_PATH=THUDM/chatglm3-6b-base
DATASET_PATH=data/alpaca_data.jsonl
DATESTR=`date +%Y%m%d-%H%M%S`
OUTPUT_DIR=output/${RUN_NAME}-${DATESTR}-${LR}
MASTER_PORT=$(shuf -n 1 -i 10000-65535)
mkdir -p $OUTPUT_DIR
torchrun --standalone --nnodes=1 --nproc_per_node=$NUM_GPUS finetune.py \
--train_format input-output \
--train_file $DATASET_PATH \
--lora_rank $LORA_RANK \
--lora_alpha $LORA_ALPHA \
--lora_dropout $LORA_DROUPOUT \
--max_seq_length $MAX_SEQ_LEN \
--preprocessing_num_workers 1 \
--model_name_or_path $BASE_MODEL_PATH \
--output_dir $OUTPUT_DIR \
--per_device_train_batch_size $DEV_BATCH_SIZE \
--gradient_accumulation_steps $GRAD_ACCUMULARION_STEPS \
--max_steps $MAX_STEP \
--logging_steps 1 \
--save_steps $SAVE_INTERVAL \
--learning_rate $LR 2>&1 | tee ${OUTPUT_DIR}/train.log
import argparse
import json
import tqdm
def format_example(example: dict) -> dict:
context = f"Instruction: {example['instruction']}\n"
if example.get("input"):
context += f"Input: {example['input']}\n"
context += "Answer: "
target = example["output"]
return {"context": context, "target": target}
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--data_path", type=str, default="data/alpaca_data.json")
parser.add_argument("--save_path", type=str, default="data/alpaca_data.jsonl")
args = parser.parse_args()
print("args:", args)
with open(args.data_path) as f:
examples = json.load(f)
with open(args.save_path, 'w') as f:
for example in tqdm.tqdm(examples, desc="formatting.."):
f.write(json.dumps(format_example(example), ensure_ascii=False) + '\n')
if __name__ == "__main__":
main()
# coding=utf-8
# Copyright 2020-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task.
"""
import os
from typing import Optional
from transformers import Trainer
import torch
from transformers.modeling_utils import PreTrainedModel, unwrap_model
from transformers.utils import logging
logger = logging.get_logger(__name__)
WEIGHTS_NAME = "pytorch_model.pt"
TRAINING_ARGS_NAME = "training_args.bin"
class LoRATrainer(Trainer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def compute_loss(self, model, inputs, return_outputs=False):
return model(**inputs).loss
def save_model(self, output_dir=None, _internal_call=False):
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
logger.info(f"Saving model checkpoint to {output_dir}")
model_to_save = unwrap_model(self.model)
saved_params = {
k: v.to("cuda") for k, v in model_to_save.named_parameters() if v.requires_grad
}
torch.save(saved_params, os.path.join(output_dir, WEIGHTS_NAME))
if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
# ChatGLM3-6B 微调示例
本目录提供 ChatGLM3-6B 模型的微调示例,包括全量微调和 P-Tuning v2。格式上,提供多轮对话微调样例和输入输出格式微调样例。
如果将模型下载到了本地,本文和代码中的 `THUDM/chatglm3-6b` 字段均应替换为相应地址以从本地加载模型。
运行示例需要 `python>=3.10`,除基础的 `torch` 依赖外,示例代码运行还需要依赖
```bash
pip install requirements.txt
```
## 多轮对话格式
多轮对话微调示例采用 ChatGLM3 对话格式约定,对不同角色添加不同 `loss_mask` 从而在一遍计算中为多轮回复计算 `loss`
### 数据格式和预处理
对于数据文件,样例采用如下格式
如果您仅希望微调模型的对话能力,而非工具能力,您应该按照以下格式整理数据。
```json
[
{
"conversations": [
{
"role": "system",
"content": "<system prompt text>"
},
{
"role": "user",
"content": "<user prompt text>"
},
{
"role": "assistant",
"content": "<assistant response text>"
},
// ... Muti Turn
{
"role": "user",
"content": "<user prompt text>"
},
{
"role": "assistant",
"content": "<assistant response text>"
}
]
}
// ...
]
```
**请注意,这种方法在微调的step较多的情况下会影响到模型的工具调用功能**
如果您希望微调模型的对话和工具能力,您应该按照以下格式整理数据。
```json
[
{
"tools": [
// available tools, format is not restricted
],
"conversations": [
{
"role": "system",
"content": "<system prompt text>"
},
{
"role": "user",
"content": "<user prompt text>"
},
{
"role": "assistant",
"content": "<assistant thought to text>"
},
{
"role": "tool",
"name": "<name of the tool to be called",
"parameters": {
"<parameter_name>": "<parameter_value>"
},
"observation": "<observation>"
// don't have to be string
},
{
"role": "assistant",
"content": "<assistant response to observation>"
},
// ... Muti Turn
{
"role": "user",
"content": "<user prompt text>"
},
{
"role": "assistant",
"content": "<assistant response text>"
}
]
}
// ...
]
```
- 关于工具描述的 system prompt 无需手动插入,预处理时会将 `tools` 字段使用 `json.dumps(..., ensure_ascii=False)` 格式化后插入为首条 system prompt。
- 每种角色可以附带一个 `bool` 类型的 `loss` 字段,表示该字段所预测的内容是否参与 `loss` 计算。若没有该字段,样例实现中默认对 `system`, `user` 不计算 `loss`,其余角色则计算 `loss`
- `tool` 并不是 ChatGLM3 中的原生角色,这里的 `tool` 在预处理阶段将被自动转化为一个具有工具调用 `metadata``assistant` 角色(默认计算 `loss`)和一个表示工具返回值的 `observation` 角色(不计算 `loss`)。
- 目前暂未实现 `Code interpreter`的微调任务。
- `system` 角色为可选角色,但若存在 `system` 角色,其必须出现在 `user` 角色之前,且一个完整的对话数据(无论单轮或者多轮对话)只能出现一次 `system` 角色。
作为示例,我们使用 ToolAlpaca 数据集来进行微调。首先,克隆 [ToolAlpaca 数据集](https://github.com/tangqiaoyu/ToolAlpaca),并使用
```bash
./scripts/format_tool_alpaca.py --path "ToolAlpaca/data/train_data.json"
```
将数据集处理成上述格式。在这里,我们有意将工具处理成了了 `list[str]` 这样的自然语言形式,以观察模型在微调前后对工具定义的理解能力。
### 微调模型
以下脚本提供了微调模型的参考方式。
```bash
./scripts/finetune_ds_multiturn.sh # 全量微调
./scripts/finetune_pt_multiturn.sh # P-Tuning v2 微调
```
### 部署
我们更新了 ChatGLM3 的综合 Demo,使其可以部署微调后的模型 checkpoint。
对于全量微调,可以使用以下方式进行部署
```bash
cd ../composite_demo
MODEL_PATH="path to finetuned model checkpoint" TOKENIZER_PATH="THUDM/chatglm3-6b" streamlit run main.py
```
对于 P-Tuning v2 微调,可以使用以下方式进行部署
```bash
cd ../composite_demo
MODEL_PATH="THUDM/chatglm3-6b" PT_PATH="path to p-tuning checkpoint" streamlit run main.py
```
## 输入输出格式
对于输入-输出格式,样例采用如下输入格式
```json
[
{
"prompt": "<prompt text>",
"response": "<response text>"
}
// ...
]
```
预处理时,不会拼接任何角色标识符。
作为示例,我们使用 AdvertiseGen 数据集来进行微调。从 [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing) 或者 [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1) 下载处理好的 AdvertiseGen 数据集,将解压后的 `AdvertiseGen` 目录放到本目录下。
```bash
./scripts/format_advertise_gen.py --path "AdvertiseGen/train.json"
```
来下载和将数据集处理成上述格式。
### 微调模型
以下脚本提供了微调模型的参考方式。
```bash
./scripts/finetune_ds.sh # 全量微调
./scripts/finetune_pt.sh # P-Tuning v2 微调
```
### 推理验证
对于输入输出格式的微调,可使用 `inference.py` 进行基本的推理验证。
```bash
python inference.py \
--pt-checkpoint "path to p-tuning checkpoint" \
--model THUDM/chatglm3-6b
```
```bash
python inference.py \
--tokenizer THUDM/chatglm3-6b \
--model "path to finetuned model checkpoint"
```
### 提示
1. 微调代码在开始训练前,会先打印首条训练数据的预处理信息,显示为
```log
Sanity Check >>>>>>>>>>>>>
'[gMASK]': 64790 -> -100
'sop': 64792 -> -100
'<|system|>': 64794 -> -100
'': 30910 -> -100
'\n': 13 -> -100
'Answer': 20115 -> -100
'the': 267 -> -100
'following': 1762 -> -100
...
'know': 683 -> -100
'the': 267 -> -100
'response': 3010 -> -100
'details': 3296 -> -100
'.': 30930 -> -100
'<|assistant|>': 64796 -> -100
'': 30910 -> 30910
'\n': 13 -> 13
'I': 307 -> 307
'need': 720 -> 720
'to': 289 -> 289
'use': 792 -> 792
...
<<<<<<<<<<<<< Sanity Check
```
字样,每行依次表示一个 detokenized string, token_id 和 target_id。可在日志中查看这部分的 `loss_mask` 是否符合预期。若不符合,可能需要调整代码或数据。
2. 参考显存用量
- P-Tuning V2 `PRE_SEQ_LEN=128`, `DEV_BATCH_SIZE=1`, `GRAD_ACCUMULARION_STEPS=16`, `MAX_SEQ_LEN=2048` 配置下约需要 21GB 显存。
- 全量微调时,`./scripts/finetune_ds_multiturn.sh` 中的配置(`MAX_SEQ_LEN=2048`, `DEV_BATCH_SIZE=16`, `GRAD_ACCUMULARION_STEPS=1`)恰好用满 4 * 80GB 显存。
3. 若尝试后发现显存不足,可以考虑
- 尝试降低 `DEV_BATCH_SIZE` 并提升 `GRAD_ACCUMULARION_STEPS`
- 尝试添加 `--quantization_bit 8``--quantization_bit 4`
- `PRE_SEQ_LEN=128`, `DEV_BATCH_SIZE=1`, `GRAD_ACCUMULARION_STEPS=16`, `MAX_SEQ_LEN=1024` 配置下,`--quantization_bit 8` 约需 12GB 显存,`--quantization_bit 4` 约需 7.6GB 显存。
## 参考文献
```
@inproceedings{liu2022p,
title={P-tuning: Prompt tuning can be comparable to fine-tuning across scales and tasks},
author={Liu, Xiao and Ji, Kaixuan and Fu, Yicheng and Tam, Weng and Du, Zhengxiao and Yang, Zhilin and Tang, Jie},
booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 2: Short Papers)},
pages={61--68},
year={2022}
}
@misc{tang2023toolalpaca,
title={ToolAlpaca: Generalized Tool Learning for Language Models with 3000 Simulated Cases},
author={Qiaoyu Tang and Ziliang Deng and Hongyu Lin and Xianpei Han and Qiao Liang and Le Sun},
year={2023},
eprint={2306.05301},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
```
from dataclasses import dataclass, field
from typing import Optional
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
ptuning_checkpoint: str = field(
default=None, metadata={"help": "Path to p-tuning v2 checkpoints"}
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
cache_dir: Optional[str] = field(
default=None,
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
)
use_fast_tokenizer: bool = field(
default=True,
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
)
model_revision: str = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
use_auth_token: bool = field(
default=False,
metadata={
"help": (
"Will use the token generated when running `huggingface-cli login` (necessary to use this script "
"with private models)."
)
},
)
resize_position_embeddings: Optional[bool] = field(
default=None,
metadata={
"help": (
"Whether to automatically resize the position embeddings if `max_source_length` exceeds "
"the model's position embeddings."
)
},
)
quantization_bit: Optional[int] = field(
default=None
)
pre_seq_len: Optional[int] = field(
default=None
)
prefix_projection: bool = field(
default=False
)
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
train_file: Optional[str] = field(
default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."}
)
max_seq_length: Optional[int] = field(
default=2048,
metadata={
"help": (
"The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated."
)
},
)
max_source_length: Optional[int] = field(
default=1024,
metadata={
"help": (
"The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
)
},
)
max_target_length: Optional[int] = field(
default=128,
metadata={
"help": (
"The maximum total sequence length for target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
)
},
)
train_format: str = field(
default=None, metadata={"help": "The format of the training data file (mulit-turn or input-output)"},
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
max_seq_length: Optional[int] = field(
default=1024,
metadata={
"help": (
"The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
)
},
)
pad_to_max_length: bool = field(
default=False,
metadata={
"help": (
"Whether to pad all samples to model maximum sentence length. "
"If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
"efficient on GPU but very bad for TPU."
)
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": (
"For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
)
},
)
def __post_init__(self):
extension = self.train_file.split(".")[-1]
assert extension in {"jsonl", "json"}, "`train_file` should be a jsonl or a json file."
assert self.train_format in {"multi-turn", "input-output"}
\ No newline at end of file
{
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu" :"auto",
"gradient_accumulation_steps": "auto",
"zero_allow_untested_optimizer": true,
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"initial_scale_power": 16,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"zero_optimization": {
"stage": 3,
"stage3_gather_16bit_weights_on_model_save": true,
"allgather_partitions": true,
"allgather_bucket_size": 5e8,
"overlap_comm": false,
"reduce_scatter": true,
"reduce_bucket_size": 5e8,
"contiguous_gradients" : true
}
}
#!/usr/bin/env python
# coding=utf-8
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Fine-tuning the library models for sequence to sequence.
"""
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
# Adapted from
import logging
import os
import sys
import torch
import json
import transformers
from transformers import (
AutoConfig,
AutoModel,
AutoTokenizer,
DataCollatorForSeq2Seq,
HfArgumentParser,
Seq2SeqTrainingArguments,
set_seed,
)
from trainer import PrefixTrainer
from arguments import ModelArguments, DataTrainingArguments
from preprocess_utils import sanity_check, MultiTurnDataset, InputOutputDataset
logger = logging.getLogger(__name__)
# import os
# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:32"
def main():
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
if training_args.should_log:
# The default of training_args.log_level is passive, so we set log level at info here to have that default.
transformers.utils.logging.set_verbosity_info()
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
# datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
# Log on each process the small summary:
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
logger.info(f"Training/evaluation parameters {training_args}")
# Set seed before initializing model.
set_seed(training_args.seed)
# Load pretrained model and tokenizer
config = AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
config.pre_seq_len = model_args.pre_seq_len
config.prefix_projection = model_args.prefix_projection
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
if model_args.ptuning_checkpoint is not None:
model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)
prefix_state_dict = torch.load(os.path.join(model_args.ptuning_checkpoint, "pytorch_model.bin"))
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
if k.startswith("transformer.prefix_encoder."):
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
elif model_args.pre_seq_len is not None:
model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)#,empty_init=False)
else:
model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True,empty_init=False)
if model_args.quantization_bit is not None:
print(f"Quantized to {model_args.quantization_bit} bit")
model = model.quantize(model_args.quantization_bit)
if model_args.pre_seq_len is not None:
# P-tuning v2
model = model.half()
model.transformer.prefix_encoder.float()
else:
# Finetune
model = model.float()
with open(data_args.train_file, "r", encoding="utf-8") as f:
if data_args.train_file.endswith(".json"):
train_data = json.load(f)
elif data_args.train_file.endswith(".jsonl"):
train_data = [json.loads(line) for line in f]
if data_args.train_format == "multi-turn":
train_dataset = MultiTurnDataset(
train_data,
tokenizer,
data_args.max_seq_length,
)
elif data_args.train_format == "input-output":
train_dataset = InputOutputDataset(
train_data,
tokenizer,
data_args.max_source_length,
data_args.max_target_length,
)
else:
raise ValueError(f"Unknown train format: {data_args.train_format}")
if training_args.local_rank < 1:
sanity_check(train_dataset[0]['input_ids'], train_dataset[0]['labels'], tokenizer)
# Data collator
data_collator = DataCollatorForSeq2Seq(
tokenizer,
model=model,
label_pad_token_id=-100,
pad_to_multiple_of=None,
padding=False
)
# Initialize our Trainer
trainer = PrefixTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
tokenizer=tokenizer,
data_collator=data_collator,
save_changed=model_args.pre_seq_len is not None
)
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
model.gradient_checkpointing_enable()
model.enable_input_require_grads()
trainer.train(resume_from_checkpoint=checkpoint)
trainer.save_model() # Saves the tokenizer too for easy upload
trainer.save_state()
if __name__ == "__main__":
main()
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
import argparse
from transformers import AutoConfig, AutoModel, AutoTokenizer
import torch
import os
parser = argparse.ArgumentParser()
parser.add_argument("--pt-checkpoint", type=str, default=None, help="The checkpoint path")
parser.add_argument("--model", type=str, default=None, help="main model weights")
parser.add_argument("--tokenizer", type=str, default=None, help="main model weights")
parser.add_argument("--pt-pre-seq-len", type=int, default=128, help="The pre-seq-len used in p-tuning")
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--max-new-tokens", type=int, default=128)
args = parser.parse_args()
if args.tokenizer is None:
args.tokenizer = args.model
if args.pt_checkpoint:
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, trust_remote_code=True)
config = AutoConfig.from_pretrained(args.model, trust_remote_code=True, pre_seq_len=args.pt_pre_seq_len)
model = AutoModel.from_pretrained(args.model, config=config, trust_remote_code=True).cuda()
prefix_state_dict = torch.load(os.path.join(args.pt_checkpoint, "pytorch_model.bin"))
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
if k.startswith("transformer.prefix_encoder."):
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
else:
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, trust_remote_code=True)
model = AutoModel.from_pretrained(args.model, trust_remote_code=True)
model = model.to(args.device)
while True:
prompt = input("Prompt:")
inputs = tokenizer(prompt, return_tensors="pt")
inputs = inputs.to(args.device)
response = model.generate(input_ids=inputs["input_ids"], max_length=inputs["input_ids"].shape[-1] + args.max_new_tokens)
response = response[0, inputs["input_ids"].shape[-1]:]
print("Response:", tokenizer.decode(response, skip_special_tokens=True))
\ No newline at end of file
import json
import ast
import astunparse
from transformers import PreTrainedTokenizer
from torch.utils.data import Dataset
from copy import deepcopy
from typing import Dict, List
# text constants
FUNCTION_CALL_NAME = 'tool_call'
FUNCTION_CALL_PREFIX = '```python\n'
FUNCTION_CALL_POSTFIX = '\n```'
TOOL_DEFINITION_PREFIX = 'Answer the following questions as best as you can. You have access to the following tools:\n'
CONVERSATOIN_KEY = 'conversations'
TOOL_DESC_KEY = 'tools'
def format_function_call(function_name: str, parameters: Dict[str, str]):
function_name = ast.Name(id=function_name)
keywords = [
ast.keyword(arg=arg_name, value=ast.Constant(arg_value))
for arg_name, arg_value in parameters.items()
]
func_call = ast.Call(func=function_name, args=[], keywords=keywords)
return astunparse.unparse(func_call).strip()
def format_conversation(item, tokenizer, conversation_key: str, tool_key: str):
conversations = deepcopy(item[conversation_key])
# Note: `loss_mask` here means whether *the prediction* of the token should take loss
tokens, loss_masks = [tokenizer.get_command("[gMASK]"), tokenizer.get_command("sop")], [0, 0]
def _update(_tokens: List[int], value: int = 1):
value = int(value)
tokens.extend(_tokens)
loss_masks.extend([value] * len(_tokens))
# insert system prompt for tools
if tool_key in item:
conversations.insert(0,
{
"role": "system",
"content": TOOL_DEFINITION_PREFIX + json.dumps(item[tool_key], indent=4, ensure_ascii=False)
}
)
for idx, conv in enumerate(conversations):
loss = conv.get("loss", True)
if conv['role'] in {'system', 'user'}:
loss = False
if conv['role'] == 'tool':
# function call python code
value = FUNCTION_CALL_PREFIX + format_function_call(FUNCTION_CALL_NAME, conv["parameters"]) + FUNCTION_CALL_POSTFIX
text = tokenizer.build_single_message("assistant", conv["name"], value)
_update(text, loss)
# function call result
value = conv.get('observation', None)
if not isinstance(value, str):
value = json.dumps(value, ensure_ascii=False)
text = tokenizer.build_single_message("observation", "", value)
_update(text, False)
else:
text = tokenizer.build_single_message(conv['role'], "", conv["content"])
_update(text, loss)
_update([tokenizer.eos_token_id], False)
assert len(tokens) == len(loss_masks), f"length mismatch: {len(tokens)} vs {len(loss_masks)}"
return tokens, loss_masks
def sanity_check(tokens: List[int], target: List[int], tokenizer: PreTrainedTokenizer):
print("Sanity Check >>>>>>>>>>>>>")
for t, m in zip(tokens, target):
decoded = tokenizer.tokenizer.index_special_tokens[t] \
if t in tokenizer.tokenizer.index_special_tokens \
else tokenizer.decode([t])
print("%20s: %6d -> %6d" % (repr(decoded), t, m))
print("<<<<<<<<<<<<< Sanity Check")
assert len(tokens) == len(target), f"length mismatch: {len(tokens)} vs {len(target)}"
class MultiTurnDataset(Dataset):
def __init__(self, data: List[dict], tokenizer: PreTrainedTokenizer, max_seq_length: int):
super(MultiTurnDataset, self).__init__()
self.tokenizer = tokenizer
self.max_seq_length = max_seq_length
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, i) -> dict:
data_item = self.data[i]
tokens, loss_masks = format_conversation(data_item, self.tokenizer, CONVERSATOIN_KEY, TOOL_DESC_KEY)
# labels are used inside the model
target_based_loss_mask = [False] + loss_masks[:-1]
labels = [(t if m else -100) for t, m in zip(tokens, target_based_loss_mask)]
tokens = tokens[:self.max_seq_length]
labels = labels[:self.max_seq_length]
tokens += [self.tokenizer.pad_token_id] * (self.max_seq_length - len(tokens))
labels += [-100] * (self.max_seq_length - len(labels))
assert len(tokens) == len(labels), f"length mismatch: {len(tokens)} vs {len(labels)}"
return {
"input_ids": tokens,
"labels": labels
}
class InputOutputDataset(Dataset):
def __init__(self, data: List[dict], tokenizer: PreTrainedTokenizer, max_source_length: int, max_target_length: int):
super(InputOutputDataset, self).__init__()
self.tokenizer = tokenizer
self.max_source_length = max_source_length
self.max_target_length = max_target_length
self.max_seq_length = max_source_length + max_target_length + 1
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, i) -> dict:
data_item = self.data[i]
a_ids = self.tokenizer.encode(text=data_item['prompt'], add_special_tokens=True, truncation=True,
max_length=self.max_source_length)
b_ids = self.tokenizer.encode(text=data_item['response'], add_special_tokens=False, truncation=True,
max_length=self.max_target_length)
context_length = len(a_ids)
input_ids = a_ids + b_ids + [self.tokenizer.eos_token_id]
labels = [self.tokenizer.pad_token_id] * context_length + b_ids + [self.tokenizer.eos_token_id]
pad_len = self.max_seq_length - len(input_ids)
input_ids = input_ids + [self.tokenizer.pad_token_id] * pad_len
labels = labels + [self.tokenizer.pad_token_id] * pad_len
labels = [(l if l != self.tokenizer.pad_token_id else -100) for l in labels]
assert len(input_ids) == len(labels), f"length mismatch: {len(input_ids)} vs {len(labels)}"
return {
"input_ids": input_ids,
"labels": labels
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment